import sys
from crustabri import Crustabri as AFSolver

import random
random.seed(811543731122527)

from collections import deque
import time
import argparse

start_time = time.time()
def get_time():
        return round(time.time() - start_time, 2)

parser = argparse.ArgumentParser(description="ICCMA'23 Dynamic Track")
parser.add_argument("task", type=str, help="Acceptance problem. Choices: DC-CO, DC-ST, DS-ST, DS-PR.")
parser.add_argument("file", type=str, help="Filename of AF in ICCMA'23 format.")
parser.add_argument("query", type=str, help="Filename of query argument.")
parser.add_argument("--fixed-part", type=float, default=0.333, help="Proportion of fixed arguments.")
parser.add_argument("--added-part", type=float, default=0.333, help="Proportion of added arguments.")
parser.add_argument("--iterations", type=int, default=64, help="Number of iterations.")
parser.add_argument("--changes", type=int, default=32, help="Number of changes.")
parser.add_argument("--queries", type=int, default=16, help="Number of query arguments.")
args = parser.parse_args()

task = args.task
af_file = args.file
query_file = args.query
problem, semantics = task.split("-")
assert(problem in ["DC", "DS"])
assert(semantics in ["CO", "ST", "PR"])

FIXED_PART = args.fixed_part
ADDED_PART = args.added_part
ITERATIONS = args.iterations
CHANGES = args.changes
QUERIES = args.queries

af_file_contents = open(af_file).read().split("\n")
af_file_contents = [line.strip() for line in af_file_contents if not line.startswith("#") and len(line) > 0]
p_line = af_file_contents[0]
assert(p_line.startswith("p"))
attack_lines = af_file_contents[1:]
arguments = range(1, int(p_line.replace("p af ", ""))+1)
n = len(arguments)
attacks = [tuple(map(int, line.split())) for line in attack_lines]
attackers = { a : [] for a in arguments }
attack_range = { a : [] for a in arguments }
adjacency_list = { a : set() for a in arguments }
for a,b in attacks:
	attackers[b].append(a)
	attack_range[a].append(b)
	adjacency_list[a].add(b)
	adjacency_list[b].add(a)
query = int(open(query_file).read().strip())

print("[%8.2fs]" % get_time(), "Original AF parsed")

visited = [0]*(n+1)
visited[query] = 1
queue = deque([query])
depth = 1
while len(queue) > 0:
	a = queue.popleft()
	for b in adjacency_list[a]:
		if visited[b] == 0:
			visited[b] = depth+1
			queue.append(b)
	depth = depth+1

print("[%8.2fs]" % get_time(), "DFS completed")

args_at_depth = { d : [] for d in range(0, max(visited)+1) }
for a in arguments:
	args_at_depth[visited[a]].append(a)

arg_fixed = [False]*(n+1)
arg_fixed[query] = True

max_depth = 0
for d in range(1, max(visited)+1):
	if sum(arg_fixed) + len(args_at_depth[d]) >= round(FIXED_PART*n):
		sample = random.sample(args_at_depth[d], round(FIXED_PART*n) - sum(arg_fixed))
		for a in sample:
			arg_fixed[a] = True
		max_depth = d
		break
	else:
		for a in args_at_depth[d]:
			arg_fixed[a] = True

print("[%8.2fs]" % get_time(), "Fixed part determined")

arg_exists = arg_fixed.copy()

for d in range(max_depth, max(visited)+1):
	candidate_args = [a for a in args_at_depth[d] if not arg_exists[a]]
	if sum(arg_exists) + len(candidate_args) >= round((FIXED_PART+ADDED_PART)*n):
		sample = random.sample(candidate_args, round((FIXED_PART+ADDED_PART)*n) - sum(arg_exists))
		for a in sample:
			arg_exists[a] = True
		max_depth = d
	else:
		for a in args_at_depth[d]:
			arg_exists[a] = True

print("[%8.2fs]" % get_time(), "Existing arguments determined")

solver = AFSolver(semantics)
for a in arguments:
	if arg_exists[a]:
		solver.add_argument(a)
for a,b in attacks:
	if arg_exists[a] and arg_exists[b]:
		solver.add_attack(a,b)

print("[%8.2fs]" % get_time(), "Solver initialized")

queries = [query] + random.sample([a for a in arguments if arg_fixed[a]], QUERIES-1)

for i in range(1, ITERATIONS+1):
	print("[%8.2fs]" % get_time(), "ITERATION %d" % i + ",", sum(arg_exists), "arguments")
	status = []
	for q in queries:
		if problem == "DC":
			status.append(solver.solve_cred([q]))
		elif problem == "DS":
			status.append(solver.solve_skept([q]))
	print("[%8.2fs]" % get_time(), "Acceptance status determined")
	print("v", "".join(map(str, map(int, status))))
	additions = 0
	deletions = 0
	for _ in range(CHANGES):
		p = (sum(arg_exists)-sum(arg_fixed))/(len(arguments)-sum(arg_fixed))
		addition = random.random() >= p
		if addition:
			candidate_args = [a for a in arguments if not arg_exists[a]]
			if len(candidate_args) == 0:
				continue
			to_add = random.choice(list(candidate_args))
			arg_exists[to_add] = True
			solver.add_argument(to_add)
			for a in attackers[to_add]:
				if arg_exists[a]:
					solver.add_attack(a, to_add)
			for a in attack_range[to_add]:
				if arg_exists[a] and a != to_add:
					solver.add_attack(to_add, a)
			additions += 1
		else:
			candidate_args = [a for a in arguments if arg_exists[a] and not arg_fixed[a]]
			if len(candidate_args) == 0:
				continue
			to_delete = random.choice(candidate_args)
			arg_exists[to_delete] = False
			solver.del_argument(to_delete)
			deletions += 1
	print("[%8.2fs]" % get_time(), additions, "additions,", deletions, "deletions applied")
