#!/usr/bin/env python3

"""
Copyright <2023-2025> <Tuomo Lehtonen, Aalto University, University of Helsinki>

Permission is hereby granted, free of charge, to any person obtaining a copy of this
software and associated documentation files (the "Software"), to deal in the Software
without restriction, including without limitation the rights to use, copy, modify,
merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to the following
conditions:

The above copyright notice and this permission notice shall be included in all copies
or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT
OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
OTHER DEALINGS IN THE SOFTWARE.
"""

import clingo
import sys, os
import argparse
import subprocess
import tempfile
import pathlib

class ASPforABA:

    COMMON_ENCODING = """
        in(X) :- not out(X), assumption(X).
        out(X) :- not in(X), assumption(X).
        supported(X) :- assumption(X), in(X).
        supported(X) :- head(R,X), triggered_by_in(R).
        triggered_by_in(R) :- head(R,_), supported(X) : body(R,X).
        :- in(X), contrary(X,Y), supported(Y).
        defeated(X) :- supported(Y), contrary(X,Y).
        derived_from_undefeated(X) :- assumption(X), not defeated(X).
        derived_from_undefeated(X) :- head(R,X), triggered_by_undefeated(R).
        triggered_by_undefeated(R) :- head(R,_), derived_from_undefeated(X) : body(R,X).
        attacked_by_undefeated(X) :- contrary(X,Y), derived_from_undefeated(Y).
        :- in(X), attacked_by_undefeated(X).
        """

    COMPLETE_ENCODING = ":- out(X), not attacked_by_undefeated(X)."

    STABLE_ENCODING = """
        defeated(X) :- supported(Y), contrary(X,Y).
        :- not defeated(X), out(X).
        """

    def __init__(self):
        self.new_rules = []
        self.ctl = None
        self.assumptions = []
        self.contraries = []
        self.rules = []
        self.solving_assumptions = []
        self.last_model = None
        self.query = None
        self.refinement_asmpts = []
        self.current_dir = pathlib.Path(__file__).parent.resolve()
        self.temp_dir = ""

    def _maximize(self, model):
        #print(model)
        body = []
        self.refinement_asmpts = []

        for a in self.assumptions:
            atom = clingo.Function("out", [clingo.Function(f"a{a}")])
            if model.contains(atom):
                self.refinement_asmpts.append(atom)
            else:
                inatom = clingo.Function("in", [clingo.Function(f"a{a}")])
                self.solving_assumptions.append((inatom,True))

        self.new_rules.append(body)
        self.last_model = model.symbols(shown=True)

    def _record_model(self, model):
        self.last_model = model.symbols(shown=True)

    def find_one(self, clingo_loc):
        encoding = os.path.join(self.current_dir, "encodings", "stb-aba-enum.dl")

        tmp = tempfile.NamedTemporaryFile(mode="w",dir=self.temp_dir,delete=False)
        try:
            self.print_ASP(tmp)
            tmp.close()
            output = subprocess.run([clingo_loc, tmp.name, encoding], capture_output=True, text=True)
            stdout = output.stdout
            retcode = output.returncode
        finally:
            os.unlink(tmp.name)

        if "UNSATISFIABLE" in stdout:
            return "NO"
        elif "SATISFIABLE" in stdout:
            ans = [item.split("(")[1].split(")")[0][1:] for item in stdout.split() if item.startswith("in")]
            ans.sort(key=int)
            return f"w {' '.join(ans)}"
        else:
            sys.exit("INTERRUPTED!")

    def skeptical(self, clingo_loc, query):
        # Only stable skeptical supported!
        encoding = os.path.join(self.current_dir, "encodings", "stb-aba-skept.dl")

        tmp = tempfile.NamedTemporaryFile(mode="w",dir=self.temp_dir,delete=False)
        try:
            self.print_ASP(tmp, query)
            tmp.close()
            output = subprocess.run([clingo_loc, "-q", tmp.name, encoding], capture_output=True, text=True)
            stdout = output.stdout
            retcode = output.returncode
        finally:
            os.unlink(tmp.name)

        if "UNSATISFIABLE" in stdout:
            return True
        elif "SATISFIABLE" in stdout:
            return False
        else:
            sys.exit("INTERRUPTED!")

    def credulous(self, clingo_loc, semantics, query):
        encoding = None
        if semantics == "ST":
            encoding = os.path.join(self.current_dir, "encodings", "stb-aba-cred.dl")
        elif semantics == "CO":
            encoding = os.path.join(self.current_dir, "encodings", "com-aba-cred.dl")

        tmp = tempfile.NamedTemporaryFile(mode="w",dir=self.temp_dir,delete=False)
        try:
            self.print_ASP(tmp, query)
            tmp.close()
            output = subprocess.run([clingo_loc, "-q", tmp.name, encoding], capture_output=True, text=True)
            stdout = output.stdout
            retcode = output.returncode
        finally:
            os.unlink(tmp.name)

        if "UNSATISFIABLE" in stdout:
            return False
        elif "SATISFIABLE" in stdout:
            return True
        else:
            sys.exit("INTERRUPTED!")

    def _complete_encoding(self):
        self.ctl.add("base", [], self.COMMON_ENCODING)
        self.ctl.add("base", [], self.COMPLETE_ENCODING)
        self.ctl.add("base", [], "#show in/1. #show supported/1.")

    def _create_instance(self):
        # 'a' denotes atom
        assumptions_str = ' '.join({f"assumption(a{asm})." for asm in self.assumptions})
        ctr_str = ' '.join({f"contrary(a{a},a{c})." for a,c in self.contraries})
        head_str = ""
        body_str = ""
        for i, rule in enumerate(self.rules):
            head_str += f"head({i},a{rule[0]}). "
            for b in rule[1]:
                body_str += f"body({i},a{b}). "
        self.ctl.add("base", [], assumptions_str)
        self.ctl.add("base", [], ctr_str)
        self.ctl.add("base", [], head_str)
        self.ctl.add("base", [], body_str)

    def _parse_input(self, input_file):
        text = open(input_file, "r").read().split("\n")
        if not text[0].startswith("p"):
            sys.exit("Invalid file format (missing p-line).")
        for line in text:
            if line.startswith("a "):
                self.assumptions.append(line.split()[1])
            elif line.startswith("r "):
                components = line.split()[1:]
                head, body = components[0], components[1:]
                self.rules.append((head,body))
            if line.startswith("c "):
                components = line.split()
                self.contraries.append((components[1], components[2]))

    def print_ASP(self, out, query=None):
        for asm in self.assumptions:
            out.write(f"assumption(a{asm}).\n")
        for asm, ctr in self.contraries:
            out.write(f"contrary(a{asm},a{ctr}).\n")
        for i, rule in enumerate(self.rules):
            out.write(f"head({str(i)},a{rule[0]}).\n")
            if rule[1]:
                for body in rule[1]:
                    out.write(f"body({str(i)},a{body}).\n")
        if query:
            out.write(f"query(a{query}).\n")

    def se_pref(self):
        self.ctl.add("base", [], "#show in/1. #show out/1.")
        self.ctl.ground([("base", [])], context=self)

        while True:
            self.solving_assumptions = []
            if not self.ctl.solve(on_model=self._maximize).satisfiable:
                break
            while True:

                rule = []
                with self.ctl.backend() as backend:
                    for a in self.refinement_asmpts:
                        rule.append(backend.add_atom(a))
                    backend.add_rule(head=[],body=rule)
                if not self.ctl.solve(assumptions=self.solving_assumptions,on_model=self._maximize).satisfiable:
                    ans = [sym.arguments[0].name[1:] for sym in self.last_model if sym.name == "in"]
                    ans.sort(key=int)
                    return f"w {' '.join(ans)}"

        # A preferred extension always exists
        sys.exit("Error in finding preferred extension with clingo")

    def skept_pref(self, query):
        self.ctl.add("base", [], "#show in/1. #show out/1.")
        self.ctl.ground([("base", [])], context=self)

        qatom = clingo.Function("supported", [clingo.Function(f"a{query}")])
        query_assumption = [(qatom,False)]
        while True:
            self.solving_assumptions = list(query_assumption)
            if not self.ctl.solve(assumptions=query_assumption, on_model=self._maximize).satisfiable:
                break
            while True:
                rule = []
                with self.ctl.backend() as backend:
                    for a in self.refinement_asmpts:
                        rule.append(backend.add_atom(a))
                    backend.add_rule(head=[],body=rule)
                if not self.ctl.solve(assumptions=self.solving_assumptions,on_model=self._maximize).satisfiable:
                    #print("not found")
                    break

            #print(self.solving_assumptions)
            self.solving_assumptions[0] = (qatom,True)
            #print(self.solving_assumptions)
            if not self.ctl.solve(assumptions=self.solving_assumptions, on_model=self._record_model).satisfiable:
                return False
            #print(self.last_model)

        return True

    def run(self):
        parser = argparse.ArgumentParser()

        parser.add_argument("-f", "--file")
        parser.add_argument("-p", "--problem", choices=["DC-CO", "DC-ST", "DS-PR", "DS-ST", "SE-PR", "SE-ST"])
        parser.add_argument("-a", "--query", type=str)
        parser.add_argument("--problems", action="store_true")

        args = parser.parse_args()

        if len(sys.argv) == 1:
            print("ASPforABA v1\nTuomo Lehtonen, tuomo.lehtonen@aalto.fi")
            sys.exit(0)

        clingo_path = os.path.join(self.current_dir, "clingo/bin/clingo")
        self.temp_dir = os.path.join(self.current_dir, ".tmp")

        config_file = os.path.join(self.current_dir, ".config")

        #if not os.path.isfile(config_file):
        #    print("Please run ./configure to give the location of clingo and directory for temporary files.")
        #    sys.exit(1)

        # manual config
        if os.path.isfile(config_file):
            solver_config = open(config_file, "r").read().split("\n")
            clingo_path = next(line for line in solver_config if line.startswith("CLINGO_PATH")).split("=")[1]
            self.temp_dir = next(line for line in solver_config if line.startswith("TEMP_PATH")).split("=")[1]

        if args.problems:
            print("[DC-CO,DC-ST,DS-PR,DS-ST,SE-PR,SE-ST]")
            sys.exit(0)

        if not args.file:
            print("Please provide an input file.")
            sys.exit(1)

        if not args.problem:
            print("Please specify a problem: [DC-CO,DC-ST,DS-PR,DS-ST,SE-PR,SE-ST]")
            sys.exit(1)

        task, semantics = args.problem.split("-")

        if (task == "DC" or task == "DS") and not args.query:
            print("Please provide a query.")
            sys.exit(1)

        input_file = args.file
        self._parse_input(input_file)

        if semantics == "PR":
            self.ctl = clingo.Control(["--warn=none"])
            self._create_instance()
            self._complete_encoding()
            if task == "DS":
                if self.skept_pref(args.query):
                    print("YES")
                else:
                    print("NO")
            elif task == "SE":
                print(self.se_pref())
        else:
            if task == "DC":
                acc = self.credulous(clingo_path, semantics, args.query)
                if acc:
                    print("YES")
                else:
                    print("NO")
            elif task == "DS":
                acc = self.skeptical(clingo_path, args.query)
                if acc:
                    print("YES")
                else:
                    print("NO")
            elif task == "SE":
                print(self.find_one(clingo_path))

if __name__ == "__main__":
    ASPforABA().run()
