#include "Encodings.h"

#include <iostream>
#include <algorithm>

using namespace std;

namespace Encodings {

void add_body_is_true_in_clauses(ABF & af, SAT_Solver * solver)
{
	for (uint32_t i = 0; i < af.bodies.size(); i++) {
		//solver->phase(af.body_is_true_in_var[i]);
		for (uint32_t atom : af.bodies[i]) {
			if (af.is_assumption[atom]) {
				//solver->phase(-af.in_var[atom]);
				vector<int> clause = { -af.body_is_true_in_var[i], af.in_var[atom] };
				add_clause(solver, clause);
			} else {
				//solver->phase(-af.derived_from_in_var[atom]);
				vector<int> clause = { -af.body_is_true_in_var[i], af.derived_from_in_var[atom] };
				add_clause(solver, clause);
			}
		}
		vector<int> clause = { af.body_is_true_in_var[i] };
		for (uint32_t atom : af.bodies[i]) {
			if (af.is_assumption[atom]) {
				clause.push_back(-af.in_var[atom]);
			} else {
				clause.push_back(-af.derived_from_in_var[atom]);
			}
		}
		add_clause(solver, clause);
	}
}

void add_body_is_true_undec_clauses(ABF & af, SAT_Solver * solver)
{
	for (uint32_t i = 0; i < af.bodies.size(); i++) {
		//solver->phase(af.body_is_true_undec_var[i]);
		for (uint32_t atom : af.bodies[i]) {
			if (af.is_assumption[atom]) {
				//solver->phase(af.out_var[atom]);
				vector<int> clause = { -af.body_is_true_undec_var[i], -af.out_var[atom] };
				add_clause(solver, clause);
			} else {
				//solver->phase(-af.derived_from_undec_var[atom]);
				vector<int> clause = { -af.body_is_true_undec_var[i], af.derived_from_undec_var[atom] };
				add_clause(solver, clause);
			}
		}
		vector<int> clause = { af.body_is_true_undec_var[i] };
		for (uint32_t atom : af.bodies[i]) {
			if (af.is_assumption[atom]) {
				//clause.push_back(af.in_var[atom]);
				clause.push_back(af.out_var[atom]);
			} else {
				clause.push_back(-af.derived_from_undec_var[atom]);
			}
		}
		add_clause(solver, clause);

		clause = { -af.body_is_true_in_var[i], af.body_is_true_undec_var[i] };
		add_clause(solver, clause);
	}
}

void add_derived_from_in_clauses(ABF & af, SAT_Solver * solver)
{
	for (uint32_t i = 0; i < af.symbols; i++) {
		if (!af.is_assumption[i]) {
			for (uint32_t body_index : af.bodies_with_head[i]) {
				vector<int> clause = { af.derived_from_in_var[i], -af.body_is_true_in_var[body_index] };
				add_clause(solver, clause);
			}
			vector<int> clause = { -af.derived_from_in_var[i] };
			for (uint32_t body_index : af.bodies_with_head[i]) {
				clause.push_back(af.body_is_true_in_var[body_index]);
			}
			add_clause(solver, clause);
		}
	}
}

void add_derived_from_undec_clauses(ABF & af, SAT_Solver * solver)
{
	for (uint32_t i = 0; i < af.symbols; i++) {
		if (!af.is_assumption[i]) {
			for (uint32_t body_index : af.bodies_with_head[i]) {
				vector<int> clause = { af.derived_from_undec_var[i], -af.body_is_true_undec_var[body_index] };
				add_clause(solver, clause);
			}
			vector<int> clause = { -af.derived_from_undec_var[i] };
			for (uint32_t body_index : af.bodies_with_head[i]) {
				clause.push_back(af.body_is_true_undec_var[body_index]);
			}
			add_clause(solver, clause);

			clause = { -af.derived_from_in_var[i], af.derived_from_undec_var[i] };
			add_clause(solver, clause);
		}
	}
}

void add_out_clauses(ABF & af, SAT_Solver * solver)
{
	add_body_is_true_in_clauses(af, solver);
	add_derived_from_in_clauses(af, solver);
	for (uint32_t i = 0; i < af.assumptions.size(); i++) {
		//solver->phase(af.out_var[af.assumptions[i]]);
		if (af.is_assumption[af.contrary[af.assumptions[i]]]) {
			vector<int> clause = { af.out_var[af.assumptions[i]], -af.in_var[af.contrary[af.assumptions[i]]],  };
			add_clause(solver, clause);
			clause = { -af.out_var[af.assumptions[i]], af.in_var[af.contrary[af.assumptions[i]]] };
			add_clause(solver, clause);
		} else {
			vector<int> clause = { af.out_var[af.assumptions[i]], -af.derived_from_in_var[af.contrary[af.assumptions[i]]] };
			add_clause(solver, clause);
			clause = { -af.out_var[af.assumptions[i]], af.derived_from_in_var[af.contrary[af.assumptions[i]]] };
			add_clause(solver, clause);
		}
	}
}

void add_attacked_by_undec_clauses(ABF & af, SAT_Solver * solver)
{
	add_body_is_true_undec_clauses(af, solver);
	add_derived_from_undec_clauses(af, solver);
	for (uint32_t i = 0; i < af.assumptions.size(); i++) {
		if (af.is_assumption[af.contrary[af.assumptions[i]]]) {
			vector<int> clause = { af.attacked_by_undec_var[af.assumptions[i]], af.out_var[af.contrary[af.assumptions[i]]] };
			add_clause(solver, clause);
			clause = { -af.attacked_by_undec_var[af.assumptions[i]], -af.out_var[af.contrary[af.assumptions[i]]] };
			add_clause(solver, clause);
		} else {
			vector<int> clause = { af.attacked_by_undec_var[af.assumptions[i]], -af.derived_from_undec_var[af.contrary[af.assumptions[i]]] };
			add_clause(solver, clause);
			clause = { -af.attacked_by_undec_var[af.assumptions[i]], af.derived_from_undec_var[af.contrary[af.assumptions[i]]] };
			add_clause(solver, clause);
		}
	}
}

void add_conflict_free(ABF & af, SAT_Solver * solver)
{
	add_out_clauses(af, solver);
	for (uint32_t i = 0; i < af.assumptions.size(); i++) {
		vector<int> clause = { -af.in_var[af.assumptions[i]], -af.out_var[af.assumptions[i]] };
		add_clause(solver, clause);
	}
}

void add_admissible(ABF & af, SAT_Solver * solver)
{
	add_conflict_free(af, solver);
	add_attacked_by_undec_clauses(af, solver);
	for (uint32_t i = 0; i < af.assumptions.size(); i++) {
		vector<int> clause = { -af.in_var[af.assumptions[i]], -af.attacked_by_undec_var[af.assumptions[i]] };
		add_clause(solver, clause);
	}
}

void add_complete(ABF & af, SAT_Solver * solver)
{
	add_admissible(af, solver);
	for (uint32_t i = 0; i < af.assumptions.size(); i++) {
		vector<int> clause = { af.attacked_by_undec_var[af.assumptions[i]], af.in_var[af.assumptions[i]] };
		add_clause(solver, clause);
	}
}

void add_stable(ABF & af, SAT_Solver * solver)
{
	add_conflict_free(af, solver);
	for (uint32_t i = 0; i < af.assumptions.size(); i++) {
		vector<int> clause = { af.in_var[af.assumptions[i]], af.out_var[af.assumptions[i]] };
		add_clause(solver, clause);
	}
}

void add_range(ABF & af, SAT_Solver * solver)
{
	for (uint32_t i = 0; i < af.assumptions.size(); i++) {
		vector<int> clause = { -af.range_var[af.assumptions[i]], af.in_var[af.assumptions[i]], af.out_var[af.assumptions[i]] };
		add_clause(solver, clause);
		clause = { af.range_var[af.assumptions[i]], -af.in_var[af.assumptions[i]] };
		add_clause(solver, clause);
		clause = { af.range_var[af.assumptions[i]], -af.out_var[af.assumptions[i]] };
		add_clause(solver, clause);
	}
}

}
