use crate::{
    aa::AAFramework,
    encodings::ConstraintsEncoder,
    sat::{Literal, SatSolver},
    utils::LabelType,
};

pub(crate) fn is_extension<T>(
    af: &AAFramework<T>,
    args: &[&T],
    solver: &mut dyn SatSolver,
    constraints_encoder: &dyn ConstraintsEncoder<T>,
) -> bool
where
    T: LabelType,
{
    is_extension_internal(af, args, solver, constraints_encoder, false)
}

pub(crate) fn is_maximal_extension<T>(
    af: &AAFramework<T>,
    args: &[&T],
    solver: &mut dyn SatSolver,
    constraints_encoder: &dyn ConstraintsEncoder<T>,
) -> bool
where
    T: LabelType,
{
    is_extension_internal(af, args, solver, constraints_encoder, true)
}

fn is_extension_internal<T>(
    af: &AAFramework<T>,
    args: &[&T],
    solver: &mut dyn SatSolver,
    constraints_encoder: &dyn ConstraintsEncoder<T>,
    check_is_maximal: bool,
) -> bool
where
    T: LabelType,
{
    let (mut args_in, mut args_out) = split_args_literals(af, args, constraints_encoder);
    let mut assumptions = args_in.clone();
    assumptions.append(&mut args_out.iter().map(|l| l.negate()).collect::<Vec<_>>());
    if solver
        .solve_under_assumptions(&assumptions)
        .unwrap_model()
        .is_none()
    {
        return false;
    }
    if !check_is_maximal {
        return true;
    }
    let selector = Literal::from(1 + solver.n_vars() as isize);
    args_out.push(selector);
    solver.add_clause(args_out);
    args_in.push(selector.negate());
    let has_model = solver
        .solve_under_assumptions(&args_in)
        .unwrap_model()
        .is_some();
    solver.add_clause(vec![selector]);
    !has_model
}

fn split_args_literals<T>(
    af: &AAFramework<T>,
    args: &[&T],
    constraints_encoder: &dyn ConstraintsEncoder<T>,
) -> (Vec<Literal>, Vec<Literal>)
where
    T: LabelType,
{
    let mut args_in = Vec::with_capacity(args.len());
    let mut args_out = Vec::with_capacity(af.n_arguments() - args.len());
    let mut is_in = vec![false; af.n_arguments()];
    for a in args {
        let arg = af.argument_set().get_argument(a).unwrap();
        let lit = constraints_encoder.arg_to_lit(arg);
        args_in.push(lit);
        is_in[arg.id()] = true;
    }
    for (i, b) in is_in.iter().enumerate() {
        if !b {
            let arg = af.argument_set().get_argument_by_id(i);
            let lit = constraints_encoder.arg_to_lit(arg);
            args_out.push(lit);
        }
    }
    (args_in, args_out)
}
