import os
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
import networkx as nx
import pandas as pd
import numpy as np
import random
import glob

from utils.GraphFeatureCalculator import GraphFeatureCalculator

# training_data = "Path/to/training/data"
# test_data = "Path/to/test/data"
# solution_data_train = "Path/to/training/data/solutions"
# solution_data_test = "Path/to/test/data/solutions"
selected_features = ['in_degree', 'out_degree']
model_save_path = "checkpoint_epoch_47_loss_0.1203.pt"
standardize_features = False
normalize_features = True

#convert nx graph to PyTorch geometric format
def networkx_to_pyg(nx_graph, feature_df, query_node_idx):
    # Convert mapping from node labels to integer indices
    node_mapping = {node: idx for idx, node in enumerate(nx_graph.nodes())}
    edges = [(node_mapping[u], node_mapping[v]) for u, v in nx_graph.edges()]

    # Create edge index tensor
    edge_index = torch.tensor(edges).t().contiguous()

    # map query node to int index
    query_node_idx = node_mapping[list(nx_graph.nodes())[query_node_idx]]

    # Convert features to tensor
    features = feature_df.values
    x = torch.tensor(features, dtype=torch.float)

    # Create one-hot query node feature
    num_nodes = nx_graph.number_of_nodes()
    query_feature = torch.zeros(num_nodes, 1)
    query_feature[query_node_idx] = 1.0

    # Concatenate with existing features
    x = torch.cat([x, query_feature], dim=1)

    # Create PyG Data object
    data = Data(x=x, edge_index=edge_index)
    return data

#Define model
class CoAdmNN(torch.nn.Module):
    def __init__(self, num_features):
        super(CoAdmNN, self).__init__()
        self.conv1 = GCNConv(num_features, 64)
        #self.conv2 = GCNConv(64, 64)
        self.conv3 = GCNConv(64,16)
        self.out = torch.nn.Linear(16, 1)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        #Apply GCN layers with ReLu
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.2, training=self.training)
        # x = F.relu(self.conv2(x, edge_index))
        # x = F.dropout(x, p=0.2, training=self.training)
        x = F.relu(self.conv3(x, edge_index))

        #Output
        x = self.out(x)
        return torch.sigmoid(x)


def predict_coadmissibility(nx_graph, feature_df, query_node_idx, model):
    # Convert to PyG format
    data = networkx_to_pyg(nx_graph, feature_df, query_node_idx)

    # Make prediction
    model.eval()
    with torch.no_grad():
        predictions = model(data)

    # Return predictions as numpy array
    return predictions.numpy().flatten()

def build_di_graph(file):
    di_graph = nx.DiGraph()
    with open(file) as f:
        af = f.read()
    f.close()
    graph = af.splitlines()

    for node in graph:
        node = node.replace('.', '')
        if 'arg(' in node:
            argument = node.replace('arg(', '').replace(')', '')
            di_graph.add_node(argument)
        elif 'att(' in node:
            attack = node.replace('att(', '').replace('', '').replace(')', '')
            attacking_node, attacked_node = attack.split(',')
            di_graph.add_edge(attacking_node, attacked_node)
    return di_graph


def build_di_graph_from_af(af_object):
    di_graph = nx.DiGraph()

    # Add all arguments as nodes
    for argument in af_object.arguments.keys():
        di_graph.add_node(argument)

    # Add all attacks as edges
    for source, data in af_object.arguments.items():
        for target in data['successors']:
            di_graph.add_edge(source, target)

    return di_graph

def parse_co_adm_args(file_path):
    # Get the base name of the file (without directories)
    base_name = os.path.basename(file_path)
    # Construct the co-admissible file name
    co_adm_filename = f"{base_name}_co_adm_args.txt"

    # The co-adm args file is stored in a subfolder "co_adm_solutions" relative to the original file's directory
    file_dir = os.path.dirname(file_path)
    co_adm_dir = os.path.join(file_dir, "co_adm_solutions")
    co_adm_file_path = os.path.join(co_adm_dir, co_adm_filename)

    # Check if the file exists
    if not os.path.isfile(co_adm_file_path):
        print(f"No co-adm args file found for {file_path}.")
        return None

    co_adm_dict = {}
    with open(co_adm_file_path, "r") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue  # Skip empty lines
            # Expect each line to be in the format: "arg: co_arg1, co_arg2, co_arg3"
            try:
                key, values_str = line.split(":", 1)
            except ValueError:
                print(f"Skipping malformed line: {line}")
                continue
            key = key.strip()
            values_str = values_str.strip()
            if values_str:
                # Split the values by comma and remove extra whitespace
                values = {v.strip() for v in values_str.split(",") if v.strip()}
            else:
                values = set()
            co_adm_dict[key] = values
    return co_adm_dict

def create_feature_vector(data_path, graph_list, feature_calculator):
    dfs = []

    for graph in graph_list:
        print('creating features for graph ' + graph)
        di_graph = build_di_graph(data_path + '/' + graph)

        # Calculate features
        features_df = feature_calculator.calculate_features(di_graph, standardize = standardize_features, normalize = normalize_features)

        dfs.append(features_df)

    return pd.concat(dfs, ignore_index=True)


def train_model(train_graphs_path, train_labels_path, num_query_nodes=5, epochs=100, batch_size=32, lr=0.01, seed=42,
                model_save_path="coadm_model.pt", checkpoint_dir="checkpoints", resume_from=None):

    # Set random seed for reproducibility
    random.seed(seed)
    torch.manual_seed(seed)

    # Create checkpoint directory if it doesn't exist
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Initialize feature calculator
    feature_calc = GraphFeatureCalculator(features_to_calculate=selected_features)

    # Get all graph files
    graph_files = glob.glob(os.path.join(train_graphs_path, "*.apx"))

    # Calculate input feature dimension from a sample graph
    sample_graph = build_di_graph(graph_files[0])
    sample_features = feature_calc.calculate_features(sample_graph, standardize=standardize_features,
                                                      normalize=normalize_features)
    num_features = sample_features.shape[1] + 1  # +1 for query node feature

    # Initialize model and optimizer
    model = CoAdmNN(num_features)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = torch.nn.BCELoss()

    # Variables to track training progress
    start_epoch = 0
    best_loss = float('inf')

    # Resume from checkpoint if specified
    if resume_from and os.path.exists(resume_from):
        print(f"Loading checkpoint from {resume_from}")
        checkpoint = torch.load(resume_from)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        best_loss = checkpoint['best_loss']
        print(f"Resuming from epoch {start_epoch} with best loss: {best_loss:.4f}")

    # Training loop
    model.train()
    for epoch in range(start_epoch, epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        total_loss = 0
        num_samples = 0

        # Process graphs in chunks to avoid loading all at once
        chunk_size = min(batch_size * 2, len(graph_files))  # Process 2 batches worth of graphs at a time

        for chunk_start in range(0, len(graph_files), chunk_size):
            chunk_end = min(chunk_start + chunk_size, len(graph_files))
            chunk_files = graph_files[chunk_start:chunk_end]

            # Process each graph in the chunk
            data_list = []
            for graph_file in chunk_files:
                # Build networkx graph
                nx_graph = build_di_graph(graph_file)

                # Calculate features for this graph
                features = feature_calc.calculate_features(nx_graph, standardize=standardize_features,
                                                           normalize=normalize_features)

                # Parse co-admissibility solutions
                coadm_solutions = parse_co_adm_args(graph_file)

                # Get a list of all nodes
                all_nodes = list(nx_graph.nodes())

                # Sample query nodes
                if len(all_nodes) <= num_query_nodes:
                    query_nodes = all_nodes  # Use all nodes if fewer than num_query_nodes
                else:
                    query_nodes = random.sample(all_nodes, num_query_nodes)

                # For each query node, create a training instance
                for query_node in query_nodes:
                    query_idx = all_nodes.index(query_node)

                    # Create labels vector (1 for co-admissible nodes, 0 for others)
                    co_admissible_nodes = coadm_solutions.get(query_node, [])
                    labels = np.zeros(len(all_nodes))

                    for co_adm_node in co_admissible_nodes:
                        node_idx = all_nodes.index(co_adm_node)
                        labels[node_idx] = 1.0

                    # Convert to PyG data object
                    data = networkx_to_pyg(nx_graph, features, query_idx)
                    data.y = torch.tensor(labels, dtype=torch.float).view(-1, 1)
                    data_list.append(data)

            # Create dataloader for this chunk
            loader = DataLoader(data_list, batch_size=batch_size, shuffle=True)

            # Train on this chunk
            for batch in loader:
                # Forward pass
                optimizer.zero_grad()
                out = model(batch)
                loss = criterion(out, batch.y)

                # Backward pass and optimize
                loss.backward()
                optimizer.step()

                # Track loss
                total_loss += loss.item() * batch.num_graphs
                num_samples += batch.num_graphs

            # Explicitly clear memory
            del data_list, loader
            torch.cuda.empty_cache() if torch.cuda.is_available() else None

        # Calculate average loss for the epoch
        avg_loss = total_loss / num_samples

        # Print progress
        if (epoch + 1) % 10 == 0 or epoch == 0:
            print(f'Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}')

        # Save checkpoint if loss has decreased
        if avg_loss < best_loss:
            best_loss = avg_loss
            checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch + 1}_loss_{avg_loss:.4f}.pt")
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
                'best_loss': best_loss,
                'num_features': num_features
            }, checkpoint_path)
            print(f"Checkpoint saved to {checkpoint_path} (loss improved to {avg_loss:.4f})")

    # Save the final trained model
    torch.save({
        'model_state_dict': model.state_dict(),
        'num_features': num_features,
        'epoch': epochs,
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_loss,
        'best_loss': best_loss
    }, model_save_path)

    print(f"Final model saved to {model_save_path}")

    return model


def load_model(checkpoint_name):
    current_dir = os.getcwd()
    checkpoint_path = os.path.join(current_dir,'utils', checkpoint_name)
    checkpoint = torch.load(checkpoint_path)
    num_features = checkpoint['num_features']

    # Initialize the model
    model = CoAdmNN(num_features)

    # Load the state dict
    model.load_state_dict(checkpoint['model_state_dict'])

    #print(f"Model loaded from {checkpoint_path} (Epoch: {checkpoint['epoch']}, Loss: {checkpoint['loss']:.4f})")

    return model

def evaluate_model(model, test_graphs_path, test_labels_path, num_query_nodes=5, seed=42):
    # Set random seed for reproducibility
    random.seed(seed)

    # Load graph files
    graph_files = glob.glob(os.path.join(test_graphs_path, "*.apx"))

    # Set model to evaluation mode
    model.eval()

    # Metrics
    total_accuracy = 0
    perfect_pred = 0
    useless_pred = 0
    total_precision = 0
    total_recall = 0
    total_f1 = 0
    total_instances = 0

    feature_calc = GraphFeatureCalculator(features_to_calculate=selected_features)

    # Process each test graph
    for graph_file in graph_files:
        # Build networkx graph
        nx_graph = build_di_graph(graph_file)

        # Calculate features
        features = feature_calc.calculate_features(nx_graph, standardize = standardize_features, normalize = normalize_features)

        # Parse co-admissibility solutions
        coadm_solutions = parse_co_adm_args(graph_file)

        if not coadm_solutions:
            continue

        # Get list of all nodes
        all_nodes = list(nx_graph.nodes())

        # Sample query nodes
        if len(all_nodes) <= num_query_nodes:
            query_nodes = all_nodes
        else:
            query_nodes = random.sample(all_nodes, num_query_nodes)

        # Evaluate each query node
        for query_node in query_nodes:
            query_idx = all_nodes.index(query_node)

            # Create PyG data object
            data = networkx_to_pyg(nx_graph, features, query_idx)

            # Get ground truth labels
            co_admissible_nodes = coadm_solutions.get(query_node, [])
            true_labels = np.zeros(len(all_nodes))

            for co_adm_node in co_admissible_nodes:
                node_idx = all_nodes.index(co_adm_node)
                true_labels[node_idx] = 1.0


            # Make predictions
            with torch.no_grad():
                predictions = model(data).cpu().numpy().flatten()

            # Convert to binary predictions (threshold at 0.7)
            binary_preds = (predictions >= 0.7).astype(np.float32)

            # Calculate metrics
            accuracy = np.mean(binary_preds == true_labels)
            if accuracy == 1.0:
                perfect_pred += 1

            if accuracy == 0.0:
                useless_pred += 1

            # Avoid division by zero
            if np.sum(binary_preds) > 0:
                precision = np.sum((binary_preds == 1) & (true_labels == 1)) / np.sum(binary_preds)
            else:
                precision = 0

            if np.sum(true_labels) > 0:
                recall = np.sum((binary_preds == 1) & (true_labels == 1)) / np.sum(true_labels)
            else:
                recall = 0

            if precision + recall > 0:
                f1 = 2 * precision * recall / (precision + recall)
            else:
                f1 = 0

            # Update totals
            total_accuracy += accuracy
            total_precision += precision
            total_recall += recall
            total_f1 += f1
            total_instances += 1

    # Calculate averages
    avg_accuracy = total_accuracy / total_instances
    avg_precision = total_precision / total_instances
    avg_recall = total_recall / total_instances
    avg_f1 = total_f1 / total_instances
    perfect_pred_percent = perfect_pred/total_instances
    useless_pred_percent = useless_pred/total_instances

    # Print results
    print(f"Evaluation on {total_instances} instances:")
    print(f"Accuracy: {avg_accuracy:.4f}")
    print(f"% of perfect predictions: {perfect_pred_percent}")
    print(f"perfect predictions: {perfect_pred}")
    print(f"% of useless predictions: {useless_pred_percent}")
    print(f"useless predictions: {useless_pred}")
    print(f"Precision: {avg_precision:.4f}")
    print(f"Recall: {avg_recall:.4f}")
    print(f"F1 Score: {avg_f1:.4f}")

    return {
        "accuracy": avg_accuracy,
        "precision": avg_precision,
        "perfect_pred": perfect_pred,
        "recall": avg_recall,
        "f1": avg_f1,
        "num_instances": total_instances
    }


def predict(af, query_node):
    model = load_model(model_save_path)
    nx_graph = af.di_graph
    feature_calc = GraphFeatureCalculator(features_to_calculate=selected_features)

    # Calculate features
    features = feature_calc.calculate_features(nx_graph, standardize=standardize_features, normalize=normalize_features)

    # Get list of all nodes
    all_nodes = list(nx_graph.nodes())
    query_idx = all_nodes.index(query_node)

    # Create PyG data object
    data = networkx_to_pyg(nx_graph, features, query_idx)

    # Make predictions
    with torch.no_grad():
        predictions = model(data).cpu().numpy().flatten()
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    coAdm_preds = (predictions >= 0.7).astype(np.float32)
    notCoAdm_pred = (predictions <= 0.3).astype(np.float32)

    # Create temporary structure for sorting
    temp_data = [(node, co_pred, not_co_pred, pred_val)
                 for node, co_pred, not_co_pred, pred_val
                 in zip(all_nodes, coAdm_preds, notCoAdm_pred, predictions)]

    # Sort items by prediction value (descending) and then by node name (lexicographically) to calculate pred_order
    sorted_items = sorted(temp_data, key=lambda x: (-x[3], x[0]))

    # Create a mapping from node to pred_order
    pred_order_map = {item[0]: i + 1 for i, item in enumerate(sorted_items)}

    # Create the final dictionary in lexicographical order of node names
    predict_dict = {}
    for node, co_pred, not_co_pred, pred_val in sorted(temp_data, key=lambda x: x[0]):
        predict_dict[node] = {
            'coAdm': co_pred,
            'notCoAdm': not_co_pred,
            'pred_value': pred_val,
            'pred_order': pred_order_map[node]
        }

    return predict_dict


def predict_set(af, lab):
    model = load_model(model_save_path)
    nx_graph = build_di_graph_from_af(af)
    feature_calc = GraphFeatureCalculator(features_to_calculate=selected_features)

    # Calculate features
    features = feature_calc.calculate_features(nx_graph, standardize=standardize_features, normalize=normalize_features)

    #query_nodes = {arg for arg, data in lab.label_dict.items() if (data['label'] != 'out' and data['label'] != 'must_out')}

    # Get list of all nodes
    all_nodes = list(nx_graph.nodes())

    # Initialize result dictionary
    result = {
        'individual_co_admissible': {},
        'individual_not_co_admissible': {},
        'all_predictions': {},
        'co_admissible_intersection': set(),
        'not_co_admissible_intersection': set()
    }

    for query_node in all_nodes:
        query_idx = all_nodes.index(query_node)

        # Create PyG data object
        data = networkx_to_pyg(nx_graph, features, query_idx)

        # Make predictions
        with torch.no_grad():
            predictions = model(data).cpu().numpy().flatten()
        torch.cuda.synchronize() if torch.cuda.is_available() else None
        coAdm_preds = (predictions >= 0.7).astype(np.float32)
        notCoAdm_pred = (predictions <= 0.3).astype(np.float32)

        # Store the predictions
        result['all_predictions'][query_node] = predictions

        # Store node indices that are co-admissible with the query node
        result['individual_co_admissible'][query_node] = {all_nodes[i] for i, val in enumerate(coAdm_preds) if
                                                          val == 1.0}

        # Store node indices that are not co-admissible with the query node
        result['individual_not_co_admissible'][query_node] = {all_nodes[i] for i, val in enumerate(notCoAdm_pred) if
                                                              val == 1.0}

    # Calculate intersection of co-admissible sets
    if result['individual_co_admissible']:
        result['co_admissible_intersection'] = set.intersection(*result['individual_co_admissible'].values())

    # Calculate intersection of not co-admissible sets
    if result['individual_not_co_admissible']:
        result['not_co_admissible_intersection'] = set.intersection(*result['individual_not_co_admissible'].values())

    return result


########################################################################################################################
# CODE FOR TRAINING AND EVALUATION
########################################################################################################################


# # Check if model checkpoint exists
# if not os.path.exists(model_save_path):
#     # Train the model if no checkpoint exists
#     print("No existing model found. Training new model...")
#     model = train_model(
#         train_graphs_path=training_data,
#         train_labels_path=solution_data_train,
#         num_query_nodes=100,
#         epochs=10,
#         batch_size=32,
#         lr=0.02,
#         seed=42,
#         model_save_path=model_save_path,
#  )
#
# else:
#     # Load existing model if checkpoint exists
#     print("Existing model found. Loading checkpoint...")
#     model = load_model(model_save_path)
#
#     # Train the model if no checkpoint exists
#     print("Resume training model...")
#     model = train_model(
#         train_graphs_path=training_data,
#         train_labels_path=solution_data_train,
#         num_query_nodes=100,
#         epochs=100,
#         batch_size=32,
#         lr=0.02,
#         seed=42,
#         model_save_path=model_save_path,
#         checkpoint_dir="checkpoints",
#         resume_from=model_save_path
#     )
# # Evaluate the model
# metrics = evaluate_model(
#     model=model,
#     test_graphs_path=test_data,
#     test_labels_path=solution_data_test,
#     num_query_nodes=200,
#     seed=56
# )
#
# print("Evaluation complete!")