#include <iostream>
#include <iomanip>
#include <cassert>
#include <sys/resource.h>

#include "Propagator.h"

using namespace std;

double get_time() {
	rusage u; getrusage(RUSAGE_SELF, &u);
	double t = u.ru_utime.tv_sec + 1e-6 * u.ru_utime.tv_usec;
	t += u.ru_stime.tv_sec + 1e-6 * u.ru_stime.tv_usec;
	return t;
}

UnfoundedSetPropagator::UnfoundedSetPropagator(ABF & _af, CaDiCaL::Solver * _solver)
	: af(_af), solver(_solver)
{
	trail.push_back(vector<int>());
	trail_changed = true;
	//is_lazy = true;
	is_fixed.resize(af.count+1, 0);
	propagate_time = 0.0;

	atom_assignment[0].resize(af.symbols, 0);
	source_valid[0].resize(af.symbols, 0);
	source_pointer[0].resize(af.symbols, UINT32_MAX);
	is_unfounded[0].resize(af.symbols, 0);
	is_todo[0].resize(af.symbols, 0);

	body_assignment[0].resize(af.bodies.size(), 0);
	body_count[0].resize(af.bodies.size(), 0);
	body_watches[0].resize(af.bodies.size(), 0);

	if (af.sem != ST && af.sem != STG) {
		atom_assignment[1].resize(af.symbols, 0);
		source_valid[1].resize(af.symbols, 0);
		source_pointer[1].resize(af.symbols, UINT32_MAX);
		is_unfounded[1].resize(af.symbols, 0);
		is_todo[1].resize(af.symbols, 0);

		body_assignment[1].resize(af.bodies.size(), 0);
		body_count[1].resize(af.bodies.size(), 0);
		body_watches[1].resize(af.bodies.size(), 0);
	}

	//are_reasons_forgettable = true;
}

UnfoundedSetPropagator::~UnfoundedSetPropagator()
{
	/*cout << "c total process time in propagate:                 " 
    << std::fixed << std::setprecision(2) << propagate_time
    << "     seconds" << endl;*/
    solver->disconnect_external_propagator();
    solver->disconnect_fixed_listener();
}

void UnfoundedSetPropagator::print_set(vector<uint8_t> & v) {
	//cout << "v ";
	for (uint32_t i = 0; i < af.symbols; i++) {
		if (v[i]) cout << (i+1) << " ";
	}
	cout << "0" << endl;
}

void UnfoundedSetPropagator::init()
{
	for (uint32_t i = 0; i < af.bodies.size(); i++) {
		for (uint32_t atom : af.bodies[i]) {
			if (af.scc_index[atom] == af.scc_index[af.symbols + i]) {
				body_count[0][i]++;
			}
		}
		for (uint32_t head : af.derivation_graph[af.symbols + i]) {
			if (af.scc_index[head] != af.scc_index[af.symbols + i] || body_count[0][i] <= 0) {
				set_source(0, i, head);
			} 
		}
	}

	propagate_source_pointers(0);

	for (uint32_t i = 0; i < af.symbols; i++) {
		if (!af.is_assumption[i] && !source_valid[0][i]) {
			solver->add(-af.derived_from_in_var[i]); solver->add(0);
		}
	}

	if (af.sem != ST && af.sem != STG) {
		body_count[1] = body_count[0];
		body_watches[1] = body_watches[0];
		source_valid[1] = source_valid[0];
		source_pointer[1] = source_pointer[0];

		for (uint32_t i = 0; i < af.symbols; i++) {
			if (!af.is_assumption[i] && !source_valid[1][i]) {
				solver->add(-af.derived_from_undec_var[i]); solver->add(0);
			}
		}
	}

}

void UnfoundedSetPropagator::set_source(uint8_t index, uint32_t body_index, uint32_t head)
{
	if (!source_valid[index][head]) {
		if (source_pointer[index][head] != UINT32_MAX) {
			body_watches[index][source_pointer[index][head]]--;
		}
		source_pointer[index][head] = body_index;
		source_valid[index][head] = 1;
		body_watches[index][body_index]++;
		source_queue[index].push(head);
	}
}

void UnfoundedSetPropagator::propagate_source_pointers(uint8_t index)
{
	while (!source_queue[index].empty()) {
		uint32_t atom = source_queue[index].front();
		source_queue[index].pop();
		if (source_valid[index][atom]) {
			for (uint32_t body_node : af.derivation_graph[atom]) {
				uint32_t body_index = body_node - af.symbols;
				if (--body_count[index][body_index] == 0 && body_assignment[index][body_index] >= 0) {
					for (uint32_t head : af.derivation_graph[af.symbols + body_index]) {
						set_source(index, body_index, head);
					}
				}
			}
		} else {
			for (uint32_t body_node : af.derivation_graph[atom]) {
				uint32_t body_index = body_node - af.symbols;
				if (++body_count[index][body_index] == 1 && body_watches[index][body_index] != 0) {
					for (uint32_t head : af.derivation_graph[af.symbols + body_index]) {
						if (af.scc_index[head] != af.scc_index[af.symbols + body_index])
							continue;
						if (source_valid[index][head] && source_pointer[index][head] == body_index) {
							source_valid[index][head] = 0;
							source_queue[index].push(head);
						}
					}
				}
			}
		}
	}
}

void UnfoundedSetPropagator::find_source(uint8_t index, uint32_t head)
{
	push_ufs(index, head);
	uint32_t new_sources = 0;
	while (!ufs_list[index].empty()) {
		head = ufs_list[index].front();
		ufs_list[index].pop_front();
		if (!source_valid[index][head]) {
			for (uint32_t body_index : af.bodies_with_head[head]) {
				if (body_assignment[index][body_index] < 0) continue;
				if (af.scc_index[head] != af.scc_index[af.symbols + body_index] || body_count[index][body_index] == 0) {
					is_unfounded[index][head] = 0;
					set_source(index, body_index, head);
					propagate_source_pointers(index);
					new_sources++;
					break;
				}
				for (uint32_t atom : af.bodies[body_index]) {
					if (af.scc_index[atom] != af.scc_index[af.symbols + body_index])   continue;
					if (source_valid[index][atom] || atom_assignment[index][atom] < 0) continue;
					push_ufs(index, atom);
				}
			}
			if (!source_valid[index][head]) {
				invalid_queue[index].push(head);
			}
		} else {
			is_unfounded[index][head] = 0;
			new_sources++;
		}
	}
	while (!invalid_queue[index].empty()) {
		head = invalid_queue[index].front();
		invalid_queue[index].pop();
		if (source_valid[index][head]) {
			is_unfounded[index][head] = 0;
		} else {
			is_unfounded[index][head] = 1;
			ufs_list[index].push_back(head);
		}
	}
}

void UnfoundedSetPropagator::get_external_bodies(uint8_t index)
{
	external_bodies.clear();
	for (auto it = ufs_list[index].begin(); it != ufs_list[index].end(); ++it) {
		uint32_t atom = *it;
		for (uint32_t body_index : af.bodies_with_head[atom]) {
			if (af.scc_index[atom] != af.scc_index[af.symbols + body_index]) {
				external_bodies.insert(body_index);
				continue;
			}
			bool is_external = true;
			for (uint32_t body_atom : af.bodies[body_index]) {
				if (is_unfounded[index][body_atom]) {
					is_external = false;
					break;
				}
			}
			if (is_external) {
				external_bodies.insert(body_index);
			}
		}
	}
}

void UnfoundedSetPropagator::add_loop_formula(uint8_t index, uint32_t head)
{
	vector<int> clause;
	if (index == 0) {
		clause.push_back(-af.derived_from_in_var[head]);
	} else {
		clause.push_back(-af.derived_from_undec_var[head]);
	}
	for (uint32_t body_index : external_bodies) {
		if (index == 0) {
			clause.push_back(af.body_is_true_in_var[body_index]);
		} else {
			clause.push_back(af.body_is_true_undec_var[body_index]);
		}
	}
	clauses.push_back(clause);

}

void UnfoundedSetPropagator::propagate(uint8_t index)
{
	double start = get_time();

	for (auto it = ufs_list[index].begin(); it != ufs_list[index].end();) {
		uint32_t atom = *it;
		if (atom_assignment[index][atom] < 0 || source_valid[index][atom]) {
			ufs_list[index].erase(it++);
			is_unfounded[index][atom] = 0;
		} else {
			++it;
		}
	}

	if (!ufs_list[index].empty()) {
		get_external_bodies(index);
		add_loop_formula(index, ufs_list[index].front());
		is_unfounded[index][ufs_list[index].front()] = 0;
		ufs_list[index].pop_front();

		double end = get_time();
		propagate_time += (end-start);
		return;
	}

	while (!invalid_queue[index].empty()) {
		uint32_t body_index = invalid_queue[index].front();
		invalid_queue[index].pop();
		if (body_assignment[index][body_index] >= 0) continue;

		for (uint32_t head : af.derivation_graph[af.symbols + body_index]) {
			if (source_pointer[index][head] == body_index) {
				if (source_valid[index][head]) {
					source_valid[index][head] = 0;
					source_queue[index].push(head);
				}
				push_todo(index, head);
			}
		}
		propagate_source_pointers(index);
	}

	while (!todo_queue[index].empty()) {
		uint32_t head = todo_queue[index].front();
		todo_queue[index].pop();
		is_todo[index][head] = 0;
		if (source_valid[index][head] || atom_assignment[index][head] < 0) {
			continue;
		}
		find_source(index, head);

		if (!ufs_list[index].empty()) {
			get_external_bodies(index);
			add_loop_formula(index, ufs_list[index].front());
			is_unfounded[index][ufs_list[index].front()] = 0;
			ufs_list[index].pop_front();

			double end = get_time();
			propagate_time += (end-start);
			return;
		}
	}
}

void UnfoundedSetPropagator::assign(int var, int8_t val)
{
	pair<var_type, uint32_t> tmp = af.interpretation[var];
	var_type type  = tmp.first;
	uint32_t index = tmp.second;
	if (type == derived_from_in) {
		atom_assignment[0][index] = val;
	} else if (type == derived_from_undec) {
		atom_assignment[1][index] = val;
	} else if (type == body_is_true_in) {
		body_assignment[0][index] = val;
		if (val < 0) {
			invalid_queue[0].push(index);
		}
	} else if (type == body_is_true_undec) {
		body_assignment[1][index] = val;
		if (val < 0) {
			invalid_queue[1].push(index);
		}
	} else {
		cerr << "c WARNING: non-observed variable in assign: " << var << " (" << (int)val << ")" << endl;
	}
}

int8_t UnfoundedSetPropagator::value(int var)
{
	pair<var_type, uint32_t> tmp = af.interpretation[var];
	var_type type  = tmp.first;
	uint32_t index = tmp.second;
	if (type == derived_from_in) {
		return atom_assignment[0][index];
	} else if (type == derived_from_undec) {
		return atom_assignment[1][index];
	} else if (type == body_is_true_in) {
		return body_assignment[0][index];
	} else if (type == body_is_true_undec) {
		return body_assignment[1][index];
	} else {
		cerr << "c WARNING: non-observed variable in value: " << var << endl;
		return 0;
	}
}

void UnfoundedSetPropagator::notify_fixed_assignment(int lit)
{
	is_fixed[abs(lit)] = 1;
}

void UnfoundedSetPropagator::notify_assignment(const vector<int> & lits)
{
	for (int lit : lits) {
		trail_changed = true;
		int var = abs(lit);
		assign(var, (lit > 0) ? 1 : -1);
		trail.back().push_back(var);
	}
}

void UnfoundedSetPropagator::notify_new_decision_level()
{
	trail.push_back(vector<int>());
}

void UnfoundedSetPropagator::notify_backtrack(size_t new_level)
{
	while (trail.size() > new_level + 1) {
		auto last = trail.back();
		for (int var : last) {
			if (!is_fixed[var]) {
				assign(var, 0);
			}
		}
		trail.pop_back();
	}
	queue<uint32_t>().swap(todo_queue[0]);
	list<uint32_t>().swap(ufs_list[0]);
	fill(is_todo[0].begin(), is_todo[0].end(), 0);
	fill(is_unfounded[0].begin(), is_unfounded[0].end(), 0);

	if (af.sem != ST && af.sem != STG) {
		queue<uint32_t>().swap(todo_queue[1]);
		list<uint32_t>().swap(ufs_list[1]);
		fill(is_todo[1].begin(), is_todo[1].end(), 0);
		fill(is_unfounded[1].begin(), is_unfounded[1].end(), 0);
	}
}

bool UnfoundedSetPropagator::cb_check_found_model(const std::vector<int> &model)
{
	if (!clauses.empty()) {
		return false;
	}
	return true;
}

int UnfoundedSetPropagator::cb_propagate()
{
	if (!trail_changed) return 0;

	trail_changed = false;
	if (clauses.empty()) propagate(0);
	if (af.sem != ST && af.sem != STG && clauses.empty()) {
		for (uint32_t i = 0; i < af.symbols; i++) {
			if (source_valid[0][i]) {
				set_source(1, source_pointer[0][i], i);
				propagate_source_pointers(1);
			}
		}
		propagate(1);
	}
	if (clauses.empty()) return 0;

	vector<int> clause = clauses.back();
	int unknown = 0; int unassigned = 0;
	for (auto lit : clause) {
		int var = abs(lit);
		int8_t truth_value = value(var);
		if (truth_value == 0) {
			unknown++;
			unassigned = lit;
		} else if (truth_value > 0 && lit > 0) {
			return 0;
		} else if (truth_value < 0 && lit < 0) {
			return 0;
		}
	}

	if (unknown == 1) {
		clauses.pop_back();
		reasons[unassigned] = clause;
		return unassigned;
	}

	return 0;
}

int UnfoundedSetPropagator::cb_add_reason_clause_lit(int propagated_lit)
{
	if (reasons[propagated_lit].empty()) {
		return 0;
	}
	int lit = reasons[propagated_lit].back();
	reasons[propagated_lit].pop_back();
	return lit;
}

bool UnfoundedSetPropagator::cb_has_external_clause(bool &is_forgettable) {
	if (!clauses.empty()) {
		is_forgettable = false;
		return true;
	}
	return false;
}

int UnfoundedSetPropagator::cb_add_external_clause_lit() {
	vector<int> & clause = clauses.back();
	if (clause.empty()) {
		clauses.pop_back();
		return 0;
	}
	int lit = clause.back();
	clause.pop_back();
	return lit;
}