#!/usr/bin/env python3

"""
Copyright <2023-2025> <Tuomo Lehtonen, Aalto University, University of Helsinki>

Permission is hereby granted, free of charge, to any person obtaining a copy of this
software and associated documentation files (the "Software"), to deal in the Software
without restriction, including without limitation the rights to use, copy, modify,
merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to the following
conditions:

The above copyright notice and this permission notice shall be included in all copies
or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT
OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
OTHER DEALINGS IN THE SOFTWARE.
"""

import subprocess
import sys, argparse, os
import tempfile
import pathlib
from collections import deque

class AF:
    def __init__(self, args = set(), attackers = dict()):
        self.args = args
        self.attackers = attackers
        self.claim_to_args = dict()
        self.highest_arg_index = len(self.args)
        self.queried = []
        self.arg_to_asmpt = dict()

    def create_from_ABAF(self, ABAF):
        self.highest_arg_index = len(ABAF.heads)
        self.claim_to_args = ABAF.rules_deriving.copy()

        for asmpt in ABAF.assumptions:
            self.claim_to_args[asmpt] = [self.highest_arg_index]
            self.arg_to_asmpt[self.highest_arg_index] = asmpt
            self.highest_arg_index+=1

    def add_args_for_dummy(self, claim, not_derived=False):
        self.claim_to_args[claim] = [self.highest_arg_index]
        self.highest_arg_index+=1

class ABAF:
    def __init__(self, a = set(), c = dict(), rd = dict(),
            q = list(), s = set(), h = dict(), b = dict()):

        self.assumptions = a
        self.contrary = c
        self.rules_deriving = rd
        self.queries = q
        self.sentences = s
        self.heads = h
        self.bodies = b

    # INPUT RESTRICTIONS:
    # - a single contrary per assumption
    def create_from_file(self, framework_filename):
        with open(framework_filename, "r") as f:
            text = f.read().split("\n")
        for line in text:
            if line.startswith("a "):
                self.assumptions.add(str(line.split()[1]))
            if line.startswith("c "):
                components = line.split()
                self.contrary[str(components[1])] = [components[2]]

        # only one contrary!
        self.rules_deriving = {ctr_list[0] : [] for ctr_list in self.contrary.values() if ctr_list[0] not in self.assumptions}
        # Assumptions have empty set, used for SCC detection
        for asmpt in self.assumptions:
            self.rules_deriving[asmpt] = list()

        self.sentences.update(self.assumptions)
        self.sentences.update(self.queries)
        for ctr_list in self.contrary.values():
            self.sentences.add(ctr_list[0])

        self.rule_indices = []
        self.heads = dict()
        self.bodies = dict()
        rule_index = 0
        for line in text:
            if line.startswith("r "):
                components = line.split()[1:]
                head, body = str(components[0]), components[1:]
                self.rule_indices.append(str(rule_index))
                self.heads[str(rule_index)] = head
                if head in self.rules_deriving:
                    self.rules_deriving[head].append(str(rule_index))
                else:
                    self.rules_deriving[head] = [str(rule_index)]

                self.bodies[str(rule_index)] = {str(b) for b in body}
                self.sentences.add(head)
                self.sentences.update(set(body))
                for b in body:
                    if not b in self.assumptions and not b in self.rules_deriving:
                        self.rules_deriving[b] = []

                rule_index += 1

class SCCDet:
    def __init__(self, fw):
        self.framework = fw
        self.SCCs = set()
        self.S = deque()
        self.on_stack = set()
        self.i = 0
        self.index = dict()
        self.low = dict()

    def tarjan(self):
        sys.setrecursionlimit(7000)
        for s in self.framework.sentences:
            if not s in self.index:
                self.scc(s)

        return self.SCCs

    def scc(self, s):
        self.index[s] = self.i
        self.low[s] = self.i
        self.i+=1
        self.S.append(s)
        self.on_stack.add(s)
        for r in self.framework.rules_deriving[s]:
            for child in self.framework.bodies[r]:
                if not child in self.index:
                    self.scc(child)
                    self.low[s] = min(self.low[s], self.low[child])
                elif child in self.on_stack:
                    self.low[s] = min(self.low[s], self.index[child])

        if self.low[s] == self.index[s]:
            SCC = set()
            while True:
                v = self.S.pop()
                self.on_stack.remove(v)
                SCC.add(v)
                if v == s:
                    break
            self.SCCs.add(frozenset(SCC))

class Solver:

    def __init__(self):
        self.current_dir = pathlib.Path(__file__).parent.resolve()
        self.temp_dir = ""

    def print_ASP(self, asmpts, contraries, rules, out_filename, query=None):
        """
        Print the given framework in ASP format.
        """
        with open(out_filename, 'w') as out:
            for asm in asmpts:
                out.write("assumption(" + asm + ").\n")
            for ctr in contraries:
                out.write("contrary(" + ctr + "," + contraries.get(ctr) + ").\n")
            for i, rule in enumerate(rules):
                out.write("head(" + str(i) + "," + rule[0] + ").\n")
                if rule[1]:
                    for body in rule[1]:
                        out.write("body(" + str(i) + "," + body + ").\n")
            if query:
                out.write("query(" + query + ").\n")

    def create_dummy_asmpts(self, framework, sentence):
        s_derived = sentence + "_derived"
        s_notderived = sentence + "_notderived"
        framework.assumptions.add(s_derived)
        framework.assumptions.add(s_notderived)

        framework.contrary[s_derived] = [s_notderived]
        framework.contrary[s_notderived] = [sentence]

        return s_derived, s_notderived

    def break_cycles(self, fw):
        acyclic_fw = ABAF()

        assumptions = fw.assumptions.copy()
        contrary = fw.contrary
        rules_deriving = dict()
        queries = fw.queries
        sentences = fw.sentences.copy()
        heads = dict()
        bodies = dict()
        contrary2 = dict()

        SCCD = SCCDet(fw)
        sccs = SCCD.tarjan()
        scc_sizes = ({len(scc) for scc in sccs})

        # No cycles, return original fw
        if not scc_sizes:
            return fw

        for asmp in fw.assumptions:
            rules_deriving[str(asmp)] = fw.rules_deriving[asmp]

        largest_scc_size = max(scc_sizes)
        rule_index = 0
        for scc in sccs:
            size = len(scc)
            for i in range(size):
                for sentence in scc:
                    rules = fw.rules_deriving[sentence]
                    new_head = f'{sentence}_{i+1}'
                    for j in rules:
                        # s_k := s
                        if i == size-1:
                            new_head = sentence
                        heads[rule_index] = new_head
                        sentences.add(new_head)
                        bodies[rule_index] = set()
                        for b_elem in fw.bodies[j]:
                            if b_elem not in scc:
                                bodies[rule_index].add(b_elem)
                            else:
                                new_body_elem = f"{b_elem}_{i}"
                                bodies[rule_index].add(new_body_elem)
                                sentences.add(new_body_elem)


                        if new_head not in rules_deriving:
                            rules_deriving[new_head] = []
                        rules_deriving[new_head].append(rule_index)

                        rule_index += 1

        acyclic_fw = ABAF(assumptions, contrary, rules_deriving, queries, sentences, heads, bodies)

        return acyclic_fw

    def create_af(self, fw, fw2, query):
        af = AF()
        af.create_from_ABAF(fw)

        elem_to_dummy = dict()
        rules = list()
        for s in sorted(fw.rules_deriving):
            indices = fw.rules_deriving[s]
            for i in indices:
                for elem in fw.bodies[i]:
                    vulnerable_asmpt = None
                    if elem in fw.assumptions:
                        vulnerable_asmpt = elem
                    else:
                        if elem not in elem_to_dummy:
                            # create dummy assumption when used in body of rule
                            s_derived, s_notderived = self.create_dummy_asmpts(fw2, elem)
                            af.add_args_for_dummy(s_derived)
                            af.add_args_for_dummy(s_notderived, True)
                            elem_to_dummy[elem] = s_derived

                            args_concluding_elem = set()
                            for contr in fw2.contrary[s_notderived]:
                                if contr in af.claim_to_args:
                                    args_concluding_elem.update(af.claim_to_args[contr])

                            af.attackers[af.claim_to_args[s_notderived][0]] = args_concluding_elem
                            # attacs from arg concluding s_notderived to arg concluding s_derived
                            af.attackers[af.claim_to_args[s_derived][0]] = set(af.claim_to_args[s_notderived])

                        vulnerable_asmpt = elem_to_dummy[elem]
                    if i not in af.attackers:
                        af.attackers[i] = set()

                    if not vulnerable_asmpt in fw2.contrary: # needed when no contrary for some asmpt
                        continue
                    for contr in fw2.contrary[vulnerable_asmpt]:
                        if contr in af.claim_to_args:
                            af.attackers[i].update(af.claim_to_args[contr])

        # create attacks to singleton assumptions
        for asmpt in fw2.assumptions:
            for arg in af.claim_to_args[asmpt]:
                af.attackers[arg] = set()
                if not asmpt in fw2.contrary:    # needed when no contrary for some asmpt
                    continue
                for contr in fw2.contrary[asmpt]:
                    if contr in af.claim_to_args:
                        af.attackers[arg].update(af.claim_to_args[contr])

        return af

    def print_AF(self, af):
        for i in range(af.highest_arg_index):
            print("arg("+str(i)+").")
        for arg, attackers in af.attackers.items():
            for attacker in attackers:
                print("att("+str(attacker)+","+str(arg)+").")
        for q in af.queried:
            print("query("+str(q)+").")

    def write_AF(self, out, af):
        for i in range(af.highest_arg_index):
            out.write("arg("+str(i)+").\n")
        for arg, attackers in af.attackers.items():
            for attacker in attackers:
                out.write("att("+str(attacker)+","+str(arg)+").\n")
        for q in af.queried:
            out.write("query("+str(q)+").\n")

    def findone(self, mutoksia_loc, af, abaf, problem):
        ans = ""
        tmp = tempfile.NamedTemporaryFile(mode="w",dir=self.temp_dir,delete=False)
        try:
            self.write_AF(tmp, af)
            tmp.close()
            output = subprocess.run([mutoksia_loc, "-fo", "apx", "-f", tmp.name, "-p", problem], capture_output=True, text=True)
            stdout = output.stdout.strip()
            stderr = output.stderr.strip()
            if "NO" in stdout:
                ans = "NO"
            elif stdout.startswith("["):
                witness = set(stdout.split("[")[1].split("]")[0].split(","))
                ans = list()
                for w in witness:
                    if int(w) in af.arg_to_asmpt:
                        ans.append(af.arg_to_asmpt[int(w)])

                ans.sort(key=int)
                ans = f"w {' '.join(ans)}"
            else:
                sys.exit("INTERRUPTED!")
        finally:
            os.unlink(tmp.name)

        return ans

    def acceptance(self, mutoksia_loc, af, problem):
        tmp = tempfile.NamedTemporaryFile(mode="w",dir=self.temp_dir,delete=False)
        try:
            self.write_AF(tmp, af)
            tmp.close()
            output = subprocess.run([mutoksia_loc, "-fo", "apx", "-f", tmp.name, "-p", problem], capture_output=True, text=True)
            stdout = output.stdout.strip()
            stderr = output.stderr.strip()
            retcode = output.returncode
        finally:
            os.unlink(tmp.name)

        if "YES" in stdout or "NO" in stdout:
            return stdout
        else:
            sys.exit(f"INTERRUPTED!")

    def run(self):
        parser = argparse.ArgumentParser()

        parser.add_argument("-f", "--file")
        parser.add_argument("-p", "--problem", choices=["DC-CO", "DC-ST", "DS-PR", "DS-ST", "SE-PR", "SE-ST"])
        parser.add_argument("-a", "--query", type=str)
        parser.add_argument("--problems", action="store_true")

        args = parser.parse_args()

        if len(sys.argv) == 1:
            print("ASPforABA v1\nTuomo Lehtonen, tuomo.lehtonen@helsinki.fi")
            sys.exit(0)

        mutoksia_path = os.path.join(self.current_dir, "mu-toksia-multiquery/mu-toksia")
        self.temp_dir = os.path.join(self.current_dir, ".tmp")

        config_file = os.path.join(self.current_dir, ".config")

        #if not os.path.isfile(config_file):
        #    print("Please run ./configure to give the location of mu-toksia binary solver and directory for temporary files.")
        #    sys.exit(1)

        # If manual config
        if os.path.isfile(config_file):
            solver_config = open(config_file, "r").read().split("\n")
            mutoksia_path = next(line for line in solver_config if line.startswith("MUTOKSIA_PATH")).split("=")[1]
            self.temp_dir = next(line for line in solver_config if line.startswith("TEMP_PATH")).split("=")[1]

        if args.problems:
            print("[DC-CO,DC-ST,DS-PR,DS-ST,SE-PR,SE-ST]")
            sys.exit(0)

        if not args.file:
            sys.exit("Please provide an input file.")

        if not args.problem:
            sys.exit("Please specify a problem: [DC-CO,DC-ST,DS-PR,DS-ST,SE-PR,SE-ST]")

        framework_filename = args.file
        task, semantics = args.problem.split("-")

        fw = ABAF()
        fw.create_from_file(framework_filename)

        D = self.break_cycles(fw)

        D2 = ABAF(a=D.assumptions.copy(), c=D.contrary.copy())
        F = self.create_af(D, D2, args.query)

        if task == "SE":
            ans = self.findone(mutoksia_path, F, D2, args.problem)
            print(ans)
        else:
            if task == "DC":
                if not args.query in F.claim_to_args:
                    print("NO")
                    sys.exit(0)
            elif task == "DS":
                if not args.query in F.claim_to_args:
                    if semantics == "PR": # preferred extension exists
                        print("NO")
                        sys.exit(0)
                    elif semantics == "ST":
                        ans = self.findone(mutoksia_path, F, D2, "SE-ST")
                        if ans == "NO":
                            print("YES")
                        elif ans.startswith("w"):
                            print("NO")
                        sys.exit(0)
            F.queried.extend(F.claim_to_args[args.query])
            ans = self.acceptance(mutoksia_path, F, args.problem)
            print(ans)

if __name__ == "__main__":
    Solver().run()
