#include <iostream>
#include <cassert>
#include <sys/resource.h>

#include "Propagator.h"

using namespace std;

double get_time() {
	rusage u; getrusage(RUSAGE_SELF, &u);
	double t = u.ru_utime.tv_sec + 1e-6 * u.ru_utime.tv_usec;
	t += u.ru_stime.tv_sec + 1e-6 * u.ru_stime.tv_usec;
	return t;
}

Graph::Graph(vector<int> & arc_vars, vector<pair<int,int>> & arcs)
{
	assert(arc_vars.size() == arcs.size());
	n_nodes = 0;

	for (int i = 0; i < arc_vars.size(); i++) {
		add_arc(arc_vars[i], arcs[i].first, arcs[i].second);
	}

	arc_to_status.resize(n_nodes * n_nodes);
	arc_to_var.resize(n_nodes * n_nodes);

	for (int i = 0; i < arcs.size(); i++) {
		int source = arcs[i].first;
		int target = arcs[i].second;
		arc_to_status[key(source, target)] = POSSIBLE;
		arc_to_var[key(source, target)] = arc_vars[i];
	}
}

void Graph::add_arc(int arc_var, int source, int target)
{
	n_nodes = max(n_nodes, source+1);
	n_nodes = max(n_nodes, target+1);
	nodes.insert(source); nodes.insert(target);
	arcs.insert(make_pair(source, target));

	if (adjacency_list.size() < n_nodes)
		adjacency_list.resize(n_nodes);
	if (reverse_adjacency_list.size() < n_nodes)
		reverse_adjacency_list.resize(n_nodes);

	adjacency_list[source].push_back(target);
	reverse_adjacency_list[target].push_back(source);

	var_to_arc[arc_var] = make_pair(source, target);
}

void Graph::bfs(int source, bool reverse, vector<uint8_t> & reachable, vector<int> & path)
{
	queue<int> queue;
	queue.push(source);

	while (!queue.empty()) {
		int v = queue.front();
		queue.pop();

		if (!reachable[v])
			reachable[v] = 1;

		if (!reverse) {
			for (auto w : adjacency_list[v]) {
				if (reachable[w]) continue;
				if (arc_to_status[key(v,w)] == ENABLED) {
					path[w] = v;
					queue.push(w);
				}
			}
		} else {
			for (auto w : reverse_adjacency_list[v]) {
				if (reachable[w]) continue;
				if (arc_to_status[key(w,v)] == ENABLED) {
					path[w] = v;
					queue.push(w);
				}
			}
		}
	}
}

void Graph::dfs(int source, bool reverse, vector<uint8_t> & reachable, vector<int> & path)
{
	stack<int> stack;
	stack.push(source);

	while (!stack.empty()) {
		int v = stack.top();
		stack.pop();

		if (!reachable[v])
			reachable[v] = 1;

		if (!reverse) {
			for (auto w : adjacency_list[v]) {
				if (reachable[w]) continue;
				if (arc_to_status[key(v,w)] == ENABLED) {
					path[w] = v;
					stack.push(w);
				}
			}
		} else {
			for (auto w : reverse_adjacency_list[v]) {
				if (reachable[w]) continue;
				if (arc_to_status[key(w,v)] == ENABLED) {
					path[w] = v;
					stack.push(w);
				}
			}
		}
	}
}

void Graph::append_path_to_clause(int source, int target, bool reverse, vector<int> & path, vector<int> & clause)
{
	int curr = source;
	while (curr != target) {
		int next = path[curr];
		if (!reverse) {
			clause.push_back(-arc_to_var[key(next, curr)]);
		} else {
			clause.push_back(-arc_to_var[key(curr, next)]);
		}
		curr = next;
	}
}

AcyclicityPropagator::AcyclicityPropagator(ABF & _af, Graph & _graph)
	: af(_af), graph(_graph)
{
	trail.push_back(std::vector<int>());
	trail_changed = true;
	is_fixed.resize(af.count+1, 0);
	propagate_time = 0.0;
	//are_reasons_forgettable = true;
	//is_lazy = true;
}

AcyclicityPropagator::~AcyclicityPropagator()
{
	cout << "c propagate time: " << propagate_time << " seconds" << endl;
}

void AcyclicityPropagator::propagate()
{
	while (!assignments.empty()) {
		int lit = assignments.front().first;
		assignments.pop_front();
		assert(lit > 0);

		int source = graph.var_to_arc[lit].first;
		int target = graph.var_to_arc[lit].second;

		if (graph.arc_to_status[graph.key(source, target)] != ENABLED)
			continue;

		vector<uint8_t> reachable_from_target(graph.n_nodes+1, 0);
		vector<int> path_from_target(graph.n_nodes+1, -1);
		graph.dfs(target, false, reachable_from_target, path_from_target);

		if (reachable_from_target[source]) {
			vector<int> cycle_clause;
			cycle_clause.push_back(-graph.arc_to_var[graph.key(source, target)]);
			graph.append_path_to_clause(source, target, false, path_from_target, cycle_clause);
			clauses.push_back(cycle_clause);
			return;
		}

		vector<uint8_t> reachable_from_source(graph.n_nodes+1, 0);
		vector<int> path_from_source(graph.n_nodes+1, -1);
		graph.dfs(source, true, reachable_from_source, path_from_source);

		for (auto v : graph.nodes) {
			if (!reachable_from_target[v]) continue;
			for (auto w : graph.adjacency_list[v]) {
				if (!reachable_from_source[w]) continue;
				assert(graph.arc_to_status[graph.key(v,w)] != ENABLED);
				if (graph.arc_to_status[graph.key(v,w)] == DISABLED)
					continue;

				vector<int> cycle_clause;
				cycle_clause.push_back(-graph.arc_to_var[graph.key(source,target)]);
				cycle_clause.push_back(-graph.arc_to_var[graph.key(v,w)]);
				graph.append_path_to_clause(v, target, false, path_from_target, cycle_clause);
				graph.append_path_to_clause(w, source, true, path_from_source, cycle_clause);
				clauses.push_back(cycle_clause);

			}
		}
	}
}

void AcyclicityPropagator::notify_fixed_assignment(int lit)
{
	is_fixed[abs(lit)] = 1;
}

void AcyclicityPropagator::notify_assignment(const vector<int> & lits)
{
	for (int lit : lits) {
		int var = abs(lit);
		if (var < af.min_edge_var || var > af.max_edge_var) {
			cout << "c WARNING: non-observed variable in notify: " << var << endl;
			continue;
		}
		trail_changed = true;
		pair<int,int> arc = graph.var_to_arc[var];
		int source = arc.first;
		int target = arc.second;
		if (lit > 0) {
			graph.arc_to_status[graph.key(source, target)] = ENABLED;
			assignments.push_back(make_pair(lit, trail.size()));
		} else {
			graph.arc_to_status[graph.key(source, target)] = DISABLED;
		}
		trail.back().push_back(var);
	}
}

void AcyclicityPropagator::notify_new_decision_level()
{
	trail.push_back(vector<int>());
}

void AcyclicityPropagator::notify_backtrack(size_t new_level)
{
	while (trail.size() > new_level + 1) {
		auto last = trail.back();
		for (int var : last) {
			if (!is_fixed[var]) {
				pair<int,int> arc = graph.var_to_arc[var];
				int source = arc.first;
				int target = arc.second;
				graph.arc_to_status[graph.key(source, target)] = POSSIBLE;
			}
		}
		trail.pop_back();
	}
}

bool AcyclicityPropagator::cb_check_found_model(const std::vector<int> &model)
{
	if (!clauses.empty()) return false;
	if (!is_lazy) return true;
	Graph model_graph = graph;
	for (int lit : model) {
		int var = abs(lit);
		if (!graph.var_to_arc.contains(var)) continue;
		pair<int,int> arc = graph.var_to_arc[var];
		int source = arc.first;
		int target = arc.second;
		if (lit > 0) {
			model_graph.arc_to_status[model_graph.key(source, target)] = ENABLED;
		} else {
			model_graph.arc_to_status[model_graph.key(source, target)] = DISABLED;
		}
	}
	vector<uint8_t> reachable(model_graph.n_nodes+1, 0);
	for (auto & target : model_graph.nodes) {
		if (reachable[target]) continue;
		vector<int> path_from_target(model_graph.n_nodes+1, -1);
		model_graph.bfs(target, false, reachable, path_from_target);
		for (auto & source : model_graph.reverse_adjacency_list[target]) {
			if (model_graph.arc_to_status[model_graph.key(source, target)] == DISABLED)
				continue;
			if (reachable[source]) {
				vector<int> cycle_clause;
				cycle_clause.push_back(-model_graph.arc_to_var[model_graph.key(source, target)]);
				model_graph.append_path_to_clause(source, target, false, path_from_target, cycle_clause);
				clauses.push_back(cycle_clause);
				return false;
			}
		}
	}
	return true;
}

int AcyclicityPropagator::cb_propagate()
{
	if (!trail_changed) return 0;

	trail_changed = false;
	if (clauses.empty()) propagate();
	if (clauses.empty()) return 0;

	vector<int> clause = clauses.back();
	int unknown = 0; int unassigned = 0;
	for (auto lit : clause) {
		assert(lit < 0);
		int var = abs(lit);
		int source = graph.var_to_arc[var].first;
		int target = graph.var_to_arc[var].second;
		if (graph.arc_to_status[graph.key(source, target)] == POSSIBLE) {
			unknown++; unassigned = lit;
		}
		if (graph.arc_to_status[graph.key(source, target)] == DISABLED) {
			return 0;
		}
	}
	
	if (unknown == 1) {
		clauses.pop_back();
		reasons[unassigned] = clause;
		return unassigned;
	}

	return 0;
}

int AcyclicityPropagator::cb_add_reason_clause_lit(int propagated_lit)
{
	if (reasons[propagated_lit].empty()) {
		return 0;
	}
	int lit = reasons[propagated_lit].back();
	reasons[propagated_lit].pop_back();
	return lit;
}

bool AcyclicityPropagator::cb_has_external_clause(bool &is_forgettable) {
	if (!clauses.empty()) {
		is_forgettable = false;
		return true;
	}
	return false;
}

int AcyclicityPropagator::cb_add_external_clause_lit() {
	vector<int> & clause = clauses.back();
	if (clause.empty()) {
		clauses.pop_back();
		return 0;
	}
	int lit = clause.back();
	clause.pop_back();
	return lit;
}