#include "Encodings.h"
#include "Util.h"
#include "SingleExtension.h"

using namespace std;

namespace SingleExtension {

bool admissible(ABF & af, SAT_Solver * solver)
{
	Encodings::add_admissible(af, solver);
#if defined(OUTPUT)
	return false;
#endif
	int sat = solver->solve();
	if (sat == 10) {
		vector<uint32_t> extension;
		for (uint32_t i = 0; i < af.assumptions.size(); i++) {
			if (solver->val(af.in_var[af.assumptions[i]]) > 0) {
				extension.push_back(af.assumptions[i]+1);
			}
		}
		print_single_extension(af, extension, 0);
	}
	return true;
}

bool complete(ABF & af, SAT_Solver * solver)
{
	Encodings::add_complete(af, solver);
#if defined(OUTPUT)
	return false;
#endif
	int sat = solver->solve();
	if (sat == 10) {
		vector<uint32_t> extension;
		for (uint32_t i = 0; i < af.assumptions.size(); i++) {
			if (solver->val(af.in_var[af.assumptions[i]]) > 0) {
				extension.push_back(af.assumptions[i]+1);
			}
		}
		print_single_extension(af, extension, 0);
	}
	return true;
}

bool preferred(ABF & af, SAT_Solver * solver)
{
#if defined(OUTPUT)
	return false;
#endif
	Encodings::add_complete(af, solver);
	/*for (uint32_t i = 0; i < af.assumptions.size(); i++) {
		solver->phase(af.in_var[af.assumptions[i]]);
	}*/

	vector<uint32_t> extension;
	while (true) {
		int sat = solver->solve();
		if (sat == 20) break;
		extension.clear();
		for (uint32_t i = 0; i < af.assumptions.size(); i++) {
			if (solver->val(af.in_var[af.assumptions[i]]) > 0) {
				extension.push_back(af.assumptions[i]+1);
			}
		}
		vector<int> complement_clause;
		complement_clause.reserve(af.assumptions.size());
		vector<int> units;
		units.reserve(af.assumptions.size());
		for (uint32_t i = 0; i < af.assumptions.size(); i++) {
			if (solver->val(af.in_var[af.assumptions[i]]) > 0) {
				units.push_back(af.in_var[af.assumptions[i]]);
			} else {
				complement_clause.push_back(af.in_var[af.assumptions[i]]);
			}
		}
		for (int lit : units) {
			solver->add(lit);
			solver->add(0);
		}
		add_clause(solver, complement_clause);
	}
	print_single_extension(af, extension, 0);
	return true;
}

bool stable(ABF & af, SAT_Solver * solver)
{
	Encodings::add_stable(af, solver);
	int sat = solver->solve();
	if (sat == 10) {
		vector<uint32_t> extension;
		for (uint32_t i = 0; i < af.assumptions.size(); i++) {
			if (solver->val(af.in_var[af.assumptions[i]]) > 0) {
				extension.push_back(af.assumptions[i]+1);
			}
		}
		print_single_extension(af, extension, 0);
		return true;
	} else {
		print_no(false);
		cout << endl;
		return false;
	}
}

bool semi_stable(ABF & af, SAT_Solver * solver)
{
#if defined(OUTPUT)
	return false;
#endif
	Encodings::add_complete(af, solver);
	Encodings::add_range(af, solver);
	/*for (uint32_t i = 0; i < af.assumptions.size(); i++) {
		solver->phase(af.in_var[af.assumptions[i]]);
	}*/

	vector<uint32_t> extension;
	while (true) {
		int sat = solver->solve();
		if (sat == 20) break;
		extension.clear();
		for (uint32_t i = 0; i < af.assumptions.size(); i++) {
			if (solver->val(af.in_var[af.assumptions[i]]) > 0) {
				extension.push_back(af.assumptions[i]+1);
			}
		}
		vector<int> complement_clause;
		complement_clause.reserve(af.assumptions.size());
		vector<int> units;
		units.reserve(af.assumptions.size());
		for (uint32_t i = 0; i < af.assumptions.size(); i++) {
			if (solver->val(af.range_var[af.assumptions[i]]) > 0) {
				units.push_back(af.range_var[af.assumptions[i]]);
			} else {
				complement_clause.push_back(af.range_var[af.assumptions[i]]);
			}
		}
		for (int lit : units) {
			solver->add(lit);
			solver->add(0);
		}
		add_clause(solver, complement_clause);
	}
	print_single_extension(af, extension, 0);
	return true;
}

bool stage(ABF & af, SAT_Solver * solver)
{
#if defined(OUTPUT)
	return false;
#endif
	Encodings::add_conflict_free(af, solver);
	Encodings::add_range(af, solver);
	/*for (uint32_t i = 0; i < af.assumptions.size(); i++) {
		solver->phase(af.in_var[af.assumptions[i]]);
	}*/

	vector<uint32_t> extension;
	while (true) {
		int sat = solver->solve();
		if (sat == 20) break;
		extension.clear();
		for (uint32_t i = 0; i < af.assumptions.size(); i++) {
			if (solver->val(af.in_var[af.assumptions[i]]) > 0) {
				extension.push_back(af.assumptions[i]+1);
			}
		}
		vector<int> complement_clause;
		complement_clause.reserve(af.assumptions.size());
		vector<int> units;
		units.reserve(af.assumptions.size());
		for (uint32_t i = 0; i < af.assumptions.size(); i++) {
			if (solver->val(af.range_var[af.assumptions[i]]) > 0) {
				units.push_back(af.range_var[af.assumptions[i]]);
			} else {
				complement_clause.push_back(af.range_var[af.assumptions[i]]);
			}
		}
		for (int lit : units) {
			solver->add(lit);
			solver->add(0);
		}
		add_clause(solver, complement_clause);
	}
	print_single_extension(af, extension, 0);
	return true;
}

bool ideal(ABF & af, SAT_Solver * solver)
{
#if defined(OUTPUT)
	return false;
#endif
	Encodings::add_complete(af, solver);

	vector<int> accepted_clause;
	vector<int> rejected_clause;
	vector<uint8_t> union_of_accepted(af.assumptions.size(), 0);
	vector<uint8_t> union_of_rejected(af.assumptions.size(), 0);
	int select = ++af.count;
	bool solved = false;

	while (true) {
		if (solved)
			solver->assume(select);
		int sat = solver->solve();
		if (sat == 20) {
			break;
		}
		solved = true;
		accepted_clause.clear();
		rejected_clause.clear();
		accepted_clause.push_back(-select);
		rejected_clause.push_back(-select);
		for (uint32_t i = 0; i < af.assumptions.size(); i++) {
			if (solver->val(af.in_var[af.assumptions[i]]) > 0) {
				union_of_accepted[i] = 1;
			} else if (!union_of_accepted[i]) {
				accepted_clause.push_back(af.in_var[af.assumptions[i]]);
			}
			if (solver->val(af.out_var[af.assumptions[i]]) > 0) {
				union_of_rejected[i] = 1;
			} else if (!union_of_rejected[i]) {
				rejected_clause.push_back(af.out_var[af.assumptions[i]]);
			}
		}
		add_clause(solver, accepted_clause);
		add_clause(solver, rejected_clause);
	}

	for (uint32_t i = 0; i < af.assumptions.size(); i++) {
		if (union_of_rejected[i]) {
			vector<int> clause = { -af.in_var[af.assumptions[i]] };
			add_clause(solver, clause);
		}
	}

	vector<int> clause = { -select };
	add_clause(solver, clause);

	vector<uint32_t> extension;
	while (true) {
		int sat = solver->solve();
		if (sat == 20) break;
		extension.clear();
		for (uint32_t i = 0; i < af.assumptions.size(); i++) {
			if (solver->val(af.in_var[af.assumptions[i]]) > 0) {
				extension.push_back(af.assumptions[i]+1);
			}
		}
		vector<int> complement_clause;
		complement_clause.reserve(af.assumptions.size());
		vector<int> units;
		units.reserve(af.assumptions.size());
		for (uint32_t i = 0; i < af.assumptions.size(); i++) {
			if (solver->val(af.in_var[af.assumptions[i]]) > 0) {
				units.push_back(af.in_var[af.assumptions[i]]);
			} else {
				complement_clause.push_back(af.in_var[af.assumptions[i]]);
			}
		}
		for (int lit : units) {
			solver->add(lit);
			solver->add(0);
		}
		add_clause(solver, complement_clause);
	}
	print_single_extension(af, extension, 0);
	return true;
}

}
