import time
from collections import deque
from utils.Labelling import Labelling


class PreferredExtensionFinder:
    def __init__(self, af):
        self.af = af
        self.extensions = set()
        self.extension = set()
        self.heuristic = None
        self.prediction = None
        self.empty_Pred = True
        self.found_preferred = False
        self.must_out_stack = deque()

    def get_preferred_extension(self, query, analytics_dict=None):
        """
        Implementation of Algorithm 4 with optimizations including:
        - Iterative right transitions instead of recursion
        - More frequent hopeless labeling checks
        - Initial labeling propagation
        """
        lab = self._create_initial_justification_labelling(query)
        for arg in self.af.arguments:
            lab.create_BLANK_mapping(self.af, arg)
            lab.create_UNDEC_mapping(self.af, arg)
        self.extension = set()
        justified = self._decide_preferred_credulous(lab, query, analytics_dict)
        if justified:
            return next(iter(self.extension))
        else:
            return None


    def _decide_preferred_credulous(self, lab, query, analytics_dict):
        # Initial propagation
        self._propagate_labeling(lab)

        if self.found_preferred:  # Terminate if already found
            return True
        if self._is_hopeless_labeling(lab) or self._is_non_maximal_labeling(lab):
            return False

        while not self.is_terminal_justification_labeling(lab, query):
            # Select influential argument
            next_arg = self._select_influential_coAdm_argument(lab, query, analytics_dict)
            #next_arg = self._select_influential_argument_for_justification(lab)
            #print(next_arg)
            if not next_arg:
                break

            # Left transition
            lab_left = lab.copy()
            self._apply_left_transition(lab_left, next_arg)
            if self.found_preferred:
                return True
            if not (self._is_hopeless_labeling(lab_left) or self._is_non_maximal_labeling(lab_left)):
                self._decide_preferred_credulous(lab_left, query, analytics_dict)


            # Right transition (iterative)
            if self.found_preferred:  # Check again after right transition
                return True
            else:
                if analytics_dict:
                    analytics_dict['backtracks'] += 1
            self._apply_right_transition(lab, next_arg)
            if self._is_hopeless_labeling(lab) or self._is_non_maximal_labeling(lab):
                return False

        # Check if admissible at terminal state
        if self._is_admissible_labeling(lab):
            in_set = {arg for arg, data in lab.label_dict.items()
                      if data['label'] == 'in'}
            self.extension.add(frozenset(in_set))
            self.found_preferred = True
            return True


    # Helper methods

    def _create_initial_justification_labelling(self, query):
        """Create initial labelling for justifying query argument."""
        lab = Labelling(self.af)
        labels = lab.label_dict

        # Label the query as "in"
        labels[query]['label'] = 'in'

        # Process all arguments in the grounded extension.
        for arg in self.af.grounded_ext:
            labels[arg]['label'] = 'in'
            # Label all predecessors as "out"
            for attacker in self.af.arguments[arg]['predecessors']:
                labels[attacker]['label'] = 'out'
            # Label all successors as "out"
            for attacked in self.af.arguments[arg]['successors']:
                labels[attacked]['label'] = 'out'

        # Cache the query's argument details.
        query_data = self.af.arguments[query]

        # Label successors of the query as "out".
        for arg in query_data['successors']:
            labels[arg]['label'] = 'out'

        # For each predecessor of the query, mark as "must_out" if not already "out".
        for arg in query_data['predecessors']:
            if labels[arg]['label'] != 'out':
                labels[arg]['label'] = 'must_out'

        # Combine predecessors and successors into one set for quick membership testing.
        query_related = set(query_data['predecessors']).union(query_data['successors'])

        # For self-attackers not related to the query, label as "undec".
        for arg in self.af.self_attackers:
            if arg != query and arg not in query_related:
                labels[arg]['label'] = 'undec'

        return lab

    def is_terminal_justification_labeling(self, lab, query):
        """Check if labelling is terminal for deciding acceptance"""

        blank_args = {arg for arg, data in lab.label_dict.items()
                      if data['label'] == 'blank'
        }

        for arg in blank_args:
            attackeds = self.af.arguments[arg]['successors']

            if any(
                lab.label_dict[attacked]['label'] == 'must_out'
                for attacked in attackeds
            ):
                return False
        return True

    def _is_admissible_labeling(self, lab):
        """Check if labeling is admissible (no MUST_OUT arguments)"""
        return not any(data['label'] == 'must_out'
                       for data in lab.label_dict.values())


    def _is_non_maximal_labeling(self, lab):
        undec_args = {
            arg for arg, data in lab.label_dict.items()
            if data['label'] == 'undec'
        }
        # Early exit if no must_out arguments
        if not undec_args:
            return False

        for arg in undec_args:
            if lab.label_dict[arg]["BLANK"] == 0 and lab.label_dict[arg]["UNDEC"] == 0:
                return True
        return False

    def _is_hopeless_labeling(self, lab):
        """
        Check if there exists a MUST_OUT argument that cannot be attacked
        by any remaining BLANK arguments
        """
        must_out_args = {
            arg for arg, data in lab.label_dict.items()
            if data['label'] == 'must_out'
        }

        # Early exit if no must_out arguments
        if not must_out_args:
            return False

        # Find arguments that need to be added to must_out_stack
        new_must_outs = {
            arg for arg in must_out_args
            if lab.label_dict[arg]['BLANK'] == 0 and arg not in self.must_out_stack
        }

        has_hopeless_arg = any(lab.label_dict[arg]['BLANK'] == 0 for arg in must_out_args)

        if new_must_outs:
            self.must_out_stack.extend(new_must_outs)

        return has_hopeless_arg

    def _select_influential_argument_for_justification(self, lab):
        """
        Find the most influential blank argument that affects target_arg.
        Returns the blank argument that attacks must_out arguments and has the most connections.
        """
        # Get all BLANK arguments that attack MUST_OUT arguments

        blank_attackers = {
            attacker
            for arg in self.must_out_stack
            if lab.label_dict[arg]['label'] == 'must_out'
            for attacker in self.af.arguments[arg]['predecessors']
            if lab.label_dict[attacker]['label'] == 'blank'
        }

        if blank_attackers:
            return max(
                blank_attackers,
                key=lambda arg: (
                    len(self.af.arguments[arg]['predecessors']) +
                    len(self.af.arguments[arg]['successors']),
                    arg
                ),
                default=None
            )

        candidates = {
            arg for arg in self.af.arguments
            if lab.label_dict[arg]['label'] == 'blank' and
               any(lab.label_dict[attacked]['label'] == 'must_out'
                   for attacked in self.af.arguments[arg]['successors'])
        }

        # If no candidates found, return None
        if not candidates:
            return None

        # Find the argument with the most neighbors
        return max(
            candidates,
            key=lambda arg: (
                len(self.af.arguments[arg]['predecessors']) +
                len(self.af.arguments[arg]['successors']),
                arg
            ),
            default=None
        )

    def _select_influential_coAdm_argument(self, lab, query, analytics_dict):
        #ensure prediction is available
        if self.prediction is None:
            start_pred = time.perf_counter()
            from utils import CoAdmNN
            self.prediction = CoAdmNN.predict(self.af, query)
            self.empty_Pred = not any(data['coAdm'] > 0 for arg, data in self.prediction.items())
            end_pred = time.perf_counter()
            if analytics_dict:
                analytics_dict['runtime_prediction'] = (end_pred - start_pred)

        if self.empty_Pred:
            return self._select_influential_argument_for_justification(lab)

        candidates = {
            arg for arg, data in self.prediction.items() if data['coAdm'] == 1.0 and lab.label_dict[arg]['label'] == 'blank' and
               any(lab.label_dict[attacked]['label'] == 'must_out'
                   for attacked in self.af.arguments[arg]['successors'])
        }

        # candidates = {
        #     arg for arg, data in self.prediction.items() if data['notCoAdm'] == 1.0 and lab.label_dict[arg]['label'] == 'blank'
        # }

        if not candidates:
            return self._select_influential_argument_for_justification(lab)

        # Find the argument with the most neighbors (predecessors + successors)
        return max(
            candidates,
            key=lambda arg: len(self.af.arguments[arg]['predecessors']) +
                            len(self.af.arguments[arg]['successors']),
            default=None
        )


    def _apply_left_transition(self, lab, arg):
        af_args = self.af.arguments
        lab_dict = lab.label_dict
        """Label argument as IN, its attackers as MUST_OUT, attacked as OUT"""
        lab_dict[arg]['label'] = 'in'
        #arg was blank before so decrease blank for all sucessors
        for attacked in af_args[arg]['successors']:
            lab_dict[attacked]['BLANK'] -= 1

        # Label attacked arguments as OUT
        for succ in self.af.arguments[arg]['successors']:
            if lab_dict[succ]['label'] == 'out':
                continue
            #if label has been blank before decrease blank for all successor
            if lab_dict[succ]['label'] in {'blank', 'must_in', 'must_undec'}:
                for attacked in af_args[succ]['successors']:
                    lab_dict[attacked]['BLANK'] -= 1
            #if label has been undec before decrease undec for all successors
            elif lab_dict[succ]['label'] == 'undec':
                for attacked in af_args[succ]['successors']:
                    lab_dict[attacked]['UNDEC'] -= 1
            lab_dict[succ]['label'] = 'out'


        # Label attackers as MUST_OUT if not already OUT
        for pred in self.af.arguments[arg]['predecessors']:
            if lab_dict[pred]['label'] != 'out':
                #if label has been blank before decrease blank for all successor
                if lab_dict[pred]['label'] in {'blank', 'must_in', 'must_undec'}:
                    for attacked in af_args[pred]['successors']:
                        lab_dict[attacked]['BLANK'] -= 1
                #if label has been undec before decrease undec for all successors
                elif lab_dict[pred]['label'] == 'undec':
                    for attacked in af_args[pred]['successors']:
                        lab_dict[attacked]['UNDEC'] -= 1
                lab_dict[pred]['label'] = 'must_out'


    def _apply_right_transition(self, lab, arg):
        af_args = self.af.arguments
        """Label argument as UNDEC"""
        lab.label_dict[arg]['label'] = 'undec'
        for attacked in af_args[arg]['successors']:
            # decrease blank for all successors
            lab.label_dict[attacked]['BLANK'] -= 1
            # increase undec for all successors
            lab.label_dict[attacked]['UNDEC'] += 1

    def _propagate_labeling(self, lab):
        """
        Propagate labelings by finding blank arguments that must be IN
        (i.e., all attackers are OUT or MUST_OUT)
        """
        lab_dict = lab.label_dict
        af_args = self.af.arguments

        changed = True
        while changed:
            changed = False
            blank_args = [arg for arg, data in lab_dict.items() if data['label'] == 'blank']

            for arg in blank_args:
                attackers = af_args[arg]['predecessors']
                if all(lab_dict[att]['label'] in {'out', 'must_out'} for att in attackers):
                    # Mark the argument as "in"
                    lab_dict[arg]['label'] = 'in'
                    for attacked in af_args[arg]['successors']:
                        lab_dict[attacked]['BLANK'] -= 1
                    # Label all successors as "out"
                    for succ in af_args[arg]['successors']:
                        if lab_dict[succ]['label'] == 'out':
                            continue
                        if lab_dict[succ]['label'] in {'blank','must_in','must_undec'}:
                            for attacked in af_args[succ]['successors']:
                                lab_dict[attacked]['BLANK'] -= 1
                        elif lab_dict[succ]['label'] == 'undec':
                            for attacked in af_args[succ]['successors']:
                                lab_dict[attacked]['UNDEC'] -= 1
                        lab_dict[succ]['label'] = 'out'
                    changed = True

        return lab
