#include "LoopUnroll.hpp"

#include "BasicBlock.h"
#include "Constant.h"
#include "Function.h"
#include "Instruction.h"
#include "Type.h"

#include <vector>

using std::find;

using namespace Graph;
using namespace Analysis;

#define CONSTINT(x) ConstantInt::get(x, m_)
#define CONSTFP(x) ConstantFP::get(x, m_)
#define op(instr, i) right_v(instr.get_operand(i))

void
LoopUnroll::run() {
    for (auto &_f : m_->get_functions()) {
        if (_f.is_declaration())
            continue;
        auto func = &_f;
        // cout << func->get_name() << endl;
        auto belist = detect_back(func);
        auto sloops = check_sloops(belist);
        cout << "get simple loops for function " << func->get_name() << ":\n";
        for (auto sl : sloops) {
            cout << "\t";
            for (auto p : sl)
                cout << p->get_name() << " ";
            cout << "\n";
        }
        for (auto sl : sloops) {
            unroll_loop(sl);
        }
    }
}
void
LoopUnroll::unroll_loop(SimpleLoop &sl) {
    CountedLoop cl(sl);
    auto func = sl[0]->get_parent();

    switch (cl.Type) {
        case CountedLoop::UNDEF:
            return;
        case CountedLoop::INT:
            cout << "Get CountedLoop in function " << func->get_name()
                 << ":\n\t"
                 << "initial value: " << cl.initial.v
                 << ", stop value: " << cl.stop.v << "\n\t"
                 << "delta: " << cl.delta->print() << "\n\t"
                 << "loop times: " << cl.count << endl;
            break;
        case CountedLoop::FLOAT:
            cout << "Get CountedLoop in function " << func->get_name()
                 << ":\n\t"
                 << "initial value: " << cl.initial.fv
                 << ", stop value: " << cl.stop.fv << "\n\t"
                 << "delta: " << cl.delta->print() << "\n\t"
                 << "loop times: " << cl.count << endl;
            break;
    }

    // get pre block and succ block
    auto b = *sl.begin();
    auto e = *sl.rbegin();
    BasicBlock *pre, *succ;
    for (auto p : b->get_pre_basic_blocks()) {
        if (p != e) {
            pre = p;
            break;
        }
    }
    for (auto s : b->get_succ_basic_blocks()) {
        if (find(sl.begin(), sl.end(), s) == sl.end()) {
            succ = s;
            break;
        }
    }

    auto BB = BasicBlock::create(m_, "", func);
    auto nb = BB;

    // unroll loop: copy instructions
    bool init = true;
    for (cl.new_emulate(); cl.judge(); cl.next()) {
        for (auto bb : sl) {
            for (auto &instr : bb->get_instructions()) {
                BB = copy_instruction(instr, bb, BB, pre, succ, sl, cl, init);
            }
        }
        init = false;
    }
    for (auto &instr : b->get_instructions())
        BB = copy_instruction(instr, b, BB, pre, succ, sl, cl, false);
    BranchInst::create_br(succ, BB);

    // correct pre's br
    auto pre_br =
        dynamic_cast<BranchInst *>(&*pre->get_instructions().rbegin());
    if (pre_br->is_cond_br()) {
        assert(pre_br->get_operand(1) == b or pre_br->get_operand(2) == b);
        if (pre_br->get_operand(1) == b)
            pre_br->set_operand(1, nb);
        else
            pre_br->set_operand(2, nb);
    } else {
        assert(pre_br->get_operand(0) == b);
        pre_br->set_operand(0, nb);
    }
    /* new links */
    // nb & pre
    nb->add_pre_basic_block(pre);
    pre->add_succ_basic_block(nb);
    // BB & succ: br maintain the links
    /* succ->add_pre_basic_block(BB);
     * BB->add_succ_basic_block(succ); */

    // old links: remove links from old graph
    pre->remove_succ_basic_block(b);
    succ->remove_pre_basic_block(b);
    // remove blocks in simpleloop from func->get_basic_blocks()
    for (auto b : sl)
        func->get_basic_blocks().remove(b);
    // neg block's pre blocks

    // replace use
    for (auto &bb : func->get_basic_blocks()) {
        for (auto &instr : bb.get_instructions()) {
            for (int i = 0; i < instr.get_num_operand(); ++i) {
                if (old2new.find(instr.get_operand(i)) != old2new.end()) {
                    instr.set_operand(i, old2new.at(instr.get_operand(i)));
                }
            }
        }
    }
}

CountedLoop::CountedLoop(const Graph::SimpleLoop &sl) : Type(UNDEF), count(0) {
    auto b = sl.front();
    auto e = sl.back();

    Value *i, *control_op;
    PhiInst *phi;

    // In `p`, get stop number(const)
    auto rit = b->get_instructions().rbegin();
    assert(dynamic_cast<BranchInst *>(&*rit) &&
           "The end instruction of a block should be branch");
    i = (rit++)->get_operand(0);
    assert(i == &*rit);
    if (dynamic_cast<CmpInst *>(&*rit)) {
        assert(static_cast<CmpInst *>(&*rit)->get_cmp_op() == CmpInst::NE &&
               "should be neq 0");
    } else if (dynamic_cast<FCmpInst *>(&*rit)) {
        assert(static_cast<FCmpInst *>(&*rit)->get_cmp_op() == FCmpInst::NE &&
               "should be neq 0");
    } else
        assert(false && "should not have this case");
    control = dynamic_cast<Instruction *>(&*rit); // only `ne` case
    i = (rit++)->get_operand(0);
    if (dynamic_cast<ZextInst *>(&*rit)) {
        assert(i == &*rit);
        i = (rit++)->get_operand(0);
        control = dynamic_cast<Instruction *>(&*rit);
        assert(i == &*rit &&
               (dynamic_cast<CmpInst *>(control) or
                dynamic_cast<FCmpInst *>(control)) &&
               "cmp or fcmp");
    }
    if (dynamic_cast<Constant *>(control->get_operand(0)) or
        dynamic_cast<Constant *>(control->get_operand(1))) {
        if (dynamic_cast<FCmpInst *>(&*control)) {
            Type = FLOAT;
            auto constfloat =
                dynamic_cast<ConstantFP *>(control->get_operand(1));
            assert(constfloat &&
                   "the case const at operand(0) not implemented");
            stop.fv = constfloat->get_value();
        } else {
            Type = INT;
            auto constint =
                dynamic_cast<ConstantInt *>(control->get_operand(1));
            assert(constint && "the case const at operand(0) not implemented");
            stop.v = constint->get_value();
        }
    } else
        goto can_not_count;

    // get control value and initial value
    control_op = control->get_operand(0);
    phi = dynamic_cast<PhiInst *>(control_op);
    if (phi == nullptr)
        goto can_not_count;
    assert(phi->get_parent() == b && "unexpected structure for while block");
    assert(phi->get_operand(3) == e && "the orther case not implemented");
    if (dynamic_cast<Constant *>(phi->get_operand(0))) {
        switch (Type) {
            case INT:
                initial.v = static_cast<ConstantInt *>(phi->get_operand(0))
                                ->get_value();
                break;
            case FLOAT:
                initial.fv =
                    static_cast<ConstantFP *>(phi->get_operand(0))->get_value();
                break;
            case UNDEF:
                assert(false);
                break;
        }
    } else
        goto can_not_count;

    // get delta, maybe `control` op `const` or `const` op `control`
    delta = dynamic_cast<BinaryInst *>(phi->get_operand(2));
    if (delta == nullptr)
        goto can_not_count;
    if (delta->get_operand(0) != control_op and
        delta->get_operand(1) != control_op)
        goto can_not_count;
    if (delta->get_operand(0) == control_op) {
        reverse = false;
        if (dynamic_cast<Constant *>(delta->get_operand(1)) == nullptr)
            goto can_not_count;
    } else { // control at [1]
        reverse = true;
        if (dynamic_cast<Constant *>(delta->get_operand(0)) == nullptr)
            goto can_not_count;
    }

    // check correctness

    // count loop
    new_emulate();
    for (count = 0; judge(); ++count) {
        if (count > Threshold)
            goto can_not_count;
        next();
    }

    return;
can_not_count:
    Type = UNDEF;
}

bool
CountedLoop::judge() {
    bool flag;
    // cycle judge
    switch (Type) {
        case INT: {
            switch (static_cast<CmpInst *>(control)->get_cmp_op()) {
                case CmpInst::EQ:
                    flag = emulate.v == stop.v;
                    break;
                case CmpInst::NE:
                    flag = emulate.v != stop.v;
                    break;
                case CmpInst::GT:
                    flag = emulate.v > stop.v;
                    break;
                case CmpInst::GE:
                    flag = emulate.v >= stop.v;
                    break;
                case CmpInst::LT:
                    flag = emulate.v < stop.v;
                    break;
                case CmpInst::LE:
                    flag = emulate.v <= stop.v;
                    break;
            }
            break;
        }
        case FLOAT: {
            switch (static_cast<FCmpInst *>(control)->get_cmp_op()) {
                case FCmpInst::EQ:
                    flag = emulate.fv == stop.fv;
                    break;
                case FCmpInst::NE:
                    flag = emulate.fv != stop.fv;
                    break;
                case FCmpInst::GT:
                    flag = emulate.fv > stop.fv;
                    break;
                case FCmpInst::GE:
                    flag = emulate.fv >= stop.fv;
                    break;
                case FCmpInst::LT:
                    flag = emulate.fv < stop.fv;
                    break;
                case FCmpInst::LE:
                    flag = emulate.fv <= stop.fv;
                    break;
            }
            break;
        }
        case UNDEF:
            assert(false);
    }
    return flag;
}

void
CountedLoop::next() {
    switch (Type) {
        case INT: {
            int op2 =
                static_cast<ConstantInt *>(delta->get_operand(reverse ? 0 : 1))
                    ->get_value();
            switch (delta->get_instr_type()) {
                case Instruction::add:
                    emulate.v += op2;
                    break;
                case Instruction::sub:
                    emulate.v = (reverse ? -1 : 1) * (emulate.v - op2);
                    break;
                case Instruction::mul:
                    emulate.v *= op2;
                    break;
                case Instruction::sdiv:
                    emulate.v = (reverse ? op2 / emulate.v : emulate.v / op2);
                    break;
                default:
                    assert(false && "not implemented");
                    break;
            }
            break;
        }
        case FLOAT: {
            float op2 =
                static_cast<ConstantFP *>(delta->get_operand(reverse ? 0 : 1))
                    ->get_value();
            switch (delta->get_instr_type()) {
                case Instruction::fadd:
                    emulate.fv += op2;
                    break;
                case Instruction::fsub:
                    emulate.fv = (reverse ? -1 : 1) * (emulate.fv - op2);
                    break;
                case Instruction::fmul:
                    emulate.fv *= op2;
                    break;
                case Instruction::fdiv:
                    emulate.fv =
                        (reverse ? (op2 / emulate.fv) : (emulate.fv / op2));
                    break;
                default:
                    assert(false && "not implemented");
                    break;
            }
            break;
        }
        case UNDEF:
            assert(false);
    }
}

vector<SimpleLoop>
LoopUnroll::check_sloops(const BackEdgeList &belist) const {
    vector<SimpleLoop> sloops;
    for (auto [e, b] : belist) {
        SimpleLoop sl;
        if (b->get_succ_basic_blocks().size() != 2 or
            b->get_pre_basic_blocks().size() != 2)
            continue;
        bool flag = true;
        // the start node should have 2*2 degree, and others have 1*1 degree
        // one exception: br to neg_idx_except
        for (auto p = e; p != b; p = *p->get_pre_basic_blocks().begin()) {
            if (p->get_pre_basic_blocks().size() != 1) {
                flag = false;
                break;
            }
            auto succ_bbs = p->get_succ_basic_blocks();
            if (succ_bbs.size() != 1) {
                assert(succ_bbs.size() == 2);
                flag = false;
                for (auto bb : succ_bbs) {
                    if (is_neg_block(bb)) {
                        flag = true;
                        break;
                    }
                }
            }
            if (not flag)
                break;
            sl.insert(sl.begin(), p);
        }

        if (not flag)
            continue;
        sl.insert(sl.begin(), b);
        if (flag)
            sloops.push_back(sl);
    }
    return sloops;
}

BackEdgeList
LoopUnroll::detect_back(Function *func) {
    BackEdgeSearcher search(func->get_entry_block());
    return search.edges;
}

void
BackEdgeSearcher::dfsrun(BasicBlock *bb) {
    vis[bb] = true;
    path.push_back(bb);
    for (auto succ : bb->get_succ_basic_blocks()) {
        if (vis[succ]) {
            string type;
            Edge edge(bb, succ);
            if (find(path.rbegin(), path.rend(), succ) == path.rend()) {
                type = "cross-edge";
            } else {
                type = "back-edge";
                edges.push_back(edge);
            }
            /* cout << "find " << type << ": " << LoopUnroll::str(edge)
             *      << "\n"; */
        } else
            dfsrun(succ);
    }

    path.pop_back();
}

BasicBlock *
LoopUnroll::copy_instruction(Instruction &instr,
                             BasicBlock *bb,
                             BasicBlock *BB,
                             BasicBlock *pre,
                             BasicBlock *succ,
                             SimpleLoop &sl,
                             CountedLoop &cl,
                             bool init) {
    Value *n;
    auto b = *sl.begin();
    auto e = *sl.rbegin();
    auto func = b->get_parent();
    switch (instr.get_instr_type()) {
        case Instruction::ret:
            if (instr.get_num_operand() == 0)
                ReturnInst::create_void_ret(BB);
            else
                ReturnInst::create_ret(op(instr, 0), BB);
            break;

        case Instruction::br:
            if (bb == b) { // do nothing
                assert(instr.get_operand(1) == succ or
                       instr.get_operand(2) == succ && "unexpected structure");
            } else {
                if (instr.get_num_operand() == 3) {
                    // we know the neg block is always at operand(1)
                    auto newBB = BasicBlock::create(m_, "", func);
                    BranchInst::create_cond_br(
                        op(instr, 0),
                        static_cast<BasicBlock *>(op(instr, 1)),
                        newBB,
                        BB);
                    BB = newBB;
                } else { // do nothing
                    assert(instr.get_num_operand() == 1 and
                           instr.get_operand(0) == b);
                }
            }
            break;

        case Instruction::add:
            n = BinaryInst::create_add(op(instr, 0), op(instr, 1), BB, m_);
            break;

        case Instruction::sub:
            n = BinaryInst::create_sub(op(instr, 0), op(instr, 1), BB, m_);
            break;

        case Instruction::mul:
            n = BinaryInst::create_mul(op(instr, 0), op(instr, 1), BB, m_);
            break;

        case Instruction::sdiv:
            n = BinaryInst::create_sdiv(op(instr, 0), op(instr, 1), BB, m_);
            break;

        case Instruction::fadd:
            n = BinaryInst::create_fadd(op(instr, 0), op(instr, 1), BB, m_);
            break;

        case Instruction::fsub:
            n = BinaryInst::create_fsub(op(instr, 0), op(instr, 1), BB, m_);
            break;

        case Instruction::fmul:
            n = BinaryInst::create_fmul(op(instr, 0), op(instr, 1), BB, m_);
            break;

        case Instruction::fdiv:
            n = BinaryInst::create_fdiv(op(instr, 0), op(instr, 1), BB, m_);
            break;

        case Instruction::alloca:
            n = AllocaInst::create_alloca(
                static_cast<AllocaInst *>(&instr)->get_alloca_type(), BB);
            break;

        case Instruction::load:
            n = LoadInst::create_load(
                static_cast<LoadInst *>(&instr)->get_load_type(),
                op(instr, 0),
                BB);
            break;

        case Instruction::store:
            StoreInst::create_store(op(instr, 0), op(instr, 1), BB);
            break;

        case Instruction::cmp:
            n = CmpInst::create_cmp(
                static_cast<CmpInst *>(&instr)->get_cmp_op(),
                op(instr, 0),
                op(instr, 1),
                BB,
                m_);
            break;

        case Instruction::fcmp:
            n = FCmpInst::create_fcmp(
                static_cast<FCmpInst *>(&instr)->get_cmp_op(),
                op(instr, 0),
                op(instr, 1),
                BB,
                m_);
            break;

        case Instruction::phi: { // make it a copy
            assert(bb == b);
            Value *v, *zero;
            // zero
            switch (cl.Type) {
                case CountedLoop::INT:
                    zero = CONSTINT(0);
                    break;
                case CountedLoop::FLOAT:
                    zero = CONSTFP(0);
                    break;
                case CountedLoop::UNDEF:
                    assert(false);
            }
            // v
            if (&instr == cl.control) {
                switch (cl.Type) {
                    case CountedLoop::INT:
                        v = CONSTINT(cl.emulate.v);
                        break;
                    case CountedLoop::FLOAT:
                        v = CONSTFP(cl.emulate.fv);
                        break;
                    case CountedLoop::UNDEF:
                        assert(false);
                }
            } else if (init) {
                assert(instr.get_operand(1) == pre);
                v = op(instr, 0);
            } else {
                assert(instr.get_operand(3) == e);
                v = op(instr, 2);
            }

            switch (cl.Type) {
                case CountedLoop::INT:
                    n = BinaryInst::create_add(zero, v, BB, m_);
                    break;
                case CountedLoop::FLOAT:
                    n = BinaryInst::create_fadd(zero, v, BB, m_);
                    break;
                case CountedLoop::UNDEF:
                    assert(false);
            }
            break;
        }

        case Instruction::call: {
            vector<Value *> args;
            for (int i = 1; i < instr.get_num_operand(); ++i)
                args.push_back(op(instr, i));
            n = CallInst::create(
                static_cast<Function *>(op(instr, 0)), args, BB);
            break;
        }

        case Instruction::getelementptr: {
            vector<Value *> idxs;
            for (int i = 1; i < instr.get_num_operand(); ++i)
                idxs.push_back(op(instr, i));
            n = GetElementPtrInst::create_gep(op(instr, 0), idxs, BB);
            break;
        }

        case Instruction::zext:
            n = ZextInst::create_zext(
                op(instr, 0),
                static_cast<ZextInst *>(&instr)->get_dest_type(),
                BB);
            break;

        case Instruction::fptosi:
            n = FpToSiInst::create_fptosi(
                op(instr, 0),
                static_cast<FpToSiInst *>(&instr)->get_dest_type(),
                BB);
            break;

        case Instruction::sitofp:
            n = SiToFpInst::create_sitofp(
                op(instr, 0),
                static_cast<SiToFpInst *>(&instr)->get_dest_type(),
                BB);
            break;
    }
    if (not instr.is_void())
        old2new[&instr] = n;
    return BB;
}