#!/usr/bin/env python3
"""
Test harness for the ICCMA solver wrapper script (e.g., run_solver.sh or just solver).

Iterates through a directory containing .af and .arg files,
runs the solver for a specified task, compares the output to a ground truth
CSV file, and calculates performance metrics (Accuracy, Precision, NPV, MCC).
"""

import argparse
import subprocess
import sys
import math
import csv # Added for CSV reading
from pathlib import Path
from collections import defaultdict

def calculate_metrics(tp: int, tn: int, fp: int, fn: int) -> dict:
    """
    Calculates Accuracy, Precision, NPV, and MCC.

    Args:
        tp: True Positives count.
        tn: True Negatives count.
        fp: False Positives count.
        fn: False Negatives count.

    Returns:
        A dictionary containing the calculated metrics. Handles division by zero.
    """
    metrics = {
        "Accuracy": 0.0,
        "Precision": 0.0,
        "NPV": 0.0,  # Negative Predictive Value
        "MCC": 0.0   # Matthews Correlation Coefficient
    }
    total = tp + tn + fp + fn

    # Accuracy
    if total > 0:
        metrics["Accuracy"] = (tp + tn) / total

    # Precision (Positive Predictive Value)
    # How often is the solver correct when it predicts YES?
    if (tp + fp) > 0:
        metrics["Precision"] = tp / (tp + fp)
    elif tp == 0 and fp == 0:
         # If TP and FP are both 0, it means the solver never predicted YES.
         # Conventionally, Precision is sometimes set to 1 in this edge case,
         # or sometimes 0 or NaN. Setting to 1 implies perfect precision among zero positive predictions.
         metrics["Precision"] = 1.0

    # NPV (Negative Predictive Value)
    # How often is the solver correct when it predicts NO?
    if (tn + fn) > 0:
        metrics["NPV"] = tn / (tn + fn)
    elif tn == 0 and fn == 0:
         # If TN and FN are both 0, it means the solver never predicted NO.
         # Similar edge case to Precision. Setting to 1 implies perfect NPV among zero negative predictions.
         metrics["NPV"] = 1.0


    # MCC (Matthews Correlation Coefficient)
    # Ranges from -1 (total disagreement) to +1 (perfect agreement), 0 is random chance.
    mcc_numerator = (tp * tn) - (fp * fn)
    mcc_denominator_sq = (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)
    if mcc_denominator_sq > 0:
        metrics["MCC"] = mcc_numerator / math.sqrt(mcc_denominator_sq)
    # Handle cases where the denominator is 0.
    elif tp > 0 and tn > 0 and fp == 0 and fn == 0: # Perfect positive prediction
        metrics["MCC"] = 1.0
    elif tp == 0 and tn == 0 and fp > 0 and fn > 0: # Perfect inverse prediction
        metrics["MCC"] = -1.0
    # Otherwise, if denominator is 0 (meaning at least one row/column sum in the
    # confusion matrix is zero), MCC is typically considered 0.
    else:
        metrics["MCC"] = 0.0


    return metrics

def load_ground_truth(csv_path: Path, target_task: str) -> dict:
    """
    Loads ground truth answers from a CSV file for a specific task.
    Infers the actual ground truth based on the ANSWER and CORRECT columns.

    Args:
        csv_path: Path to the CSV file.
                  Expected format: TASK,INSTANCE,SOLVER,ANSWER,CORRECT
        target_task: The specific task (e.g., "DC-PR") to load data for.

    Returns:
        A dictionary mapping instance filename (str) to the *inferred*
        actual ground truth answer (bool). Returns an empty dictionary
        if the file cannot be read or is empty.
    """
    ground_truth = {}
    try:
        with open(csv_path, 'r', newline='') as csvfile:
            reader = csv.reader(csvfile)
            header = next(reader) # Skip header row

            # Ensure header has at least 5 columns
            if len(header) < 5:
                 print(f"Error: CSV header in '{csv_path}' has fewer than 5 columns. Expected format: TASK,INSTANCE,SOLVER,ANSWER,CORRECT", file=sys.stderr)
                 sys.exit(1)


            for i, row in enumerate(reader):
                # Ensure row has at least 5 elements
                if len(row) < 5:
                    print(f"Warning: Skipping malformed row {i+2} in '{csv_path}' (needs at least 5 columns): {row}", file=sys.stderr)
                    continue

                # Read relevant columns
                task, instance, _, answer_raw, correct_raw = row[0], row[1], row[2], row[3], row[4]

                # Only load entries for the target task
                if task == target_task:
                    instance_filename = Path(instance).name # Use only the filename part
                    answer_str = answer_raw.strip().upper()
                    correct_str = correct_raw.strip().lower() # Handle 'true'/'false' case-insensitively

                    # Validate input from CSV
                    if answer_str not in ["YES", "NO"]:
                         print(f"Warning: Skipping row {i+2} in '{csv_path}'. Invalid ANSWER '{answer_raw}' for instance '{instance}'. Expected YES or NO.", file=sys.stderr)
                         continue
                    if correct_str not in ["true", "false"]:
                         print(f"Warning: Skipping row {i+2} in '{csv_path}'. Invalid CORRECT value '{correct_raw}' for instance '{instance}'. Expected true or false.", file=sys.stderr)
                         continue

                    # --- Start of Patch Logic ---
                    # Infer the actual ground truth based on the CSV's prediction and correctness
                    csv_prediction_is_yes = (answer_str == "YES")
                    csv_prediction_was_correct = (correct_str == "true")

                    actual_ground_truth_bool: bool
                    if csv_prediction_is_yes and csv_prediction_was_correct:
                        # CSV said YES and was right -> Ground Truth is YES
                        actual_ground_truth_bool = True
                    elif csv_prediction_is_yes and not csv_prediction_was_correct:
                        # CSV said YES and was wrong -> Ground Truth is NO
                        actual_ground_truth_bool = False
                    elif not csv_prediction_is_yes and csv_prediction_was_correct:
                        # CSV said NO and was right -> Ground Truth is NO
                        actual_ground_truth_bool = False
                    else: # not csv_prediction_is_yes and not csv_prediction_was_correct
                        # CSV said NO and was wrong -> Ground Truth is YES
                        actual_ground_truth_bool = True

                    # Store the *actual* ground truth
                    ground_truth[instance_filename] = actual_ground_truth_bool
                    # --- End of Patch Logic ---

    except FileNotFoundError:
        print(f"Error: Ground truth CSV file not found at '{csv_path}'", file=sys.stderr)
        sys.exit(1)
    except Exception as e:
        print(f"Error reading ground truth CSV file '{csv_path}': {e}", file=sys.stderr)
        sys.exit(1)

    if not ground_truth:
         print(f"Warning: No ground truth data loaded for task '{target_task}' from '{csv_path}'. Check task name and CSV content.", file=sys.stderr)

    return ground_truth


def run_evaluation(task: str, test_dir: Path, solver_script_path: Path, ground_truth_data: dict):
    """
    Runs the evaluation process using ground truth from a dictionary.

    Args:
        task: The task name (e.g., "DC-PR") to pass to the solver.
        test_dir: The directory containing .af and .arg files.
        solver_script_path: The path to the solver wrapper script (executable).
        ground_truth_data: Dictionary mapping instance filename to expected answer (bool).
    """
     
    solver_cmd = str(solver_script_path)

    if not test_dir.is_dir():
        print(f"Error: Test directory not found at '{test_dir}'", file=sys.stderr)
        sys.exit(1)

    counts = defaultdict(int) # TP, TN, FP, FN
    processed_files = 0
    skipped_files = 0
    missing_gt_files = 0
    solver_errors = 0
    unexpected_output_files = 0

    print(f"Starting evaluation for task '{task}' in directory '{test_dir}'...")
    print(f"Using ground truth from CSV.")
    print("-" * 60)

    # Iterate through all .af files in the directory
    for af_file in sorted(test_dir.glob("*.af")):
        base_name = af_file.stem
        arg_file = test_dir / f"{base_name}.af.arg"

        # Check if corresponding .arg file exists
        if not arg_file.is_file():
            # This case might indicate incomplete test pairs, less common if generated together
            print(f"Warning: Skipping '{af_file.name}', missing corresponding '{arg_file.name}'")
            skipped_files += 1
            continue

        # Look up ground truth using the .af filename
        instance_filename = af_file.name
        if instance_filename not in ground_truth_data:
            print(f"Warning: Skipping '{instance_filename}'. No ground truth found in CSV for task '{task}'.")
            missing_gt_files += 1
            skipped_files += 1
            continue

        expected_sol = ground_truth_data[instance_filename] # Get boolean answer
        expected_sol_raw = "YES" if expected_sol else "NO"

        processed_files += 1

        # Read the argument from .arg file
        argument = arg_file.read_text().strip()
        if not argument:
                print(f"Warning: Skipping '{af_file.name}', argument file '{arg_file.name}' is empty.")
                skipped_files +=1
                processed_files -=1 # Don't count as processed if skipped due to empty arg
                continue

        # Construct the command to run the solver wrapper
        # Ensure paths are strings for subprocess
        command = [
            solver_cmd, # Use the potentially PATH-resolved command
            "-f", str(af_file),
            "-p", task,
            "-a", argument
        ]

        command = f"./solver.sh -f {str(af_file)} -p {task} -a {argument}"

        print(command)
        # Execute the solver script
        try:
            result = subprocess.run(command, capture_output=True, text=True, check=False, shell=True, timeout=300) # Added timeout (e.g., 5 minutes)
        except subprocess.TimeoutExpired:
                print(f"Error: Solver timed out for '{af_file.name}'. Skipping.")
                solver_errors += 1
                skipped_files += 1
                processed_files -= 1
                continue


        if result.returncode != 0:
            print(f"Error running solver for '{af_file.name}':", file=sys.stderr)
            print(f"  Command: {' '.join(command)}", file=sys.stderr)
            # Check if stderr is not empty before printing
            stderr_output = result.stderr.strip()
            if stderr_output:
                print(f"  Stderr: {stderr_output}", file=sys.stderr)
            else:
                print("  Stderr: (empty)", file=sys.stderr)
            print(f" -> Solver Error (Return Code: {result.returncode}), Skipping metrics for this file.")
            solver_errors += 1
            skipped_files += 1
            processed_files -= 1
            continue # Skip metrics calculation for this file

        # Parse the solver's output
        predicted_sol_raw = result.stdout.strip().upper()
        if predicted_sol_raw not in ["YES", "NO"]:
                print(f"Warning: Skipping '{af_file.name}', unexpected solver output: '{predicted_sol_raw}'")
                unexpected_output_files += 1
                skipped_files += 1
                processed_files -= 1 # Don't count as processed if skipped
                continue

        predicted_sol = (predicted_sol_raw == "YES") # Convert to boolean

        # Determine TP, TN, FP, FN
        correct = (predicted_sol == expected_sol)
        if predicted_sol and expected_sol:
            counts["TP"] += 1
        elif not predicted_sol and not expected_sol:
            counts["TN"] += 1
        elif predicted_sol and not expected_sol:
            counts["FP"] += 1
        elif not predicted_sol and expected_sol:
            counts["FN"] += 1

        print(f"File: {af_file.name}, Arg: {argument}, Expected: {expected_sol_raw}, Predicted: {predicted_sol_raw} -> {'CORRECT' if correct else 'WRONG'}")




    print("-" * 60)
    print(f"Evaluation Summary:")
    print(f"  Total .af files found : {len(list(test_dir.glob('*.af')))}")
    print(f"  Processed instances   : {processed_files}")
    print(f"  Skipped instances     : {skipped_files}")
    # Calculate the count for missing .arg files more accurately
    missing_arg_files = skipped_files - missing_gt_files - solver_errors - unexpected_output_files
    print(f"    (Missing .arg file) : {missing_arg_files if missing_arg_files >= 0 else 0}") # Ensure non-negative
    print(f"    (Missing ground truth): {missing_gt_files}")
    print(f"    (Solver error/timeout): {solver_errors}")
    print(f"    (Unexpected output) : {unexpected_output_files}")

    print(f"\nConfusion Matrix:")
    print(f"  TP: {counts['TP']} (Predicted YES, Expected YES)")
    print(f"  TN: {counts['TN']} (Predicted NO, Expected NO)")
    print(f"  FP: {counts['FP']} (Predicted YES, Expected NO) - Type I Error")
    print(f"  FN: {counts['FN']} (Predicted NO, Expected YES) - Type II Error")


    if processed_files > 0:
        metrics = calculate_metrics(counts['TP'], counts['TN'], counts['FP'], counts['FN'])
        print("\nPerformance Metrics:")
        for name, value in metrics.items():
            print(f"  {name:<10}: {value:.4f}")
    else:
        print("\nNo files were successfully processed. Cannot calculate metrics.")

    print("-" * 60)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Evaluate ICCMA solver performance against CSV ground truth.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter # Show default values in help
        )
    parser.add_argument("-t", "--task", required=True, help="The task name (e.g., DC-PR, DS-ST) to evaluate. Must match TASK column in CSV.")
    parser.add_argument("-d", "--directory", required=True, type=Path, help="Directory containing .af and .arg test files.")
    parser.add_argument("-g", "--ground-truth-csv", required=True, type=Path, help="Path to the CSV file containing ground truth. Format: TASK,INSTANCE,SOLVER,ANSWER,CORRECT")
    # Changed default solver path and updated help text
    parser.add_argument("--solver", type=Path, default="./solver.sh", help="Path to the executable solver wrapper script (e.g., 'solver' if in PATH, or './run_solver.sh').")

    args = parser.parse_args()

    # Load ground truth first
    ground_truth = load_ground_truth(args.ground_truth_csv, args.task)

    if not ground_truth and not args.ground_truth_csv.exists():
         # Error message already printed by load_ground_truth if file not found
         sys.exit(1)
    elif not ground_truth and args.ground_truth_csv.exists():
         # Warning already printed, but exit if no data for the task was found
         print(f"Exiting because no ground truth data could be loaded for task '{args.task}'.", file=sys.stderr)
         sys.exit(1)


    # Run the evaluation
    run_evaluation(args.task, args.directory, args.solver, ground_truth)

