#include <iostream>
#include <vector>
#include <utility>
#include <cstdint>
#include <limits>
#include <unordered_map>
#include <unordered_set>
#include <boost/graph/adjacency_list.hpp>
#include <boost/graph/graph_traits.hpp>

#include "VE.h"
#include "Util.h"

using Graph = boost::adjacency_list<
    boost::listS,
    boost::vecS,
    boost::bidirectionalS
>;

static inline uint64_t make_key(uint32_t u, uint32_t v) {
    return (uint64_t(u) << 32) | v;
}

void set_acyclic(SAT_Solver * solver,
                 std::vector<std::pair<uint32_t,uint32_t>>& edges,
                 std::vector<uint32_t>& edge_vars,
                 uint32_t& top)
{
    Graph g;

    // maps
    std::unordered_map<uint32_t, Graph::vertex_descriptor> id2vd;
    std::unordered_map<Graph::vertex_descriptor, uint32_t> vd2id;

    // add vertices
    for (auto [u, v] : edges) {
        if (!id2vd.count(u)) id2vd[u] = add_vertex(g);
        if (!id2vd.count(v)) id2vd[v] = add_vertex(g);
    }

    // edge‐var map
    std::unordered_map<uint64_t, uint32_t> edge_map;
    edge_map.reserve(edges.size()*2);

    std::vector<int> clause;

    // add initial edges
	for (size_t i = 0; i < edges.size(); ++i) {
		auto [u, v] = edges[i];
		uint64_t key = make_key(u, v);
		if (edge_map.find(key) == edge_map.end()) {
			uint32_t old_var = edge_vars[i];
			++top;
			//solver->add(top);
            clause = { -static_cast<int>(old_var), static_cast<int>(top) };
			add_clause(solver, clause);
			edge_vars[i] = top;
			edge_map.insert({ key, top });
			add_edge(id2vd[u], id2vd[v], g);
			// self‐loop unary clause ¬x_new
			if (u == v) {
                clause = { -static_cast<int>(top) };
				add_clause(solver, clause);
			}
		}
	}


    // reverse map
    for (auto& p : id2vd) {
        vd2id[p.second] = p.first;
    }

    using Traits = boost::graph_traits<Graph>;
    size_t num_nodes = num_vertices(g);

    // vertex elimination
    for (size_t iter = 0; iter < num_nodes; ++iter) {
        // pick min‐score vertex
        Traits::vertex_iterator vi, vi_end;
        boost::tie(vi, vi_end) = vertices(g);
        Traits::vertex_descriptor best_v = Traits::null_vertex();
        size_t best_score = std::numeric_limits<size_t>::max();
        for (; vi != vi_end; ++vi) {
            auto v = *vi;
            size_t score = in_degree(v, g) * out_degree(v, g);
            if (score > 0 && score < best_score) {
                best_score = score;
                best_v = v;
            }
        }
        if (best_v == Traits::null_vertex()) break;

        uint32_t n_best = vd2id[best_v];

        // connect in→out around best_v
        Traits::in_edge_iterator in_it, in_end;
        Traits::out_edge_iterator out_it, out_end;
        for (boost::tie(in_it, in_end) = in_edges(best_v, g); in_it != in_end; ++in_it) {
            auto i_vd = source(*in_it, g);
            uint32_t u_id = vd2id[i_vd];
            for (boost::tie(out_it, out_end) = out_edges(best_v, g);
                 out_it != out_end; ++out_it)
            {
                auto j_vd = target(*out_it, g);
                uint32_t v_id = vd2id[j_vd];

                uint64_t key = make_key(u_id, v_id);
                ++top;
                auto ins = edge_map.insert({key, top});
                bool is_new = ins.second;
                if (is_new) {
                    edge_vars.push_back(top);
                    add_edge(i_vd, j_vd, g);
                    edges.emplace_back(u_id, v_id);
                    //solver->add(top);
                    // self‐loop clause for new edge
                    if (u_id == v_id) {
                        clause = { -static_cast<int>(top) };
                        add_clause(solver, clause);
                    }
                } else {
                    --top;
                }

                // ternary clause
                if (v_id != n_best) {
                    int var_uv  = static_cast<int>(edge_map[make_key(u_id,   n_best)]);
                    int var_bv  = static_cast<int>(edge_map[make_key(n_best, v_id)]);
                    int var_new = static_cast<int>(edge_map[key]);
                    clause = { -var_uv, -var_bv, var_new };
                    add_clause(solver, clause);
                }
            }
        }

        // remove adjacent edges
        std::vector<std::pair<uint32_t,uint32_t>> to_rem;
        for (boost::tie(in_it, in_end) = in_edges(best_v, g); in_it != in_end; ++in_it) {
            to_rem.emplace_back(vd2id[source(*in_it, g)], n_best);
        }
        for (boost::tie(out_it, out_end) = out_edges(best_v, g); out_it != out_end; ++out_it) {
            to_rem.emplace_back(n_best, vd2id[target(*out_it, g)]);
        }
        for (auto [u_id, v_id] : to_rem) {
            remove_edge(id2vd[u_id], id2vd[v_id], g);
            edge_map.erase(make_key(u_id, v_id));
        }

        // debug print (disabled)
        if (0) {
            std::cout << "\nGraph after iteration " << iter << ":\n";
            Traits::edge_iterator ei, ei_end;
            for (boost::tie(ei, ei_end) = boost::edges(g); ei != ei_end; ++ei) {
                std::cout
                  << vd2id[source(*ei, g)]
                  << " -> "
                  << vd2id[target(*ei, g)]
                  << "\n";
            }
        }
    }
}
