/*
 * 声明:本代码为 2020 秋 中国科大编译原理(李诚)课程实验参考实现。
 * 请不要以任何方式,将本代码上传到可以公开访问的站点或仓库
 */

#include "cminusf_builder.hpp"

#include "logging.hpp"

#define CONST_FP(num) ConstantFP::get((float)num, module.get())
#define CONST_INT(num) ConstantInt::get(num, module.get())

// TODO: Global Variable Declarations
// You can define global variables here
// to store state. You can expand these
// definitions if you need to.

// the latest return value
Value *cur_value = nullptr;
// function that is being built
Function *cur_fun = nullptr;

// types
Type *VOID_T;
Type *INT1_T;
Type *INT32_T;
Type *INT32PTR_T;
Type *FLOAT_T;
Type *FLOATPTR_T;

/*
 * use CMinusfBuilder::Scope to construct scopes
 * scope.enter: enter a new scope
 * scope.exit: exit current scope
 * scope.push: add a new binding to current scope
 * scope.find: find and return the value bound to the name
 */

void error_exit(std::string s) {
    LOG_ERROR << s;
    std::abort();
}

// This function makes sure that
// 1. 2 values have same type
// 2. type is either int or float
// 3. return true if there is float
bool CminusfBuilder::type_cast(Value *&lvalue, Value *&rvalue, std::string util) {
    if (not Type::is_eq_type(lvalue->get_type(), rvalue->get_type())) {
        if (Type::is_eq_type(lvalue->get_type(), INT32_T) and Type::is_eq_type(rvalue->get_type(), FLOAT_T))
            lvalue = builder->create_sitofp(lvalue, FLOAT_T);
        else if (Type::is_eq_type(rvalue->get_type(), INT32_T) and Type::is_eq_type(lvalue->get_type(), FLOAT_T))
            rvalue = builder->create_sitofp(rvalue, FLOAT_T);
    }
    // now 2 value is the same type, but we only support computing between int and float
    if (not Type::is_eq_type(lvalue->get_type(), INT32_T) or not Type::is_eq_type(lvalue->get_type(), FLOAT_T))
        error_exit("not supported cross-type " + util);
    return Type::is_eq_type(lvalue->get_type(), FLOAT_T);
}

void CminusfBuilder::visit(ASTProgram &node) {
    VOID_T = Type::get_void_type(module.get());
    INT1_T = Type::get_int1_type(module.get());
    INT32_T = Type::get_int32_type(module.get());
    INT32PTR_T = Type::get_int32_ptr_type(module.get());
    FLOAT_T = Type::get_float_type(module.get());
    FLOATPTR_T = Type::get_float_ptr_type(module.get());

    for (auto decl : node.declarations) {
        decl->accept(*this);
    }
}

// Done
void CminusfBuilder::visit(ASTNum &node) {
    //!TODO: This function is empty now.
    // Add some code here.
    //
    switch (node.type) {
        case TYPE_INT:
            cur_value = CONST_INT(node.i_val);
            return;
        case TYPE_FLOAT:
            cur_value = CONST_FP(node.f_val);
            return;
        default:
            error_exit("ASTNum is not int or float");
    }
}

// Done
void CminusfBuilder::visit(ASTVarDeclaration &node) {
    //!TODO: This function is empty now.
    // Add some code here.
    if (node.num) {
        // declares an array
        //
        // get array size
        node.num->accept(*this);
        //
        // type cast
        if (Type::is_eq_type(cur_value->get_type(), FLOAT_T))
            cur_value = builder->create_fptosi(cur_value, INT32_T);
        int size = static_cast<ConstantInt *>(cur_value)->get_value();
        if (size <= 0)
            error_exit("array size[" + std::to_string(size) + "] <= 0");

        switch (node.type) {
            case TYPE_INT: {
                auto I32Array_T = Type::get_array_type(INT32_T, size);
                cur_value = builder->create_alloca(I32Array_T);
                break;
            }

            case TYPE_FLOAT: {
                auto FloatArray_T = Type::get_array_type(FLOAT_T, size);
                cur_value = builder->create_alloca(FloatArray_T);
                break;
            }
            default:
                error_exit("Variable type(not array) is not int or float");
        }

    } else {
        switch (node.type) {
            case TYPE_INT:
                cur_value = builder->create_alloca(INT32_T);
                break;
            case TYPE_FLOAT:
                cur_value = builder->create_alloca(FLOAT_T);
                break;
            default:
                error_exit("Variable type(not array) is not int or float");
        }
    }
    if (not scope.push(node.id, cur_value))
        error_exit("variable redefined: " + node.id);
}

// Done
void CminusfBuilder::visit(ASTFunDeclaration &node) {
    FunctionType *fun_type;
    Type *ret_type;
    std::vector<Type *> param_types;
    if (node.type == TYPE_INT)
        ret_type = INT32_T;
    else if (node.type == TYPE_FLOAT)
        ret_type = FLOAT_T;
    else
        ret_type = VOID_T;

    for (auto &param : node.params) {
        //!TODO: Please accomplish param_types.
        //
        // First make function BB, which needs this param type,
        // then set_insert_point, we can call accept to gen code,
        switch (param->type) {
            case TYPE_INT:
                param_types.push_back(param->isarray ? INT32PTR_T : INT32_T);
                break;
            case TYPE_FLOAT:
                param_types.push_back(param->isarray ? FLOATPTR_T : FLOAT_T);
                break;
            case TYPE_VOID:
                if (not param_types.empty())
                    error_exit("function parameters weird");
                break;
        }
    }

    fun_type = FunctionType::get(ret_type, param_types);
    auto fun = Function::create(fun_type, node.id, module.get());
    cur_fun = fun;
    scope.push(node.id, fun);

    auto funBB = BasicBlock::create(module.get(), "entry", fun);
    builder->set_insert_point(funBB);
    scope.enter();

    std::vector<Value *> args;
    for (auto arg = fun->arg_begin(); arg != fun->arg_end(); arg++) {
        args.push_back(*arg);
    }

    for (int i = 0; i < node.params.size(); ++i) {
        //!TODO: You need to deal with params
        // and store them in the scope.
        cur_value = args[i];
        node.params[i]->accept(*this);
    }
    node.compound_stmt->accept(*this);
    // default return value
    if (builder->get_insert_block()->get_terminator() == nullptr) {
        if (cur_fun->get_return_type()->is_void_type())
            builder->create_void_ret();
        else if (cur_fun->get_return_type()->is_float_type())
            builder->create_ret(CONST_FP(0.));
        else
            builder->create_ret(CONST_INT(0));
    }
    scope.exit();
}

// Done
void CminusfBuilder::visit(ASTParam &node) {
    //!TODO: This function is empty now.
    // If the parameter is int|float, copy and store them
    switch (node.type) {
        case TYPE_INT: {
            scope.push(node.id, cur_value);
            break;
        }
        case TYPE_FLOAT: {
            scope.push(node.id, cur_value);
            break;
        }
        case TYPE_VOID:
            break;
    }
}

void CminusfBuilder::visit(ASTCompoundStmt &node) {
    //!TODO: This function is not complete.
    // You may need to add some code here
    // to deal with complex statements.

    for (auto &decl : node.local_declarations) {
        decl->accept(*this);
    }

    for (auto &stmt : node.statement_list) {
        stmt->accept(*this);
        if (builder->get_insert_block()->get_terminator() != nullptr)
            break;
    }
}

void CminusfBuilder::visit(ASTExpressionStmt &node) {
    //!TODO: This function is empty now.
    // Add some code here.
}

void CminusfBuilder::visit(ASTSelectionStmt &node) {
    //!TODO: This function is empty now.
    // Add some code here.
}

void CminusfBuilder::visit(ASTIterationStmt &node) {
    //!TODO: This function is empty now.
    // Add some code here.
}

// Jobs:
// - call accept()
// - gen return code
void CminusfBuilder::visit(ASTReturnStmt &node) {
    if (node.expression == nullptr) {
        builder->create_void_ret();
    } else {
        //!TODO: The given code is incomplete.
        // You need to solve other return cases (e.g. return an integer).
        //
        node.expression->accept(*this);
        builder->create_ret(cur_value);
    }
}

void CminusfBuilder::visit(ASTVar &node) {
    //!TODO: This function is empty now.
    // Add some code here.
    //
    // Goal: calculate address
    // 1. get base
    // 2. get bias if there is, and cast it if needed.

    auto memory = scope.find(node.id);
    Value *addr;
    if (memory == nullptr) {
        LOG_ERROR << node.id << " not declared!";
        std::abort();
    }
    if (node.expression) { // a[1] = 1;
        node.expression->accept(*this);
        if (not Type::is_eq_type(cur_value->get_type(), INT32_T)) {
            if (Type::is_eq_type(cur_value->get_type(), FLOAT_T))
                cur_value = builder->create_fptosi(cur_value, INT32_T);
            else {
                LOG_ERROR << "unexpected type!";
                std::abort();
            }
        }

        addr = builder->create_gep(memory, {cur_value});

    } else { // a = 1;
        addr = memory;
    }
    cur_value = addr;
}

void CminusfBuilder::visit(ASTAssignExpression &node) {
    //!TODO: This function is empty now.
    // Add some code here.
    //
    // 1. For assignment, both side should have a vaule,
    // the left side has an addr, right side have the correspond value
    // 2. Type cast!!!
    //
    //
    // This accept() should update `cur_value` to the left value
    node.var->accept(*this);
    auto left = cur_value;
    // This accept() should update `cur_value` to the right value
    node.expression->accept(*this);

    // type cast
    auto type = scope.find(node.var->id)->get_type();
    if (not Type::is_eq_type(type, cur_value->get_type())) {
        if (type->is_pointer_type())
            cur_value = builder->create_fptosi(cur_value, INT32_T);
        else if (type->is_float_type())
            cur_value = builder->create_sitofp(cur_value, FLOAT_T);
        else {
            LOG(ERROR) << "Bad assignment!";
            std::abort();
        }
    }
    // gen code
    builder->create_store(left, cur_value);
}

// Done
void CminusfBuilder::visit(ASTSimpleExpression &node) {
    //!TODO: This function is empty now.
    // Add some code here.
    //
    node.additive_expression_l->accept(*this);
    if (node.additive_expression_r) {
        auto lvalue = cur_value;
        node.additive_expression_r->accept(*this);
        auto rvalue = cur_value;
        // type cast
        bool float_cmp = type_cast(lvalue, rvalue, "cmp");
        switch (node.op) {
            case OP_LE: {
                if (float_cmp)
                    cur_value = builder->create_fcmp_le(lvalue, rvalue);
                else
                    cur_value = builder->create_icmp_le(lvalue, rvalue);
                break;
            }
            case OP_LT: {
                if (float_cmp)
                    cur_value = builder->create_fcmp_lt(lvalue, rvalue);
                else
                    cur_value = builder->create_icmp_lt(lvalue, rvalue);
                break;
            }
            case OP_GT: {
                if (float_cmp)
                    cur_value = builder->create_fcmp_gt(lvalue, rvalue);
                else
                    cur_value = builder->create_icmp_gt(lvalue, rvalue);
                break;
            }
            case OP_GE: {
                if (float_cmp)
                    cur_value = builder->create_fcmp_ge(lvalue, rvalue);
                else
                    cur_value = builder->create_icmp_ge(lvalue, rvalue);
                break;
            }
            case OP_EQ: {
                if (float_cmp)
                    cur_value = builder->create_fcmp_eq(lvalue, rvalue);
                else
                    cur_value = builder->create_icmp_eq(lvalue, rvalue);
                break;
            }
            case OP_NEQ: {
                if (float_cmp)
                    cur_value = builder->create_fcmp_ne(lvalue, rvalue);
                else
                    cur_value = builder->create_icmp_ne(lvalue, rvalue);
                break;
            }
        }
    }
}

// Done
void CminusfBuilder::visit(ASTAdditiveExpression &node) {
    //!TODO: This function is empty now.
    // Add some code here.
    //
    node.term->accept(*this);
    if (node.additive_expression) {
        auto lvalue = cur_value;
        node.additive_expression->accept(*this);
        auto rvalue = cur_value;
        // type cast
        bool float_type = type_cast(lvalue, rvalue, "addop");
        // now left and right is the same type
        switch (node.op) {
            case OP_PLUS: {
                if (float_type)
                    cur_value = builder->create_fadd(lvalue, rvalue);
                else
                    cur_value = builder->create_iadd(lvalue, rvalue);
                break;
            }
            case OP_MINUS: {
                if (float_type)
                    cur_value = builder->create_fsub(lvalue, rvalue);
                else
                    cur_value = builder->create_isub(lvalue, rvalue);
                break;
            }
        }
    }
}

// Done
void CminusfBuilder::visit(ASTTerm &node) {
    //!TODO: This function is empty now.
    // Add some code here.
    node.factor->accept(*this);
    if (node.term) {
        auto lvalue = cur_value;
        node.term->accept(*this);
        auto rvalue = cur_value;
        // type cast
        bool float_type = type_cast(lvalue, rvalue, "mulop");
        // now left and right is the same type
        switch (node.op) {
            case OP_MUL: {
                if (float_type)
                    cur_value = builder->create_fmul(lvalue, rvalue);
                else
                    cur_value = builder->create_imul(lvalue, rvalue);
                break;
            }
            case OP_DIV: {
                if (float_type)
                    cur_value = builder->create_fdiv(lvalue, rvalue);
                else
                    cur_value = builder->create_isdiv(lvalue, rvalue);
                break;
            }
        }
    }
}

// Done
void CminusfBuilder::visit(ASTCall &node) {
    //!TODO: This function is empty now.
    // Add some code here.
    Function *func = static_cast<Function *>(scope.find(node.id));
    std::vector<Value *> args;
    if (func == nullptr)
        error_exit("function " + node.id + " not declared");
    if (node.args.size() != func->get_num_of_args())
        error_exit("expect " + std::to_string(func->get_num_of_args()) + " params, but " +
                   std::to_string(node.args.size()) + " is given");
    for (int i = 0; i != node.args.size(); ++i) {
        Type *param_type = func->get_function_type()->get_param_type(i);
        node.args[i]->accept(*this);

        // type cast
        if (not Type::is_eq_type(param_type, cur_value->get_type())) {
            if (param_type->is_pointer_type()) {
                if (not Type::is_eq_type(param_type->get_pointer_element_type(),
                                         cur_value->get_type()->get_array_element_type()))
                    error_exit("expected right pointer type");
                // int[] to int* or float[] to flot*
                cur_value = builder->create_gep(cur_value, {0, 0});
            } else if (param_type->is_integer_type() or param_type->is_float_type()) {
                // need type cast between int and float
                if (Type::is_eq_type(cur_value->get_type(), INT32_T))
                    cur_value = builder->create_sitofp(cur_value, FLOAT_T);
                else if (Type::is_eq_type(cur_value->get_type(), FLOAT_T))
                    cur_value = builder->create_fptosi(cur_value, INT32_T);
                else
                    error_exit("unexpected type cast!");

            } else
                error_exit("unexpected case when casting arguments for function call" + node.id);
        }

        // now cur_value fits the param type
        args.push_back(cur_value);
    }
    cur_value = builder->create_call(func, args);
}