#include "ABF.h"

#include <iostream>
#include <set>

using namespace std;

ABF::ABF() : symbols(0), count(0), symbols_as_strings(false)
{}

void ABF::add_rule(vector<uint32_t> & rule)
{
	if (!rule.size())
		return;
	for (uint32_t & val : rule) {
		val--; // index from zero
	}
	uint32_t head = rule[0];
	vector<uint32_t> body(rule.begin()+1, rule.end());
	auto it = body_index.find(body);
	if (it == body_index.end()) {
		rules.push_back(make_pair(head, bodies.size()));
		bodies_with_head[head].push_back(bodies.size());
		body_index[body] = bodies.size();
		bodies.push_back(body);
	} else {
		rules.push_back(make_pair(head, it->second));
		bodies_with_head[head].push_back(it->second);
	}
	for (uint32_t val : rule) {
		symbols = max(symbols, val+1);
	}
}

void ABF::add_assumption(uint32_t assum)
{
	if (is_assumption.count(assum-1))
		return;
	assumptions.push_back(assum-1);
	is_assumption[assum-1] = 1;
	symbols = max(symbols, assum);
}

void ABF::add_contrary(uint32_t assum, uint32_t cont)
{
	if (contrary.count(assum-1))
		return;
	contrary[assum-1] = cont-1;
	symbols = max(symbols, assum);
	symbols = max(symbols, cont);
}

void ABF::add_rule(string rule)
{
	size_t pos = 0;
	string arg;
	uint32_t head = 0;
	vector<uint32_t> body;
	while ((pos = rule.find(",")) != string::npos) {
		arg = rule.substr(0, pos);
		rule.erase(0, pos+1);
		if (symbol_to_int.count(arg) == 0) {
			int_to_symbol[++symbols] = arg;
			symbol_to_int[arg] = symbols;
		}
		if (head == 0) {
			head = symbol_to_int[arg]-1;
		} else {
			body.push_back(symbol_to_int[arg]-1);
		}
	}
	if (symbol_to_int.count(rule) == 0) {
		int_to_symbol[++symbols] = rule;
		symbol_to_int[rule] = symbols;
	}
	if (head == 0) {
		head = symbol_to_int[rule]-1;
	} else {
		body.push_back(symbol_to_int[rule]-1);
	}
	auto it = body_index.find(body);
	if (it == body_index.end()) {
		rules.push_back(make_pair(head, bodies.size()));
		body_index[body] = bodies.size();
		bodies.push_back(body);
	} else {
		rules.push_back(make_pair(head, it->second));
	}
}

void ABF::add_assumption(string assum)
{
	if (symbol_to_int.count(assum) == 0) {
		int_to_symbol[++symbols] = assum;
		symbol_to_int[assum] = symbols;
	}
	assumptions.push_back(symbol_to_int[assum]-1);
	is_assumption[symbol_to_int[assum]-1] = 1;
}

void ABF::add_contrary(string assum, string cont)
{
	if (symbol_to_int.count(assum) == 0) {
		int_to_symbol[++symbols] = assum;
		symbol_to_int[assum] = symbols;
	}
	if (symbol_to_int.count(cont) == 0) {
		int_to_symbol[++symbols] = cont;
		symbol_to_int[cont] = symbols;
	}
	contrary[symbol_to_int[assum]-1] = symbol_to_int[cont]-1;
}

void ABF::strong_connect(uint32_t node, uint32_t & index,
	stack<uint32_t> & node_stack, vector<uint8_t> & on_stack,
	vector<uint32_t> & node_to_index, vector<uint32_t> & node_to_lowlink)
{
	node_to_index[node] = index;
	node_to_lowlink[node] = index;
	index++;
	node_stack.push(node);
	on_stack[node] = 1;;
	for (uint32_t i = 0; i < derivation_graph[node].size(); i++) {
		uint32_t successor = derivation_graph[node][i];
		if (node_to_index[successor] == 0) {
			// successor has not been visited, recurse
			strong_connect(successor, index, node_stack, on_stack, node_to_index, node_to_lowlink);
			node_to_lowlink[node] = min(node_to_lowlink[node], node_to_lowlink[successor]);
		} else if (on_stack[successor]) {
			// successor is on stack, and therefore in the current scc;
			// if not on stack, edge points to an already found scc
			node_to_lowlink[node] = min(node_to_lowlink[node], node_to_index[successor]);
			//node_to_lowlink[node] = min(node_to_lowlink[node], node_to_lowlink[successor]);
		}
	}
	// node is a root node
	if (node_to_lowlink[node] == node_to_index[node]) {
		// construct new scc from stack
		//cout << "c scc " << scc_count << " : ";
		while (true) {
			uint32_t next = node_stack.top();
			node_stack.pop();
			on_stack[next] = 0;
			scc_index[next] = scc_count;
			//cout << next << " ";
			if (node == next) break;
		}
		scc_count++;
		//cout << endl;
	}
}

void ABF::tarjan_scc()
{
	uint32_t index = 1;
	stack<uint32_t> node_stack;
	vector<uint8_t> on_stack(symbols+bodies.size(), 0);
	vector<uint32_t> node_to_index(symbols+bodies.size(), 0);
	vector<uint32_t> node_to_lowlink(symbols+bodies.size(), 0);
	for (uint32_t i = 0; i < symbols+bodies.size(); i++) {
		if (node_to_index[i] == 0) {
			strong_connect(i, index, node_stack, on_stack, node_to_index, node_to_lowlink);
		}
	}
}

void ABF::initialize_vars()
{
	set<pair<uint32_t,uint32_t>> derivation_edges_set;
	for (const auto& [head, index] : rules) {
		derivation_edges_set.insert(make_pair(symbols+index, head));
		//derivation_graph[symbols+index].push_back(head);
		for (uint32_t atom : bodies[index]) {
			derivation_edges_set.insert(make_pair(atom, symbols+index));
			//derivation_graph[atom].push_back(symbols + index);
		}
	}

	derivation_graph.resize(symbols+bodies.size());
	for (const auto& [s,t] : derivation_edges_set) {
		derivation_graph[s].push_back(t);
	}

	// compute scc decomposition
	scc_count = 0;
	scc_index.resize(symbols+bodies.size());
	tarjan_scc();

	derivation_graph.clear();
	derivation_graph.resize(symbols+bodies.size());
	for (const auto& [s,t] : derivation_edges_set) {
		if (s >= symbols || scc_index[s] == scc_index[t]) {
			derivation_graph[s].push_back(t);
		}
	}

#if defined(OUTPUT)
	cout << "c number of edges = " << derivation_edges.size() << endl;
	cout << "c number of sccs  = " << scc_count << endl;
#endif

	interpretation.push_back(make_pair(none, 0));

	in_var.resize(symbols, 0);
	for (uint32_t i = 0; i < assumptions.size(); i++) {
		in_var[assumptions[i]] = ++count;
		interpretation.push_back(make_pair(in, assumptions[i]));
#if defined(OUTPUT)
		cout << "c " << count << " = in(" << assumptions[i] << ")\n";
#endif
	}

	out_var.resize(symbols, 0);
	for (uint32_t i = 0; i < assumptions.size(); i++) {
		out_var[assumptions[i]] = ++count;
		interpretation.push_back(make_pair(out, assumptions[i]));
#if defined(OUTPUT)
		cout << "c " << count << " = out(" << assumptions[i] << ")\n";
#endif
	}
	//cout << endl;

	derived_from_in_var.resize(symbols, 0);
	for (uint32_t i = 0; i < symbols; i++) {
		if (is_assumption[i]) continue;
		derived_from_in_var[i] = ++count;
		interpretation.push_back(make_pair(derived_from_in, i));
#if defined(OUTPUT)
		cout << "c " << count << " = derived(" << i << ")\n";
#endif
	}

	body_is_true_in_var.resize(bodies.size(), 0);
	for (uint32_t i = 0; i < bodies.size(); i++) {
		body_is_true_in_var[i] = ++count;
		interpretation.push_back(make_pair(body_is_true_in, i));
#if defined(OUTPUT)
		cout << "c " << count << " = body_true(";
		for (uint32_t k = 0; k < rules[i].second.size(); k++) {
			cout << bodies[i].second[k] << " ";
		}
		cout << "\b)\n";
#endif
	}

	if (sem == ST) return;

	if (sem != STG) {
		attacked_by_undec_var.resize(symbols, 0);
		for (uint32_t i = 0; i < assumptions.size(); i++) {
			attacked_by_undec_var[assumptions[i]] = ++count;
			interpretation.push_back(make_pair(attacked_by_undec, assumptions[i]));
			//cout << "c " << count << " = assu " << assumptions[i] << " attacked by undec\n";
		}
		//cout << endl;

		derived_from_undec_var.resize(symbols, 0);
		for (uint32_t i = 0; i < symbols; i++) {
			if (is_assumption[i]) continue;
			derived_from_undec_var[i] = ++count;
			interpretation.push_back(make_pair(derived_from_undec, i));
		}
		//cout << endl;

		body_is_true_undec_var.resize(bodies.size(), 0);
		for (uint32_t i = 0; i < bodies.size(); i++) {
			body_is_true_undec_var[i] = ++count;
			interpretation.push_back(make_pair(body_is_true_undec, i));
			//cout << "c " << count << " = body of rule " << rules[i].first << " <- ";
			//for (uint32_t k = 0; k < rules[i].second.size(); k++) {
				//cout << rules[i].second[k] << " ";
			//}
			//cout << " true under undec\n";
		}
		//cout << endl;
	}

	if (sem != SST && sem != STG) return;

	range_var.resize(symbols, 0);
	for (uint32_t i = 0; i < assumptions.size(); i++) {
		range_var[assumptions[i]] = ++count;
		interpretation.push_back(make_pair(range, assumptions[i]));
		//cout << "c " << count << " = assu " << assumptions[i] << " attacked by undec\n";
	}

}

string ABF::lit_to_string(int lit) {
	string rep = "";
	int var = abs(lit);
	if (lit < 0) rep += "-";

	//cout << "var = " << var << " size = " << interpretation.size() << endl;
	pair<var_type, int> tmp = interpretation[var];
	switch (tmp.first) {
	case in:
		rep += "in(";
		break;
	case out:
		rep += "out(";
		break;
	case attacked_by_undec:
		rep += "attacked_by_undec(";
		break;
	case body_is_true_in:
		rep += "body_is_true_in(";
		break;
	case body_is_true_undec:
		rep += "body_is_true_undec(";
		break;
	case derived_from_in:
		rep += "derived_from_in(";
		break;
	case derived_from_undec:
		rep += "derived_from_undec(";
		break;
	case range:
		rep += "range(";
		break;
	default:
		break;
	}
	rep += to_string(tmp.second) + ")";
	return rep;
}