#include "Encodings.h"

#include <iostream>
#include <algorithm>

using namespace std;

namespace Encodings {

void add_body_is_true_in_clauses(ABF & af, SAT_Solver * solver)
{
	for (uint32_t i = 0; i < af.bodies.size(); i++) {
		//solver->phase(af.body_is_true_in_var[i]);
		for (uint32_t atom : af.bodies[i]) {
			if (af.is_assumption[atom]) {
				//solver->phase(-af.in_var[atom]);
				vector<int> clause = { -af.body_is_true_in_var[i], af.in_var[atom] };
				add_clause(solver, clause);
			} else {
				//solver->phase(-af.derived_from_in_var[atom]);
				vector<int> clause = { -af.body_is_true_in_var[i], af.derived_from_in_var[atom] };
				add_clause(solver, clause);
			}
		}
		vector<int> clause = { af.body_is_true_in_var[i] };
		for (uint32_t atom : af.bodies[i]) {
			if (af.is_assumption[atom]) {
				clause.push_back(-af.in_var[atom]);
			} else {
				clause.push_back(-af.derived_from_in_var[atom]);
			}
		}
		add_clause(solver, clause);
	}
}

void add_body_is_true_undec_clauses(ABF & af, SAT_Solver * solver)
{
	for (uint32_t i = 0; i < af.bodies.size(); i++) {
		//solver->phase(af.body_is_true_undec_var[i]);
		for (uint32_t atom : af.bodies[i]) {
			if (af.is_assumption[atom]) {
				//solver->phase(af.out_var[atom]);
				vector<int> clause = { -af.body_is_true_undec_var[i], -af.out_var[atom] };
				add_clause(solver, clause);
			} else {
				//solver->phase(-af.derived_from_undec_var[atom]);
				vector<int> clause = { -af.body_is_true_undec_var[i], af.derived_from_undec_var[atom] };
				add_clause(solver, clause);
			}
		}
		vector<int> clause = { af.body_is_true_undec_var[i] };
		for (uint32_t atom : af.bodies[i]) {
			if (af.is_assumption[atom]) {
				//clause.push_back(af.in_var[atom]);
				clause.push_back(af.out_var[atom]);
			} else {
				clause.push_back(-af.derived_from_undec_var[atom]);
			}
		}
		add_clause(solver, clause);

		clause = { -af.body_is_true_in_var[i], af.body_is_true_undec_var[i] };
		add_clause(solver, clause);
	}
}

void add_rule_is_justified_in_clauses(ABF & af, SAT_Solver * solver)
{
	/*vector<int> edge_vars;
	for (const auto & [source, target] : af.derivation_edges) {
		edge_vars.push_back(af.derived_from_in_edge_var[make_pair(source, target)]);
		//solver->phase(-af.derived_from_in_edge_var[make_pair(source, target)]);
	}

	if (af.sem == ST) {
		int index = solver->add_graph(edge_vars, af.derivation_edges);
		solver->set_acyclic(index);
	}*/
	for (const auto & [source, target] : af.derivation_edges) {
		solver->phase(-af.derived_from_in_edge_var[make_pair(source, target)]);
	}

	for (uint32_t i = 0; i < af.rules.size(); i++) {
		if (af.scc_index[af.rules[i].first] == af.scc_index[af.symbols + af.rules[i].second]) {
			vector<int> clause = { af.rule_is_justified_in_var[i], -af.body_is_true_in_var[af.rules[i].second] };
			uint32_t head = af.rules[i].first;
			for (uint32_t k = 0; k < af.bodies[af.rules[i].second].size(); k++) {
				uint32_t body_elem = af.bodies[af.rules[i].second][k];
				if (af.scc_index[head] == af.scc_index[body_elem]) {
					clause.push_back(-af.derived_from_in_edge_var[make_pair(body_elem, head)]);
				}
			}
			add_clause(solver, clause);
		}
	}

	for (uint32_t i = 0; i < af.rules.size(); i++) {
		if (af.scc_index[af.rules[i].first] == af.scc_index[af.symbols + af.rules[i].second]) {
			vector<int> clause = { -af.rule_is_justified_in_var[i], af.body_is_true_in_var[af.rules[i].second] };
			add_clause(solver, clause);
			uint32_t head = af.rules[i].first;
			for (uint32_t k = 0; k < af.bodies[af.rules[i].second].size(); k++) {
				uint32_t body_elem = af.bodies[af.rules[i].second][k];
				if (af.scc_index[head] == af.scc_index[body_elem]) {
					clause = { -af.rule_is_justified_in_var[i], af.derived_from_in_edge_var[make_pair(body_elem, head)] };
					add_clause(solver, clause);
				}
			}
		}
	}

	for (const auto & [source, target] : af.derivation_edges) {
		if (af.is_assumption[target]) {
			vector<int> clause = { af.in_var[target], -af.derived_from_in_edge_var[make_pair(source, target)] };
			add_clause(solver, clause);
		} else {
			vector<int> clause = { af.derived_from_in_var[target], -af.derived_from_in_edge_var[make_pair(source, target)] };
			add_clause(solver, clause);
		}
	}

	for (uint32_t i = 0; i < af.rules.size(); i++) {
		if (af.scc_index[af.rules[i].first] != af.scc_index[af.symbols + af.rules[i].second]) {
			uint32_t head = af.rules[i].first;
			for (const auto & [source, target] : af.derivation_edges) {
				if (head != target) continue;
				vector<int> clause = { -af.body_is_true_in_var[af.rules[i].second], -af.derived_from_in_edge_var[make_pair(source, target)] };
				add_clause(solver, clause);
			}
		}
	}

	for (uint32_t i = 0; i < af.rules.size(); i++) {
		if (af.scc_index[af.rules[i].first] == af.scc_index[af.symbols + af.rules[i].second]) {
			uint32_t head = af.rules[i].first;
			vector<uint32_t> body1 = af.bodies[af.rules[i].second];
			for (uint32_t rule_index : af.rules_with_head[head]) {
				if (rule_index == i) continue;
				uint32_t body_index = af.rules[rule_index].second;
				if (af.scc_index[af.symbols + body_index] != af.scc_index[head])
					continue;
				vector<uint32_t> body2 = af.bodies[af.rules[rule_index].second];
				vector<uint32_t> diff;
				set_difference(body1.begin(), body1.end(), body2.begin(), body2.end(), back_inserter(diff));
				for (uint32_t body_elem : diff) {
					if (af.scc_index[head] == af.scc_index[body_elem]) {
						vector<int> clause = { -af.rule_is_justified_in_var[rule_index], -af.derived_from_in_edge_var[make_pair(body_elem, head)] };
						add_clause(solver, clause);
					}
				}
			}
		}
	}
}

void add_rule_is_justified_undec_clauses(ABF & af, SAT_Solver * solver)
{
	/*vector<int> edge_vars;
	for (const auto & [source, target] : af.derivation_edges) {
		edge_vars.push_back(af.derived_from_undec_edge_var[make_pair(source, target)]);
		//solver->phase(-af.derived_from_undec_edge_var[make_pair(source, target)]);
		vector<int> clause = { -af.derived_from_in_edge_var[make_pair(source, target)], af.derived_from_undec_edge_var[make_pair(source, target)] };
		add_clause(solver, clause);
	}

	int index = solver->add_graph(edge_vars, af.derivation_edges);
	solver->set_acyclic(index);*/

	for (const auto & [source, target] : af.derivation_edges) {
		solver->phase(-af.derived_from_undec_edge_var[make_pair(source, target)]);
	}

	for (const auto & [source, target] : af.derivation_edges) {
		vector<int> clause = { -af.derived_from_in_edge_var[make_pair(source, target)], af.derived_from_undec_edge_var[make_pair(source, target)] };
		add_clause(solver, clause);
	}

	for (uint32_t i = 0; i < af.rules.size(); i++) {
		if (af.scc_index[af.rules[i].first] == af.scc_index[af.symbols + af.rules[i].second]) {
			vector<int> clause = { af.rule_is_justified_undec_var[i], -af.body_is_true_undec_var[af.rules[i].second] };
			uint32_t head = af.rules[i].first;
			for (uint32_t k = 0; k < af.bodies[af.rules[i].second].size(); k++) {
				uint32_t body_elem = af.bodies[af.rules[i].second][k];
				if (af.scc_index[head] == af.scc_index[body_elem]) {
					clause.push_back(-af.derived_from_undec_edge_var[make_pair(body_elem, head)]);
				}
			}
			add_clause(solver, clause);
		}
	}

	for (uint32_t i = 0; i < af.rules.size(); i++) {
		if (af.scc_index[af.rules[i].first] == af.scc_index[af.symbols + af.rules[i].second]) {
			vector<int> clause = { -af.rule_is_justified_undec_var[i], af.body_is_true_undec_var[af.rules[i].second] };
			add_clause(solver, clause);
			uint32_t head = af.rules[i].first;
			for (uint32_t k = 0; k < af.bodies[af.rules[i].second].size(); k++) {
				uint32_t body_elem = af.bodies[af.rules[i].second][k];
				if (af.scc_index[head] == af.scc_index[body_elem]) {
					clause = { -af.rule_is_justified_undec_var[i], af.derived_from_undec_edge_var[make_pair(body_elem, head)] };
					add_clause(solver, clause);
				}
			}
		}
	}

	for (const auto & [source, target] : af.derivation_edges) {
		if (af.is_assumption[target]) {
			vector<int> clause = { -af.out_var[target], -af.derived_from_undec_edge_var[make_pair(source, target)] };
			add_clause(solver, clause);
		} else {
			vector<int> clause = { af.derived_from_undec_var[target], -af.derived_from_undec_edge_var[make_pair(source, target)] };
			add_clause(solver, clause);
		}
	}

	for (uint32_t i = 0; i < af.rules.size(); i++) {
		if (af.scc_index[af.rules[i].first] != af.scc_index[af.symbols + af.rules[i].second]) {
			uint32_t head = af.rules[i].first;
			for (const auto & [source, target] : af.derivation_edges) {
				if (head != target) continue;
				vector<int> clause = { -af.body_is_true_undec_var[af.rules[i].second], af.derived_from_in_edge_var[make_pair(source, target)], -af.derived_from_undec_edge_var[make_pair(source, target)] };
				add_clause(solver, clause);
			}
		}
	}

	for (uint32_t i = 0; i < af.rules.size(); i++) {
		if (af.scc_index[af.rules[i].first] == af.scc_index[af.symbols + af.rules[i].second]) {
			uint32_t head = af.rules[i].first;
			vector<uint32_t> body1 = af.bodies[af.rules[i].second];
			for (uint32_t rule_index : af.rules_with_head[head]) {
				if (rule_index == i) continue;
				uint32_t body_index = af.rules[rule_index].second;
				if (af.scc_index[af.symbols + body_index] != af.scc_index[head])
					continue;
				vector<uint32_t> body2 = af.bodies[af.rules[rule_index].second];
				vector<uint32_t> diff;
				set_difference(body1.begin(), body1.end(), body2.begin(), body2.end(), back_inserter(diff));
				for (uint32_t body_elem : diff) {
					if (af.scc_index[head] == af.scc_index[body_elem]) {
						vector<int> clause = { -af.rule_is_justified_undec_var[rule_index], af.derived_from_in_edge_var[make_pair(body_elem, head)], -af.derived_from_undec_edge_var[make_pair(body_elem, head)] };
						add_clause(solver, clause);
					}
				}
			}
		}
	}
}

void add_derived_from_in_clauses(ABF & af, SAT_Solver * solver)
{
	add_rule_is_justified_in_clauses(af, solver);
	for (uint32_t i = 0; i < af.symbols; i++) {
		if (!af.is_assumption[i]) {
			for (uint32_t body_index : af.bodies_with_head[i]) {
				vector<int> clause = { af.derived_from_in_var[i], -af.body_is_true_in_var[body_index] };
				add_clause(solver, clause);
			}
			vector<int> clause = { -af.derived_from_in_var[i] };
			for (uint32_t rule_index : af.rules_with_head[i]) {
				if (af.scc_index[af.rules[rule_index].first] != af.scc_index[af.symbols + af.rules[rule_index].second]) {
					clause.push_back(af.body_is_true_in_var[af.rules[rule_index].second]);
				} else {
					clause.push_back(af.rule_is_justified_in_var[rule_index]);
				}
			}
			add_clause(solver, clause);
		}
	}
}

void add_derived_from_undec_clauses(ABF & af, SAT_Solver * solver)
{
	add_rule_is_justified_undec_clauses(af, solver);
	for (uint32_t i = 0; i < af.symbols; i++) {
		if (!af.is_assumption[i]) {
			for (uint32_t body_index : af.bodies_with_head[i]) {
				vector<int> clause = { af.derived_from_undec_var[i], -af.body_is_true_undec_var[body_index] };
				add_clause(solver, clause);
			}
			vector<int> clause = { -af.derived_from_undec_var[i] };
			for (uint32_t rule_index : af.rules_with_head[i]) {
				if (af.scc_index[af.rules[rule_index].first] != af.scc_index[af.symbols + af.rules[rule_index].second]) {
					clause.push_back(af.body_is_true_undec_var[af.rules[rule_index].second]);
				} else {
					clause.push_back(af.rule_is_justified_undec_var[rule_index]);
				}
			}
			add_clause(solver, clause);

			clause = { -af.derived_from_in_var[i], af.derived_from_undec_var[i] };
			add_clause(solver, clause);
		}
	}
}

void add_out_clauses(ABF & af, SAT_Solver * solver)
{
	add_body_is_true_in_clauses(af, solver);
	add_derived_from_in_clauses(af, solver);
	for (uint32_t i = 0; i < af.assumptions.size(); i++) {
		//solver->phase(af.out_var[af.assumptions[i]]);
		if (af.is_assumption[af.contrary[af.assumptions[i]]]) {
			vector<int> clause = { af.out_var[af.assumptions[i]], -af.in_var[af.contrary[af.assumptions[i]]],  };
			add_clause(solver, clause);
			clause = { -af.out_var[af.assumptions[i]], af.in_var[af.contrary[af.assumptions[i]]] };
			add_clause(solver, clause);
		} else {
			vector<int> clause = { af.out_var[af.assumptions[i]], -af.derived_from_in_var[af.contrary[af.assumptions[i]]] };
			add_clause(solver, clause);
			clause = { -af.out_var[af.assumptions[i]], af.derived_from_in_var[af.contrary[af.assumptions[i]]] };
			add_clause(solver, clause);
		}
	}
}

void add_attacked_by_undec_clauses(ABF & af, SAT_Solver * solver)
{
	add_body_is_true_undec_clauses(af, solver);
	add_derived_from_undec_clauses(af, solver);
	for (uint32_t i = 0; i < af.assumptions.size(); i++) {
		if (af.is_assumption[af.contrary[af.assumptions[i]]]) {
			vector<int> clause = { af.attacked_by_undec_var[af.assumptions[i]], af.out_var[af.contrary[af.assumptions[i]]] };
			add_clause(solver, clause);
			clause = { -af.attacked_by_undec_var[af.assumptions[i]], -af.out_var[af.contrary[af.assumptions[i]]] };
			add_clause(solver, clause);
		} else {
			vector<int> clause = { af.attacked_by_undec_var[af.assumptions[i]], -af.derived_from_undec_var[af.contrary[af.assumptions[i]]] };
			add_clause(solver, clause);
			clause = { -af.attacked_by_undec_var[af.assumptions[i]], af.derived_from_undec_var[af.contrary[af.assumptions[i]]] };
			add_clause(solver, clause);
		}
	}
}

void add_conflict_free(ABF & af, SAT_Solver * solver)
{
	add_out_clauses(af, solver);
	for (uint32_t i = 0; i < af.assumptions.size(); i++) {
		vector<int> clause = { -af.in_var[af.assumptions[i]], -af.out_var[af.assumptions[i]] };
		add_clause(solver, clause);
	}
}

void add_admissible(ABF & af, SAT_Solver * solver)
{
	add_conflict_free(af, solver);
	add_attacked_by_undec_clauses(af, solver);
	for (uint32_t i = 0; i < af.assumptions.size(); i++) {
		vector<int> clause = { -af.in_var[af.assumptions[i]], -af.attacked_by_undec_var[af.assumptions[i]] };
		add_clause(solver, clause);
	}
}

void add_complete(ABF & af, SAT_Solver * solver)
{
	add_admissible(af, solver);
	for (uint32_t i = 0; i < af.assumptions.size(); i++) {
		vector<int> clause = { af.attacked_by_undec_var[af.assumptions[i]], af.in_var[af.assumptions[i]] };
		add_clause(solver, clause);
	}
}

void add_stable(ABF & af, SAT_Solver * solver)
{
	add_conflict_free(af, solver);
	for (uint32_t i = 0; i < af.assumptions.size(); i++) {
		vector<int> clause = { af.in_var[af.assumptions[i]], af.out_var[af.assumptions[i]] };
		add_clause(solver, clause);
	}
}

void add_range(ABF & af, SAT_Solver * solver)
{
	for (uint32_t i = 0; i < af.assumptions.size(); i++) {
		vector<int> clause = { -af.range_var[af.assumptions[i]], af.in_var[af.assumptions[i]], af.out_var[af.assumptions[i]] };
		add_clause(solver, clause);
		clause = { af.range_var[af.assumptions[i]], -af.in_var[af.assumptions[i]] };
		add_clause(solver, clause);
		clause = { af.range_var[af.assumptions[i]], -af.out_var[af.assumptions[i]] };
		add_clause(solver, clause);
	}
}

}
