#include "regalloc.hpp"

#include "Function.h"
#include "Instruction.h"
#include "liverange.hpp"

#include <algorithm>

using std::for_each;

using namespace RA;

#define ASSERT_CMPINST_USED_ONCE(cmpinst)                                      \
    (assert(cmpinst->get_use_list().size() <= 1))

int
get_arg_id(Argument *arg) {
    auto args = arg->get_parent()->get_args();
    int id = 1;
    for (auto a : args) {
        if (a == arg)
            break;
        ++id;
    }
    return id;
}

bool
RegAllocator::no_reg_alloca(Value *v) {
    auto instr = dynamic_cast<Instruction *>(v);
    auto arg = dynamic_cast<Argument *>(v);
    if (instr) {
        // never allocate register
        if (instr->is_alloca() or instr->is_cmp() or instr->is_fcmp())
            return true;
        else if (instr->is_zext()) { // only alloca for true use
            bool alloc;
            ASSERT_CMPINST_USED_ONCE(instr);
            if (instr->get_use_list().size() == 0)
                return false;
            auto use_ins = dynamic_cast<Instruction *>(
                instr->get_use_list().begin()->val_);
            // assert(use_ins != nullptr && "should only be instruction?");
            if (use_ins->is_cmp() and
                static_cast<CmpInst *>(use_ins)->get_cmp_op() == CmpInst::NE) {
                // this case:
                // %op0 = icmp slt i32 1, 2
                // %op1 = zext i1 %op0 to i32
                // %op2 = icmp ne i32 %op1, 0  # <- if judges to here
                // br i1 %op2, label %label3, label %label5
                ASSERT_CMPINST_USED_ONCE(use_ins);
                if (use_ins->get_use_list().size() == 0)
                    return false;
                auto use2_ins = dynamic_cast<Instruction *>(
                    use_ins->get_use_list().begin()->val_);
                alloc = not(use2_ins->is_br());
            } else
                alloc = true;

            return not(alloc);
        } else // then always allocate
            return false;
    }
    if (arg) { // only allocate for the first 8 args
        return get_arg_id(arg) > ARG_MAX_R;
    } else
        assert(false && "only instruction and argument's LiveInterval exits");
}

void
RegAllocator::reset(Function *func) {
    cur_func = func;
    regmap.clear();
    active.clear();
    for_each(used, used + R + 1, [](bool &u) { u = false; });
}

int
RegAllocator::ReserveForArg(const LVITS &liveints) {
    auto args = cur_func->get_args();
    auto it_int = liveints.begin();
    auto it_arg = args.begin();
    int reg;
    for (reg = 1; reg <= args.size() and reg <= ARG_MAX_R; ++reg) {
        auto arg = *it_arg;
        if (not(FLOAT ^ arg->get_type()->is_float_type())) {
            auto liveint = *it_int;
            assert(arg == liveint.second &&
                   "arg should be in order in liveints");

            used[reg] = true;
            regmap[arg] = reg;
            active.insert(liveint);
        }
        ++it_arg, ++it_int;
    }
    return reg;
}

void
RegAllocator::LinearScan(const LVITS &liveints, Function *func) {
    reset(func);
    ReserveForArg(liveints);
    int reg;
    for (auto liveint : liveints) {
        if (FLOAT ^ liveint.second->get_type()->is_float_type())
            continue;
        if (dynamic_cast<Argument *>(liveint.second)) {
            continue;
        }
        if (no_reg_alloca(liveint.second))
            continue;
        ExpireOldIntervals(liveint);
        if (active.size() == R)
            SpillAtInterval(liveint);
        else {
            for (reg = 1; reg <= R and used[reg]; ++reg)
                ;
            if (reg == 16) {
                for (auto [interval, v] : active) {
                    cout << "already allocated: " << v->get_name() << " ~ "
                         << regmap.at(v) << endl;
                }
                assert(false);
            }
            used[reg] = true;
            regmap[liveint.second] = reg;
            active.insert(liveint);
        }
    }
}

void
RegAllocator::ExpireOldIntervals(LiveInterval liveint) {
    auto it = active.begin();
    for (; it != active.end() and it->first.j < liveint.first.i; ++it)
        used[regmap.at(it->second)] = false;
    active.erase(active.begin(), it);
}

void
RegAllocator::SpillAtInterval(LiveInterval liveint) {
    auto spill = *active.rbegin();
    if (dynamic_cast<Argument *>(spill.second))
        return;
    if (spill.first.j > liveint.first.j) {
        // cancel reg allocation for spill
        regmap[liveint.second] = regmap.at(spill.second);
        active.insert(liveint);

        active.erase(spill);
        regmap.erase(spill.second);
    }
}