#include "IRprinter.hpp"
#include "Instruction.hpp"
#include <cassert>
#include "Constant.hpp"
#include "Function.hpp"
#include "GlobalVariable.hpp"

std::string print_as_op(Value *v, bool print_ty) {
    std::string op_ir;
    if (print_ty) {
        op_ir += v->get_type()->print();
        op_ir += " ";
    }

    if (dynamic_cast<GlobalVariable *>(v) || dynamic_cast<Function*>(v)) {
        op_ir += "@" + v->get_name();
    } else if (dynamic_cast<Constant *>(v)) {
        op_ir += v->print();
    } else {
        op_ir += "%" + v->get_name();
    }

    return op_ir;
}

static std::string safe_print_op_as_op(const Instruction* v, unsigned idx, bool print_ty) {
    if (v->get_num_operand() <= idx) return  "op<unknown>";
    Value* val = v->get_operand(idx);
    if (val == nullptr) return "op<null>";
    std::string op_ir;
    if (print_ty) {
        auto ty = val->get_type();
        if (ty == nullptr) op_ir += "t<null> ";
        else
        {
            op_ir += ty->safe_print();
            op_ir += " ";
        }
    }

    if (dynamic_cast<GlobalVariable*>(val) || dynamic_cast<Function*>(val)) {
        op_ir += "@" + val->safe_get_name_or_ptr();
    }
    else if (auto constant = dynamic_cast<Constant*>(val)) {
        op_ir += constant->safe_print_help();
    }
    else {
        op_ir += "%" + val->safe_get_name_or_ptr();
    }

    return op_ir;
}

static std::string safe_print_as_op(const Value* v, bool print_ty) {
    std::string op_ir;
    if (print_ty) {
        auto ty = v->get_type();
        if (ty == nullptr) op_ir += "t<null> ";
        else
        {
            op_ir += ty->safe_print();
            op_ir += " ";
        }
    }

    if (dynamic_cast<const GlobalVariable*>(v) || dynamic_cast<const Function*>(v)) {
        op_ir += "@" + v->safe_get_name_or_ptr();
    }
    else if (auto constant = dynamic_cast<const Constant*>(v)) {
        op_ir += constant->safe_print_help();
    }
    else {
        op_ir += "%" + v->safe_get_name_or_ptr();
    }

    return op_ir;
}

static const char* safe_print_instr_op_name(Instruction::OpID id) {
    switch (id) {
    case Instruction::ret:
        return "ret";
    case Instruction::br:
        return "br";
    case Instruction::add:
        return "add";
    case Instruction::sub:
        return "sub";
    case Instruction::mul:
        return "mul";
    case Instruction::sdiv:
        return "sdiv";
    case Instruction::fadd:
        return "fadd";
    case Instruction::fsub:
        return "fsub";
    case Instruction::fmul:
        return "fmul";
    case Instruction::fdiv:
        return "fdiv";
    case Instruction::alloca:
        return "alloca";
    case Instruction::load:
        return "load";
    case Instruction::store:
        return "store";
    case Instruction::ge:
        return "sge";
    case Instruction::gt:
        return "sgt";
    case Instruction::le:
        return "sle";
    case Instruction::lt:
        return "slt";
    case Instruction::eq:
        return "eq";
    case Instruction::ne:
        return "ne";
    case Instruction::fge:
        return "uge";
    case Instruction::fgt:
        return "ugt";
    case Instruction::fle:
        return "ule";
    case Instruction::flt:
        return "ult";
    case Instruction::feq:
        return "ueq";
    case Instruction::fne:
        return "une";
    case Instruction::phi:
        return "phi";
    case Instruction::call:
        return "call";
    case Instruction::getelementptr:
        return "getelementptr";
    case Instruction::zext:
        return "zext";
    case Instruction::fptosi:
        return "fptosi";
    case Instruction::sitofp:
        return "sitofp";
    }
    return "inst<unknown>";
}

std::string print_instr_op_name(Instruction::OpID id) {
    switch (id) {
    case Instruction::ret:
        return "ret";
    case Instruction::br:
        return "br";
    case Instruction::add:
        return "add";
    case Instruction::sub:
        return "sub";
    case Instruction::mul:
        return "mul";
    case Instruction::sdiv:
        return "sdiv";
    case Instruction::fadd:
        return "fadd";
    case Instruction::fsub:
        return "fsub";
    case Instruction::fmul:
        return "fmul";
    case Instruction::fdiv:
        return "fdiv";
    case Instruction::alloca:
        return "alloca";
    case Instruction::load:
        return "load";
    case Instruction::store:
        return "store";
    case Instruction::ge:
        return "sge";
    case Instruction::gt:
        return "sgt";
    case Instruction::le:
        return "sle";
    case Instruction::lt:
        return "slt";
    case Instruction::eq:
        return "eq";
    case Instruction::ne:
        return "ne";
    case Instruction::fge:
        return "uge";
    case Instruction::fgt:
        return "ugt";
    case Instruction::fle:
        return "ule";
    case Instruction::flt:
        return "ult";
    case Instruction::feq:
        return "ueq";
    case Instruction::fne:
        return "une";
    case Instruction::phi:
        return "phi";
    case Instruction::call:
        return "call";
    case Instruction::getelementptr:
        return "getelementptr";
    case Instruction::zext:
        return "zext";
    case Instruction::fptosi:
        return "fptosi";
    case Instruction::sitofp:
        return "sitofp";
    }
    assert(false && "Must be bug");
}

std::string Instruction::safe_print() const
{
    switch (op_id_)
    {
        case ret:
            {
                std::string instr_ir;
                instr_ir += get_instr_op_name();
                instr_ir += " ";
                if (!dynamic_cast<const ReturnInst*>(this)->is_void_ret()) {
                    instr_ir += this->get_operand(0)->get_type()->print();
                    instr_ir += " ";
                    instr_ir += print_as_op(this->get_operand(0), false);
                }
                else {
                    instr_ir += "void";
                }
                return instr_ir;
            }
        case br:
            {
                std::string instr_ir;
                instr_ir += safe_print_instr_op_name(get_instr_type());
                instr_ir += " ";
                instr_ir += safe_print_op_as_op(this, 0, true);
                if (dynamic_cast<const BranchInst*>(this)->is_cond_br()) {
                    instr_ir += ", ";
                    instr_ir += safe_print_op_as_op(this, 1, true);
                    instr_ir += ", ";
                    instr_ir += safe_print_op_as_op(this, 2, true);
                }
                return instr_ir;
            }
        case add:
        case sub:
        case mul:
        case sdiv:
        case fadd:
        case fsub:
        case fmul:
        case fdiv:
            {
                std::string instr_ir;
                instr_ir += safe_print_as_op(this, true);
                instr_ir += " = ";
                instr_ir += safe_print_instr_op_name(op_id_);
                instr_ir += " ";
                instr_ir += safe_print_op_as_op(this, 0, true);
                instr_ir += ", ";
                instr_ir += safe_print_op_as_op(this, 1, true);
                return instr_ir;
            }
        case alloca:
            {
                std::string instr_ir;
                instr_ir += safe_print_as_op(this, true);
                instr_ir += " = ";
                instr_ir += get_instr_op_name();
                instr_ir += " ";
                auto ty = get_type();
                if (ty == nullptr || !ty->is_pointer_type())
                {
                    instr_ir += "t<unknown>";
                }
                else
                {
                    auto ty2 = ty->get_pointer_element_type();
                    if (ty2 == nullptr)
                        instr_ir += "t<null>";
                    else instr_ir += ty2->safe_print();
                }
                return instr_ir;
            }
        case load:
            {
                std::string instr_ir;
                instr_ir += safe_print_as_op(this, true);
                instr_ir += " = ";
                instr_ir += safe_print_instr_op_name(get_instr_type());
                instr_ir += " ";
                instr_ir += safe_print_op_as_op(this, 0, true);
                return instr_ir;
            }
        case store:
            {
                std::string instr_ir;
                instr_ir += safe_print_instr_op_name(get_instr_type());
                instr_ir += " ";
                instr_ir += safe_print_op_as_op(this, 0, true);
                instr_ir += ", ";
                instr_ir += safe_print_op_as_op(this, 1, true);
                return instr_ir;
            }
        case ge:
        case gt:
        case le:
        case lt:
        case eq:
        case ne:
            {
                std::string instr_ir;
                instr_ir += safe_print_as_op(this, true);
                instr_ir += " = icmp ";
                instr_ir += safe_print_instr_op_name(op_id_);
                instr_ir += " ";
                instr_ir += safe_print_op_as_op(this, 0, true);
                instr_ir += ", ";
                instr_ir += safe_print_op_as_op(this, 1, true);
                return instr_ir;
            }
        case fge:
        case fgt:
        case fle:
        case flt:
        case feq:
        case fne:
            {
                std::string instr_ir;
                instr_ir += safe_print_as_op(this, true);
                instr_ir += " = fcmp ";
                instr_ir += safe_print_instr_op_name(op_id_);
                instr_ir += " ";
                instr_ir += safe_print_op_as_op(this, 0, true);
                instr_ir += ", ";
                instr_ir += safe_print_op_as_op(this, 1, true);
                return instr_ir;
            }
        case phi:
            {
                std::string instr_ir;
                instr_ir += safe_print_as_op(this, true);
                instr_ir += " = ";
                instr_ir += safe_print_instr_op_name(get_instr_type());
                instr_ir += " ";
                for (unsigned i = 0; i < this->get_num_operand() / 2; i++) {
                    if (i > 0)
                        instr_ir += ", ";
                    instr_ir += "[ ";
                    instr_ir += safe_print_op_as_op(this, i << 1, true);
                    instr_ir += ", ";
                    instr_ir += safe_print_op_as_op(this, (i << 1) + 1, true);
                    instr_ir += " ]";
                }
                return instr_ir;
            }
        case call:
            {
                std::string instr_ir;
                auto ty = get_type();
                if (ty == nullptr || !ty->is_void_type())
                {
                    instr_ir += safe_print_as_op(this, true);
                    instr_ir += " = ";
                }
                instr_ir += safe_print_instr_op_name(op_id_);
                instr_ir += " ";

                if (get_num_operand() == 0)
                {
                    return instr_ir + "<error no operand>";
                }

                auto op0 = get_operand(0);

                if (op0 == nullptr)
                {
                    instr_ir += "ft<null> ";
                }
                else
                {
                    if (auto fty = dynamic_cast<FunctionType*>(op0->get_type()))
                    {
                        auto ty3 = fty->get_return_type();
                        if (ty3 == nullptr)
                            instr_ir += "t<null> ";
                        else
                            instr_ir += ty3->safe_print();
                    }
                    else
                    {
                        instr_ir += "<error not func> ";
                        instr_ir += safe_print_as_op(op0, true);
                        return instr_ir;
                    }
                }

                instr_ir += " ";
                instr_ir += safe_print_op_as_op(this, 0, false);
                instr_ir += "(";
                for (unsigned i = 1; i < this->get_num_operand(); i++) {
                    if (i > 1)
                        instr_ir += ", ";
                    instr_ir += safe_print_op_as_op(this, i, true);
                }
                instr_ir += ")";
                return instr_ir;
            }
        case getelementptr:
            {
                std::string instr_ir;
                instr_ir += safe_print_as_op(this, true);
                instr_ir += " = ";
                instr_ir += safe_print_instr_op_name(get_instr_type());
                instr_ir += " ";
                for (unsigned i = 0; i < this->get_num_operand(); i++) {
                    if (i > 0)
                        instr_ir += ", ";
                    instr_ir += safe_print_op_as_op(this, i, true);
                }
                return instr_ir;
            }
        case zext:
        case fptosi:
        case sitofp:
            {
                std::string instr_ir;
                instr_ir += safe_print_as_op(this, true);
                instr_ir += " = ";
                instr_ir += safe_print_instr_op_name(get_instr_type());
                instr_ir += " ";
                instr_ir += safe_print_op_as_op(this, 0, true);
                return instr_ir;
            }
    }
    std::string str;
    str += safe_print_as_op(this, true);
    str += " = unknown inst";
    return str;
}

template <class BinInst>
static std::string print_binary_inst(const BinInst &inst) {
    std::string instr_ir;
    instr_ir += "%";
    instr_ir += inst.get_name();
    instr_ir += " = ";
    instr_ir += inst.get_instr_op_name();
    instr_ir += " ";
    instr_ir += inst.get_operand(0)->get_type()->print();
    instr_ir += " ";
    instr_ir += print_as_op(inst.get_operand(0), false);
    instr_ir += ", ";
    if (inst.get_operand(0)->get_type() == inst.get_operand(1)->get_type()) {
        instr_ir += print_as_op(inst.get_operand(1), false);
    } else {
        instr_ir += print_as_op(inst.get_operand(1), true);
    }
    return instr_ir;
}
std::string IBinaryInst::print() { return print_binary_inst(*this); }

std::string FBinaryInst::print() { return print_binary_inst(*this); }

template <class CMP>
static std::string print_cmp_inst(const CMP &inst) {
    std::string cmp_type;
    if (inst.is_cmp())
        cmp_type = "icmp";
    else if (inst.is_fcmp())
        cmp_type = "fcmp";
    else
        assert(false && "Unexpected case");
    std::string instr_ir;
    instr_ir += "%";
    instr_ir += inst.get_name();
    instr_ir += " = " + cmp_type + " ";
    instr_ir += inst.get_instr_op_name();
    instr_ir += " ";
    instr_ir += inst.get_operand(0)->get_type()->print();
    instr_ir += " ";
    instr_ir += print_as_op(inst.get_operand(0), false);
    instr_ir += ", ";
    if (inst.get_operand(0)->get_type() == inst.get_operand(1)->get_type()) {
        instr_ir += print_as_op(inst.get_operand(1), false);
    } else {
        instr_ir += print_as_op(inst.get_operand(1), true);
    }
    return instr_ir;
}
std::string ICmpInst::print() { return print_cmp_inst(*this); }

std::string FCmpInst::print() { return print_cmp_inst(*this); }

std::string CallInst::print() {
    std::string instr_ir;
    if (!this->is_void()) {
        instr_ir += "%";
        instr_ir += this->get_name();
        instr_ir += " = ";
    }
    instr_ir += get_instr_op_name();
    instr_ir += " ";
    instr_ir += this->get_function_type()->get_return_type()->print();

    instr_ir += " ";
    assert(dynamic_cast<Function *>(this->get_operand(0)) &&
           "Wrong call operand function");
    instr_ir += print_as_op(this->get_operand(0), false);
    instr_ir += "(";
    for (unsigned i = 1; i < this->get_num_operand(); i++) {
        if (i > 1)
            instr_ir += ", ";
        instr_ir += this->get_operand(i)->get_type()->print();
        instr_ir += " ";
        instr_ir += print_as_op(this->get_operand(i), false);
    }
    instr_ir += ")";
    return instr_ir;
}

std::string BranchInst::print() {
    std::string instr_ir;
    instr_ir += get_instr_op_name();
    instr_ir += " ";
    instr_ir += print_as_op(this->get_operand(0), true);
    if (is_cond_br()) {
        instr_ir += ", ";
        instr_ir += print_as_op(this->get_operand(1), true);
        instr_ir += ", ";
        instr_ir += print_as_op(this->get_operand(2), true);
    }
    return instr_ir;
}

std::string ReturnInst::print() {
    std::string instr_ir;
    instr_ir += get_instr_op_name();
    instr_ir += " ";
    if (!is_void_ret()) {
        instr_ir += this->get_operand(0)->get_type()->print();
        instr_ir += " ";
        instr_ir += print_as_op(this->get_operand(0), false);
    } else {
        instr_ir += "void";
    }

    return instr_ir;
}

std::string GetElementPtrInst::print() {
    std::string instr_ir;
    instr_ir += "%";
    instr_ir += this->get_name();
    instr_ir += " = ";
    instr_ir += get_instr_op_name();
    instr_ir += " ";
    assert(this->get_operand(0)->get_type()->is_pointer_type());
    instr_ir +=
        this->get_operand(0)->get_type()->get_pointer_element_type()->print();
    instr_ir += ", ";
    for (unsigned i = 0; i < this->get_num_operand(); i++) {
        if (i > 0)
            instr_ir += ", ";
        instr_ir += this->get_operand(i)->get_type()->print();
        instr_ir += " ";
        instr_ir += print_as_op(this->get_operand(i), false);
    }
    return instr_ir;
}

std::string StoreInst::print() {
    std::string instr_ir;
    instr_ir += get_instr_op_name();
    instr_ir += " ";
    instr_ir += this->get_operand(0)->get_type()->print();
    instr_ir += " ";
    instr_ir += print_as_op(this->get_operand(0), false);
    instr_ir += ", ";
    instr_ir += print_as_op(this->get_operand(1), true);
    return instr_ir;
}

std::string LoadInst::print() {
    std::string instr_ir;
    instr_ir += "%";
    instr_ir += this->get_name();
    instr_ir += " = ";
    instr_ir += get_instr_op_name();
    instr_ir += " ";
    assert(this->get_operand(0)->get_type()->is_pointer_type());
    instr_ir +=
        this->get_operand(0)->get_type()->get_pointer_element_type()->print();
    instr_ir += ",";
    instr_ir += " ";
    instr_ir += print_as_op(this->get_operand(0), true);
    return instr_ir;
}

std::string AllocaInst::print() {
    std::string instr_ir;
    instr_ir += "%";
    instr_ir += this->get_name();
    instr_ir += " = ";
    instr_ir += get_instr_op_name();
    instr_ir += " ";
    instr_ir += get_alloca_type()->print();
    return instr_ir;
}

std::string ZextInst::print() {
    std::string instr_ir;
    instr_ir += "%";
    instr_ir += this->get_name();
    instr_ir += " = ";
    instr_ir += get_instr_op_name();
    instr_ir += " ";
    instr_ir += this->get_operand(0)->get_type()->print();
    instr_ir += " ";
    instr_ir += print_as_op(this->get_operand(0), false);
    instr_ir += " to ";
    instr_ir += this->get_dest_type()->print();
    return instr_ir;
}

std::string FpToSiInst::print() {
    std::string instr_ir;
    instr_ir += "%";
    instr_ir += this->get_name();
    instr_ir += " = ";
    instr_ir += get_instr_op_name();
    instr_ir += " ";
    instr_ir += this->get_operand(0)->get_type()->print();
    instr_ir += " ";
    instr_ir += print_as_op(this->get_operand(0), false);
    instr_ir += " to ";
    instr_ir += this->get_dest_type()->print();
    return instr_ir;
}

std::string SiToFpInst::print() {
    std::string instr_ir;
    instr_ir += "%";
    instr_ir += this->get_name();
    instr_ir += " = ";
    instr_ir += get_instr_op_name();
    instr_ir += " ";
    instr_ir += this->get_operand(0)->get_type()->print();
    instr_ir += " ";
    instr_ir += print_as_op(this->get_operand(0), false);
    instr_ir += " to ";
    instr_ir += this->get_dest_type()->print();
    return instr_ir;
}

std::string PhiInst::print() {
    std::string instr_ir;
    instr_ir += "%";
    instr_ir += this->get_name();
    instr_ir += " = ";
    instr_ir += get_instr_op_name();
    instr_ir += " ";
    instr_ir += this->get_operand(0)->get_type()->print();
    instr_ir += " ";
    for (unsigned i = 0; i < this->get_num_operand() / 2; i++) {
        if (i > 0)
            instr_ir += ", ";
        instr_ir += "[ ";
        instr_ir += print_as_op(this->get_operand(2 * i), false);
        instr_ir += ", ";
        instr_ir += print_as_op(this->get_operand(2 * i + 1), false);
        instr_ir += " ]";
    }
    if (this->get_num_operand() / 2 <
        this->get_parent()->get_pre_basic_blocks().size()) {
        for (auto pre_bb : this->get_parent()->get_pre_basic_blocks()) {
            if (std::find(this->get_operands().begin(),
                          this->get_operands().end(),
                          static_cast<Value *>(pre_bb)) ==
                this->get_operands().end()) {
                // find a pre_bb is not in phi
                instr_ir += ", [ undef, " + print_as_op(pre_bb, false) + " ]";
            }
        }
    }
    return instr_ir;
}
