import hashlib
import os

import networkx as nx
import pandas as pd
import numpy as np
from numpy.linalg import LinAlgError
import math


class GraphFeatureCalculator:
    """A class to calculate various features from directed graphs."""

    def __init__(self, features_to_calculate=None):
        """
        Initialize the feature calculator with specified features.

        Args:
            features_to_calculate (list): List of feature names to calculate.
                                        If None, calculates all available features.
        """
        self.all_features = {
            'in_degree': self._calculate_in_degree,
            'out_degree': self._calculate_out_degree,
            'in_katz_centrality': self._calculate_in_katz_centrality,
            'out_katz_centrality': self._calculate_out_katz_centrality,
            'in_closeness_centrality': self._calculate_in_closeness_centrality,
            'out_closeness_centrality': self._calculate_out_closeness_centrality,
            'betweenness_centrality': self._calculate_betweenness_centrality,
            'no_of_sccs': self._calculate_number_of_sccs,
            'scc_size': self._calculate_scc_size,
            'strong_connectivity': self._calculate_strong_connectivity,
            'irreflexive': self._calculate_irreflexivity,
            'avg_degree': self._calculate_avg_degree,
            'aperiodicity': self._calculate_aperiodicity
        }

        self.features_to_calculate = (features_to_calculate
                                      if features_to_calculate is not None
                                      else list(self.all_features.keys()))

    def _create_feature_hash(self, feature_names):
        """
        Create a deterministic hash from a list of feature names.

        Args:
            feature_names (str): Concatenated string of feature names

        Returns:
            str: First 6 characters of the MD5 hash
        """
        return hashlib.md5(feature_names.encode('utf-8')).hexdigest()[:6]

    def get_model_names(self, base_prefix='rf_kwt', model_dir=None):
        """
        Generate model and scaler filenames based on selected features.

        Args:
            base_prefix (str): Prefix for the model filename
            model_dir (str, optional): Directory to save/load models from.
                                     If None, uses the current script's directory.

        Returns:
            tuple: (model_filepath, scaler_filepath)
        """
        feature_count = len(self.features_to_calculate)
        feature_names = '_'.join(sorted(self.features_to_calculate))

        # Create a shortened version if the filename would be too long
        if len(feature_names) > 50:  # arbitrary length limit
            features_hash = self._create_feature_hash(feature_names)
            RF_name = f"{base_prefix}_{feature_count}Features_{features_hash}.pkl"
        else:
            RF_name = f"{base_prefix}_{feature_count}Features_{feature_names}.pkl"

        scaler_name = RF_name.replace('.pkl', '_scaler.pkl')

        # Get the appropriate directory
        if model_dir is None:
            model_dir = os.path.dirname(os.path.abspath(__file__))

        # Construct full paths
        RF_path = os.path.join(model_dir, RF_name)
        scaler_path = os.path.join(model_dir, scaler_name)

        return RF_path, scaler_path

    def calculate_features(self, di_graph, standardize = False, normalize = False):
        """
        Calculate all specified features for the given directed graph.

        Args:
            di_graph (nx.DiGraph): NetworkX directed graph object

        Returns:
            pd.DataFrame: DataFrame containing all calculated features
        """
        feature_dict = {}

        for feature_name in self.features_to_calculate:
            if feature_name in self.all_features:
                feature_dict[feature_name] = self.all_features[feature_name](di_graph)
            else:
                print(f"Warning: Feature '{feature_name}' not recognized and will be skipped.")

        # Convert to DataFrame
        features_df = pd.DataFrame.from_dict(feature_dict, orient='columns')

        # Apply standardization if requested
        if standardize:
            for column in features_df.columns:
                if features_df[column].std() > 0:  # Avoid division by zero
                    features_df[column] = (features_df[column] - features_df[column].mean()) / features_df[
                        column].std()
                else:
                    # If standard deviation is 0, just center the data
                    features_df[column] = features_df[column] - features_df[column].mean()

        # Apply min-max normalization if requested
        if normalize:
            for column in features_df.columns:
                min_val = features_df[column].min()
                max_val = features_df[column].max()

                # Check if all values are the same (to avoid division by zero)
                if max_val == min_val:
                    # For static features, set to 1 to preserve the information
                    # that all nodes have this property equally
                    features_df[column] = 1.0
                else:
                    # Apply min-max normalization
                    features_df[column] = (features_df[column] - min_val) / (max_val - min_val)

        return features_df

    def _calculate_in_degree(self, di_graph):
        return {node: di_graph.in_degree(node) for node in di_graph.nodes()}

    def _calculate_out_degree(self, di_graph):
        return {node: di_graph.out_degree(node) for node in di_graph.nodes()}

    def _calculate_in_katz_centrality(self, di_graph):
        try:
            largest_eigval = max(nx.adjacency_spectrum(di_graph))
            lambda_max = np.real(largest_eigval)
            alpha = 0.1 if lambda_max == 0 else (1 / lambda_max) * 0.9

            return nx.katz_centrality_numpy(di_graph, alpha=alpha)
        except LinAlgError:
            print("Warning: LinAlgError in Katz centrality calculation. Using alpha=0.9")
            return nx.katz_centrality_numpy(di_graph, alpha=0.9)

    def _calculate_out_katz_centrality(self, di_graph):
        try:
            largest_eigval = max(nx.adjacency_spectrum(di_graph))
            lambda_max = np.real(largest_eigval)
            alpha = 0.1 if lambda_max == 0 else (1 / lambda_max) * 0.9

            return nx.katz_centrality_numpy(di_graph.reverse(), alpha=alpha)
        except LinAlgError:
            print("Warning: LinAlgError in Katz centrality calculation. Using alpha=0.9")
            return nx.katz_centrality_numpy(di_graph.reverse(), alpha=0.9)

    def _calculate_in_closeness_centrality(self, di_graph):
        return nx.closeness_centrality(di_graph)

    def _calculate_out_closeness_centrality(self, di_graph):
        return nx.closeness_centrality(di_graph.reverse())

    def _calculate_betweenness_centrality(self, di_graph):
        betweenness = nx.betweenness_centrality(di_graph)
        if any(math.isnan(score) for score in betweenness.values()):
            print("Warning: NaN values in betweenness centrality. Setting all values to 0")
            return {node: 0 for node in di_graph.nodes()}
        return betweenness

    def _calculate_scc_size(self, di_graph):
        scc_list = list(nx.strongly_connected_components(di_graph))
        return {node: len(scc) for scc in scc_list for node in scc}

    def _calculate_strong_connectivity(self, di_graph):
        is_strongly_connected = nx.is_strongly_connected(di_graph)
        return {node: 1 if is_strongly_connected else 0 for node in di_graph.nodes()}

    def _calculate_irreflexivity(self, di_graph):
        is_irreflexive = not any(di_graph.has_edge(node, node) for node in di_graph.nodes())
        return {node: 1 if is_irreflexive else 0 for node in di_graph.nodes()}

    def _calculate_number_of_sccs(self, di_graph):
        num_sccs = nx.number_strongly_connected_components(di_graph)
        return {node: num_sccs for node in di_graph.nodes()}

    def _calculate_avg_degree(self, di_graph):
        if not di_graph.nodes():
            return {node: 0 for node in di_graph.nodes()}
        avg = sum(dict(di_graph.degree()).values()) / len(di_graph.nodes())
        return {node: avg for node in di_graph.nodes()}

    def _calculate_aperiodicity(self, di_graph):
        is_aperiodic = nx.is_aperiodic(di_graph)
        return {node: 1 if is_aperiodic else 0 for node in di_graph.nodes()}

    def _calculate_katz_centrality(
            G,
            alpha=0.1,
            beta=1.0,
            max_iter=1000,
            tol=1.0e-6,
            nstart=None,
            normalized=True,
            weight=None,
    ):
        r"""Compute the Katz centrality for the nodes of the graph G.
        Code adapted from networkX to avoid raising Exception if convergence fails
        """
        if len(G) == 0:
            return {}

        nnodes = G.number_of_nodes()

        if nstart is None:
            # choose starting vector with entries of 0
            x = {n: 0 for n in G}
        else:
            x = nstart

        try:
            b = dict.fromkeys(G, float(beta))
        except (TypeError, ValueError, AttributeError) as err:
            b = beta
            if set(beta) != set(G):
                raise nx.NetworkXError(
                    "beta dictionary must have a value for every node"
                ) from err

        # make up to max_iter iterations
        for _ in range(max_iter):
            xlast = x
            x = dict.fromkeys(xlast, 0)
            # do the multiplication y^T = Alpha * x^T A + Beta
            for n in x:
                for nbr in G[n]:
                    x[nbr] += xlast[n] * G[n][nbr].get(weight, 1)
            for n in x:
                x[n] = alpha * x[n] + b[n]

            # check convergence
            error = sum(abs(x[n] - xlast[n]) for n in x)
            if error < nnodes * tol:
                if normalized:
                    # normalize vector
                    try:
                        s = 1.0 / math.hypot(*x.values())
                    except ZeroDivisionError:
                        s = 1.0
                else:
                    s = 1
                for n in x:
                    x[n] *= s
                return x
        return x
        # raise nx.PowerIterationFailedConvergence(max_iter)