/*
 * 声明:本代码为 2020 秋 中国科大编译原理(李诚)课程实验参考实现。
 * 请不要以任何方式,将本代码上传到可以公开访问的站点或仓库
 */

#include "cminusf_builder.hpp"

#include "logging.hpp"

#define CONST_FP(num) ConstantFP::get((float)num, module.get())
#define CONST_INT(num) ConstantInt::get(num, module.get())

// TODO: Global Variable Declarations
// You can define global variables here
// to store state. You can expand these
// definitions if you need to.

// the latest return value
bool LV = false;
Value *cur_value = nullptr;
// function that is being built
Function *cur_fun = nullptr;

// types
Type *VOID_T;
Type *INT1_T;
Type *INT32_T;
Type *INT32PTR_T;
Type *FLOAT_T;
Type *FLOATPTR_T;

// initializer
ConstantZero *I32Initializer;
ConstantZero *FloatInitializer;

/*
 * use CMinusfBuilder::Scope to construct scopes
 * scope.enter: enter a new scope
 * scope.exit: exit current scope
 * scope.push: add a new binding to current scope
 * scope.find: find and return the value bound to the name
 */

void error_exit(std::string s) {
    LOG_ERROR << s;
    std::abort();
}

// This function makes sure that
// 1. 2 values have same type
// 2. type is either i32 or float
void CminusfBuilder::biop_type_check(Value *&lvalue, Value *&rvalue, std::string util) {
    if (Type::is_eq_type(lvalue->get_type(), rvalue->get_type())) {
        if (lvalue->get_type()->is_integer_type() or lvalue->get_type()->is_float_type()) {
            // check for i1
            if (Type::is_eq_type(lvalue->get_type(), INT1_T)) {
                lvalue = builder->create_zext(lvalue, INT32_T);
                rvalue = builder->create_zext(rvalue, INT32_T);
            }

        } else
            error_exit("not supported type cast for " + util);
        return;
    }

    // only support cast between int and float: i32, i1, float
    //
    // case that integer and float is mixed, directly cast integer to float
    if (lvalue->get_type()->is_integer_type() and rvalue->get_type()->is_float_type())
        lvalue = builder->create_sitofp(lvalue, FLOAT_T);
    else if (lvalue->get_type()->is_float_type() and rvalue->get_type()->is_integer_type())
        rvalue = builder->create_sitofp(rvalue, FLOAT_T);
    else if (lvalue->get_type()->is_integer_type() and rvalue->get_type()->is_integer_type()) {
        // case that I32 and I1 mixed
        if (Type::is_eq_type(lvalue->get_type(), INT1_T))
            lvalue = builder->create_zext(lvalue, INT32_T);
        else
            rvalue = builder->create_zext(rvalue, INT32_T);
    } else { // we only support computing among i1, i32 and float
        error_exit("not supported type cast for " + util);
    }
}

// this function makes sure value is a bool type
void CminusfBuilder::cast_to_i1(Value *&value) {
    assert(value->get_type()->is_integer_type() or value->get_type()->is_float_type());
    if (value->get_type()->is_float_type())
        // value = builder->create_fptosi(value, INT1_T);
        value = builder->create_fcmp_ne(value, CONST_FP(0));
    else if (Type::is_eq_type(value->get_type(), INT32_T))
        value = builder->create_icmp_ne(value, CONST_INT(0));
}

void CminusfBuilder::visit(ASTProgram &node) {
    VOID_T = Type::get_void_type(module.get());
    INT1_T = Type::get_int1_type(module.get());
    INT32_T = Type::get_int32_type(module.get());
    INT32PTR_T = Type::get_int32_ptr_type(module.get());
    FLOAT_T = Type::get_float_type(module.get());
    FLOATPTR_T = Type::get_float_ptr_type(module.get());
    I32Initializer = ConstantZero::get(INT32_T, builder->get_module());
    FloatInitializer = ConstantZero::get(FLOAT_T, builder->get_module());

    for (auto decl : node.declarations) {
        decl->accept(*this);
    }
}

// Done
void CminusfBuilder::visit(ASTNum &node) {
    //!TODO: This function is empty now.
    // Add some code here.
    //
    switch (node.type) {
        case TYPE_INT:
            cur_value = CONST_INT(node.i_val);
            return;
        case TYPE_FLOAT:
            cur_value = CONST_FP(node.f_val);
            return;
        default:
            error_exit("ASTNum is not int or float");
    }
}

// Done
void CminusfBuilder::visit(ASTVarDeclaration &node) {
    //!TODO: This function is empty now.
    // Add some code here.
    bool global = (builder->get_insert_block() == nullptr);
    if (node.num) {
        // declares an array
        //
        // get array size
        node.num->accept(*this);
        //
        // !no type cast here!
        if (not(node.num->type == TYPE_INT))
            error_exit("size of array has non-integer type");

        int size = node.num->i_val;
        if (size <= 0)
            error_exit("array size[" + std::to_string(size) + "] <= 0");

        switch (node.type) {
            case TYPE_INT: {
                auto I32Array_T = Type::get_array_type(INT32_T, size);
                if (global)
                    cur_value =
                        GlobalVariable::create(node.id, builder->get_module(), I32Array_T, false, I32Initializer);
                else
                    cur_value = builder->create_alloca(I32Array_T);
                break;
            }

            case TYPE_FLOAT: {
                auto FloatArray_T = Type::get_array_type(FLOAT_T, size);
                if (global)
                    cur_value =
                        GlobalVariable::create(node.id, builder->get_module(), FloatArray_T, false, FloatInitializer);
                else
                    cur_value = builder->create_alloca(FloatArray_T);
                break;
            }
            default:
                error_exit("Variable type(not array) is not int or float");
        }
        assert(cur_value->get_type()->is_pointer_type() && "IF SEE THIS: API ERROR");

    } else {
        // flat int or float type
        switch (node.type) {
            case TYPE_INT:
                if (global)
                    cur_value = GlobalVariable::create(node.id, builder->get_module(), INT32_T, false, I32Initializer);
                else
                    cur_value = builder->create_alloca(INT32_T);
                break;

            case TYPE_FLOAT:
                if (global)
                    cur_value =
                        GlobalVariable::create(node.id, builder->get_module(), FLOAT_T, false, FloatInitializer);
                else
                    cur_value = builder->create_alloca(INT32_T);
                break;
            default:
                error_exit("Variable type(not array) is not int or float");
        }
    }

    if (not scope.push(node.id, cur_value))
        error_exit("variable redefined: " + node.id);
    LOG_DEBUG << "add entry: " << node.id << " " << cur_value;
}

// Done
void CminusfBuilder::visit(ASTFunDeclaration &node) {
    FunctionType *fun_type;
    Type *ret_type;
    std::vector<Type *> param_types;
    if (node.type == TYPE_INT)
        ret_type = INT32_T;
    else if (node.type == TYPE_FLOAT)
        ret_type = FLOAT_T;
    else
        ret_type = VOID_T;

    for (auto &param : node.params) {
        //!TODO: Please accomplish param_types.
        //
        // First make function BB, which needs this param type,
        // then set_insert_point, we can call accept to gen code,
        switch (param->type) {
            case TYPE_INT:
                param_types.push_back(param->isarray ? INT32PTR_T : INT32_T);
                break;
            case TYPE_FLOAT:
                param_types.push_back(param->isarray ? FLOATPTR_T : FLOAT_T);
                break;
            case TYPE_VOID:
                if (not param_types.empty())
                    error_exit("function parameters weird");
                break;
        }
    }

    fun_type = FunctionType::get(ret_type, param_types);
    auto fun = Function::create(fun_type, node.id, module.get());
    cur_fun = fun;
    scope.push(node.id, fun);

    auto funBB = BasicBlock::create(module.get(), "entry", fun);
    builder->set_insert_point(funBB);
    scope.enter();

    std::vector<Value *> args;
    for (auto arg = fun->arg_begin(); arg != fun->arg_end(); arg++) {
        args.push_back(*arg);
    }

    for (int i = 0; i < node.params.size(); ++i) {
        //!TODO: You need to deal with params
        // and store them in the scope.
        cur_value = args[i];
        node.params[i]->accept(*this);
    }
    node.compound_stmt->accept(*this);
    // default return value
    if (builder->get_insert_block()->get_terminator() == nullptr) {
        if (cur_fun->get_return_type()->is_void_type())
            builder->create_void_ret();
        else if (cur_fun->get_return_type()->is_float_type())
            builder->create_ret(CONST_FP(0.));
        else
            builder->create_ret(CONST_INT(0));
    }
    scope.exit();
}

// Done
void CminusfBuilder::visit(ASTParam &node) {
    //!TODO: This function is empty now.
    // If the parameter is int|float, copy and store them
    auto param_value = cur_value;
    switch (node.type) {
        case TYPE_INT: {
            cur_value = builder->create_alloca(INT32_T);
            break;
        }
        case TYPE_FLOAT: {
            cur_value = builder->create_alloca(FLOAT_T);
            break;
        }
        case TYPE_VOID:
            return;
    }
    scope.push(node.id, cur_value);
    builder->create_store(param_value, cur_value);
}

// Done?
void CminusfBuilder::visit(ASTCompoundStmt &node) {
    //!TODO: This function is not complete.
    // You may need to add some code here
    // to deal with complex statements.

    /* auto bb = BasicBlock::create(builder->get_module(), "", cur_fun);
     * builder->create_br(bb);
     * builder->set_insert_point(bb); */
    scope.enter();
    for (auto &decl : node.local_declarations) {
        decl->accept(*this);
    }

    for (auto &stmt : node.statement_list) {
        stmt->accept(*this);
        if (builder->get_insert_block()->get_terminator() != nullptr)
            break;
    }
    scope.exit();
}

// Done
void CminusfBuilder::visit(ASTExpressionStmt &node) {
    //!TODO: This function is empty now.
    // Add some code here.
    if (node.expression)
        node.expression->accept(*this);
}

// Done
void CminusfBuilder::visit(ASTSelectionStmt &node) {
    //!TODO: This function is empty now.
    // Add some code here.
    scope.enter();
    node.expression->accept(*this);
    auto cond = cur_value;
    cast_to_i1(cond);

    auto ifBB = BasicBlock::create(builder->get_module(), "", cur_fun);
    auto endBB = BasicBlock::create(builder->get_module(), "", cur_fun);
    if (node.else_statement) {
        auto elseBB = BasicBlock::create(builder->get_module(), "", cur_fun);
        builder->create_cond_br(cond, ifBB, elseBB);

        builder->set_insert_point(ifBB);
        node.if_statement->accept(*this);
        builder->create_br(endBB);

        builder->set_insert_point(elseBB);
        node.else_statement->accept(*this);
        builder->create_br(endBB);

        builder->set_insert_point(endBB);
    } else {
        builder->create_cond_br(cond, ifBB, endBB);

        builder->set_insert_point(ifBB);
        node.if_statement->accept(*this);
        builder->create_br(endBB);

        builder->set_insert_point(endBB);
    }
    scope.exit();
}

// Done
void CminusfBuilder::visit(ASTIterationStmt &node) {
    //!TODO: This function is empty now.
    // Add some code here.
    scope.enter();
    auto HEAD = BasicBlock::create(builder->get_module(), "", cur_fun);
    auto BODY = BasicBlock::create(builder->get_module(), "", cur_fun);
    auto END = BasicBlock::create(builder->get_module(), "", cur_fun);

    builder->create_br(HEAD);

    builder->set_insert_point(HEAD);
    node.expression->accept(*this);
    auto cond = cur_value;
    cast_to_i1(cond);
    builder->create_cond_br(cond, BODY, END);

    builder->set_insert_point(BODY);
    node.statement->accept(*this);
    builder->create_br(HEAD);

    builder->set_insert_point(END);
    scope.exit();
}

// Done
void CminusfBuilder::visit(ASTReturnStmt &node) {
    if (node.expression == nullptr) {
        builder->create_void_ret();
    } else {
        //!TODO: The given code is incomplete.
        // You need to solve other return cases (e.g. return an integer).
        //
        node.expression->accept(*this);
        // type cast
        // return type can only be int, float or void
        if (not Type::is_eq_type(cur_fun->get_return_type(), cur_value->get_type())) {
            if (not cur_value->get_type()->is_integer_type() and not cur_value->get_type()->is_float_type())
                error_exit("unsupported return type");
            if (cur_value->get_type()->is_float_type())
                cur_value = builder->create_fptosi(cur_value, INT32_T);
            else if (cur_fun->get_return_type()->is_float_type())
                cur_value = builder->create_sitofp(cur_value, FLOAT_T);
            else
                cur_value = builder->create_zext(cur_value, INT32_T);
        }

        builder->create_ret(cur_value);
    }
}

// Done
// if LV is marked, return memory addr
// else return value stored inside
void CminusfBuilder::visit(ASTVar &node) {
    //!TODO: This function is empty now.
    // Add some code here.
    //
    // Goal: calculate address
    // 1. get base
    // 2. get bias if there is, and cast it if needed.

    bool old_LV = LV;
    auto memory = scope.find(node.id);
    Value *addr;
    if (memory == nullptr)
        error_exit("variable " + node.id + " not declared");
    LOG_DEBUG << "find entry: " << node.id << " " << memory;
    if (node.expression) { // e.g. int a[10]; // mem is [i32 x 10]*
        assert(memory->get_type()->is_pointer_type());
        LV = false;
        node.expression->accept(*this);
        if (not Type::is_eq_type(cur_value->get_type(), INT32_T)) {
            if (Type::is_eq_type(cur_value->get_type(), FLOAT_T))
                cur_value = builder->create_fptosi(cur_value, INT32_T);
            else
                error_exit("bad type for subscription");
        }

        auto cond = builder->create_icmp_lt(cur_value, CONST_INT(0));
        auto except_func = scope.find("neg_idx_except");
        auto TBB = BasicBlock::create(builder->get_module(), "", cur_fun);
        auto passBB = BasicBlock::create(builder->get_module(), "", cur_fun);
        builder->create_cond_br(cond, TBB, passBB);

        builder->set_insert_point(TBB);
        builder->create_call(except_func, {});
        builder->create_br(passBB);

        builder->set_insert_point(passBB);

        addr = builder->create_gep(memory, {CONST_INT(0), cur_value});

    } else { // e.g. int a; // a is i32*
        addr = memory;
    }
    LV = old_LV;
    if (LV)
        cur_value = addr;
    else {
        cur_value = builder->create_load(addr);
    }
}

// Done
void CminusfBuilder::visit(ASTAssignExpression &node) {
    //!TODO: This function is empty now.
    // Add some code here.
    //
    LV = true;
    node.var->accept(*this);
    LV = false;
    auto addr = cur_value;
    node.expression->accept(*this);

    assert(addr->get_type()->get_pointer_element_type() != nullptr);
    // type cast: left is a pointer type, pointed to i32 or float
    if (not Type::is_eq_type(addr->get_type()->get_pointer_element_type(), cur_value->get_type())) {
        if (cur_value->get_type()->is_float_type())
            cur_value = builder->create_fptosi(cur_value, INT32_T);
        else if (addr->get_type()->get_pointer_element_type()->is_float_type())
            cur_value = builder->create_sitofp(cur_value, FLOAT_T);
        else if (Type::is_eq_type(cur_value->get_type(), INT1_T))
            cur_value = builder->create_zext(cur_value, INT32_T);
        else
            error_exit("bad type for assignment");
    }
    // gen code
    builder->create_store(cur_value, addr);
}

// Done
void CminusfBuilder::visit(ASTSimpleExpression &node) {
    //!TODO: This function is empty now.
    // Add some code here.
    //
    if (node.additive_expression_r) {
        node.additive_expression_l->accept(*this);
        auto lvalue = cur_value;
        node.additive_expression_r->accept(*this);
        auto rvalue = cur_value;
        // check type
        biop_type_check(lvalue, rvalue, "cmp");
        bool float_cmp = lvalue->get_type()->is_float_type();
        switch (node.op) {
            case OP_LE: {
                if (float_cmp)
                    cur_value = builder->create_fcmp_le(lvalue, rvalue);
                else
                    cur_value = builder->create_icmp_le(lvalue, rvalue);
                break;
            }
            case OP_LT: {
                if (float_cmp)
                    cur_value = builder->create_fcmp_lt(lvalue, rvalue);
                else
                    cur_value = builder->create_icmp_lt(lvalue, rvalue);
                break;
            }
            case OP_GT: {
                if (float_cmp)
                    cur_value = builder->create_fcmp_gt(lvalue, rvalue);
                else
                    cur_value = builder->create_icmp_gt(lvalue, rvalue);
                break;
            }
            case OP_GE: {
                if (float_cmp)
                    cur_value = builder->create_fcmp_ge(lvalue, rvalue);
                else
                    cur_value = builder->create_icmp_ge(lvalue, rvalue);
                break;
            }
            case OP_EQ: {
                if (float_cmp)
                    cur_value = builder->create_fcmp_eq(lvalue, rvalue);
                else
                    cur_value = builder->create_icmp_eq(lvalue, rvalue);
                break;
            }
            case OP_NEQ: {
                if (float_cmp)
                    cur_value = builder->create_fcmp_ne(lvalue, rvalue);
                else
                    cur_value = builder->create_icmp_ne(lvalue, rvalue);
                break;
            }
        }
    } else
        node.additive_expression_l->accept(*this);
}

// Done
void CminusfBuilder::visit(ASTAdditiveExpression &node) {
    //!TODO: This function is empty now.
    // Add some code here.
    //
    if (node.additive_expression) {
        node.additive_expression->accept(*this);
        auto lvalue = cur_value;
        node.term->accept(*this);
        auto rvalue = cur_value;
        // check type
        biop_type_check(lvalue, rvalue, "addop");
        bool float_type = lvalue->get_type()->is_float_type();
        // now left and right is the same type
        switch (node.op) {
            case OP_PLUS: {
                if (float_type)
                    cur_value = builder->create_fadd(lvalue, rvalue);
                else
                    cur_value = builder->create_iadd(lvalue, rvalue);
                break;
            }
            case OP_MINUS: {
                if (float_type)
                    cur_value = builder->create_fsub(lvalue, rvalue);
                else
                    cur_value = builder->create_isub(lvalue, rvalue);
                break;
            }
        }
    } else
        node.term->accept(*this);
}

// Done
void CminusfBuilder::visit(ASTTerm &node) {
    //!TODO: This function is empty now.
    // Add some code here.
    if (node.term) {
        node.term->accept(*this);
        auto lvalue = cur_value;
        node.factor->accept(*this);
        auto rvalue = cur_value;
        // check type
        biop_type_check(lvalue, rvalue, "mul");
        bool float_type = lvalue->get_type()->is_float_type();
        // now left and right is the same type
        switch (node.op) {
            case OP_MUL: {
                if (float_type)
                    cur_value = builder->create_fmul(lvalue, rvalue);
                else
                    cur_value = builder->create_imul(lvalue, rvalue);
                break;
            }
            case OP_DIV: {
                if (float_type)
                    cur_value = builder->create_fdiv(lvalue, rvalue);
                else
                    cur_value = builder->create_isdiv(lvalue, rvalue);
                break;
            }
        }
    } else
        node.factor->accept(*this);
}

// Done
void CminusfBuilder::visit(ASTCall &node) {
    //!TODO: This function is empty now.
    // Add some code here.
    Function *func = static_cast<Function *>(scope.find(node.id));
    std::vector<Value *> args;
    if (func == nullptr)
        error_exit("function " + node.id + " not declared");
    if (node.args.size() != func->get_num_of_args())
        error_exit("expect " + std::to_string(func->get_num_of_args()) + " params, but " +
                   std::to_string(node.args.size()) + " is given");
    // check every argument
    for (int i = 0; i != node.args.size(); ++i) {
        // ith parameter's type
        Type *param_type = func->get_function_type()->get_param_type(i);
        node.args[i]->accept(*this);

        // type cast
        if (not Type::is_eq_type(param_type, cur_value->get_type())) {
            if (param_type->is_pointer_type()) {
                if (not Type::is_eq_type(param_type->get_pointer_element_type(),
                                         cur_value->get_type()->get_array_element_type()))
                    error_exit("expected right pointer type");
                // int[] to int* or float[] to flot*
                cur_value = builder->create_gep(cur_value, {CONST_INT(0), CONST_INT(0)});
            } else if (param_type->is_integer_type() or param_type->is_float_type()) {
                // need type cast between int and float
                if (not cur_value->get_type()->is_integer_type() and not cur_value->get_type()->is_float_type())
                    error_exit("unexpected type cast!");

                if (param_type->is_float_type())
                    cur_value = builder->create_sitofp(cur_value, FLOAT_T);
                else if (param_type->is_integer_type())
                    if (cur_value->get_type()->is_integer_type())
                        cur_value = builder->create_zext(cur_value, INT32_T);
                    else
                        cur_value = builder->create_fptosi(cur_value, INT32_T);
                else
                    error_exit("unexpected type cast!");

            } else
                error_exit("unexpected case when casting arguments for function call " + node.id);
        }

        // now cur_value fits the param type
        args.push_back(cur_value);
    }
    cur_value = builder->create_call(func, args);
}