Commit 04681c6c authored by Yang's avatar Yang

remove llvm

parent 606023f0
build/
.cache/
\ No newline at end of file
.cache/
.vs
out
CMakePresets.json
Folder.DotSettings.user
\ No newline at end of file
{
"configurations": [
{
"type": "lldb",
"request": "launch",
"name": "Compile & Debug 1.ll",
// 要调试的程序
"program": "${workspaceFolder}/build/cminusfc",
// 命令行参数
"args": [
"-o",
"./build/1.ll",
"-emit-llvm",
"./build/1.cminus"
],
// 程序运行的目录
"cwd": "${workspaceFolder}",
// 程序运行前运行的命令(例如 build)
"preLaunchTask": "make cminusfc"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug 1.ll",
// 要调试的程序
"program": "${workspaceFolder}/build/cminusfc",
// 命令行参数
"args": [
"-o",
"./build/1.ll",
"-emit-llvm",
"./build/1.cminus"
],
// 程序运行的目录
"cwd": "${workspaceFolder}"
}
]
}
\ No newline at end of file
{
"version": "2.0.0",
"tasks": [
{
"type": "shell",
"label": "make cminusfc",
"command": "cd build && make"
}
]
}
\ No newline at end of file
......@@ -29,14 +29,6 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
find_package(FLEX REQUIRED)
find_package(BISON REQUIRED)
find_package(LLVM REQUIRED CONFIG)
message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}")
message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}")
llvm_map_components_to_libnames(
llvm_libs
support
core
)
INCLUDE_DIRECTORIES(
include
......@@ -44,11 +36,8 @@ INCLUDE_DIRECTORIES(
include/common
include/lightir
include/codegen
${LLVM_INCLUDE_DIRS}
)
add_definitions(${LLVM_DEFINITIONS})
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR})
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR})
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR})
......
#pragma once
#include "BasicBlock.hpp"
#include "Constant.hpp"
#include "Function.hpp"
#include "IRBuilder.hpp"
#include "Module.hpp"
......@@ -19,7 +18,7 @@ class Scope {
// exit a scope
void exit() { inner.pop_back(); }
bool in_global() { return inner.size() == 1; }
bool in_global() const { return inner.size() == 1; }
// push a name to scope
// return true if successful
......@@ -30,7 +29,7 @@ class Scope {
}
Value *find(const std::string& name) {
for (auto s = inner.rbegin(); s != inner.rend(); s++) {
for (auto s = inner.rbegin(); s != inner.rend(); ++s) {
auto iter = s->find(name);
if (iter != s->end()) {
return iter->second;
......@@ -39,8 +38,6 @@ class Scope {
// Name not found: handled here?
assert(false && "Name not found in scope");
return nullptr;
}
private:
......@@ -50,29 +47,29 @@ class Scope {
class CminusfBuilder : public ASTVisitor {
public:
CminusfBuilder() {
module = std::make_unique<Module>();
builder = std::make_unique<IRBuilder>(nullptr, module.get());
module = new Module();
builder = new IRBuilder(nullptr, module);
auto *TyVoid = module->get_void_type();
auto *TyInt32 = module->get_int32_type();
auto *TyFloat = module->get_float_type();
auto *input_type = FunctionType::get(TyInt32, {});
auto *input_fun = Function::create(input_type, "input", module.get());
auto *input_fun = Function::create(input_type, "input", module);
std::vector<Type *> output_params;
output_params.push_back(TyInt32);
auto *output_type = FunctionType::get(TyVoid, output_params);
auto *output_fun = Function::create(output_type, "output", module.get());
auto *output_fun = Function::create(output_type, "output", module);
std::vector<Type *> output_float_params;
output_float_params.push_back(TyFloat);
auto *output_float_type = FunctionType::get(TyVoid, output_float_params);
auto *output_float_fun =
Function::create(output_float_type, "outputFloat", module.get());
Function::create(output_float_type, "outputFloat", module);
auto *neg_idx_except_type = FunctionType::get(TyVoid, {});
auto *neg_idx_except_fun = Function::create(
neg_idx_except_type, "neg_idx_except", module.get());
neg_idx_except_type, "neg_idx_except", module);
scope.enter();
scope.push("input", input_fun);
......@@ -81,29 +78,31 @@ class CminusfBuilder : public ASTVisitor {
scope.push("neg_idx_except", neg_idx_except_fun);
}
std::unique_ptr<Module> getModule() { return std::move(module); }
Module* getModule() const { return module; }
~CminusfBuilder() override { delete builder; }
private:
virtual Value *visit(ASTProgram &) override final;
virtual Value *visit(ASTNum &) override final;
virtual Value *visit(ASTVarDeclaration &) override final;
virtual Value *visit(ASTFunDeclaration &) override final;
virtual Value *visit(ASTParam &) override final;
virtual Value *visit(ASTCompoundStmt &) override final;
virtual Value *visit(ASTExpressionStmt &) override final;
virtual Value *visit(ASTSelectionStmt &) override final;
virtual Value *visit(ASTIterationStmt &) override final;
virtual Value *visit(ASTReturnStmt &) override final;
virtual Value *visit(ASTAssignExpression &) override final;
virtual Value *visit(ASTSimpleExpression &) override final;
virtual Value *visit(ASTAdditiveExpression &) override final;
virtual Value *visit(ASTVar &) override final;
virtual Value *visit(ASTTerm &) override final;
virtual Value *visit(ASTCall &) override final;
std::unique_ptr<IRBuilder> builder;
Value *visit(ASTProgram &) final;
Value *visit(ASTNum &) final;
Value *visit(ASTVarDeclaration &) final;
Value *visit(ASTFunDeclaration &) final;
Value *visit(ASTParam &) final;
Value *visit(ASTCompoundStmt &) final;
Value *visit(ASTExpressionStmt &) final;
Value *visit(ASTSelectionStmt &) final;
Value *visit(ASTIterationStmt &) final;
Value *visit(ASTReturnStmt &) final;
Value *visit(ASTAssignExpression &) final;
Value *visit(ASTSimpleExpression &) final;
Value *visit(ASTAdditiveExpression &) final;
Value *visit(ASTVar &) final;
Value *visit(ASTTerm &) final;
Value *visit(ASTCall &) final;
IRBuilder* builder;
Scope scope;
std::unique_ptr<Module> module;
Module* module;
struct {
// whether require lvalue
......
......@@ -90,23 +90,23 @@ struct ASTNode {
};
struct ASTProgram : ASTNode {
virtual Value* accept(ASTVisitor &) override final;
virtual ~ASTProgram() = default;
Value* accept(ASTVisitor &) final;
~ASTProgram() override = default;
std::vector<std::shared_ptr<ASTDeclaration>> declarations;
};
struct ASTDeclaration : ASTNode {
virtual ~ASTDeclaration() = default;
~ASTDeclaration() override = default;
CminusType type;
std::string id;
};
struct ASTFactor : ASTNode {
virtual ~ASTFactor() = default;
~ASTFactor() override = default;
};
struct ASTNum : ASTFactor {
virtual Value* accept(ASTVisitor &) override final;
Value* accept(ASTVisitor &) final;
CminusType type;
union {
int i_val;
......@@ -115,18 +115,18 @@ struct ASTNum : ASTFactor {
};
struct ASTVarDeclaration : ASTDeclaration {
virtual Value* accept(ASTVisitor &) override final;
Value* accept(ASTVisitor &) final;
std::shared_ptr<ASTNum> num;
};
struct ASTFunDeclaration : ASTDeclaration {
virtual Value* accept(ASTVisitor &) override final;
Value* accept(ASTVisitor &) final;
std::vector<std::shared_ptr<ASTParam>> params;
std::shared_ptr<ASTCompoundStmt> compound_stmt;
};
struct ASTParam : ASTNode {
virtual Value* accept(ASTVisitor &) override final;
Value* accept(ASTVisitor &) final;
CminusType type;
std::string id;
// true if it is array param
......@@ -134,22 +134,22 @@ struct ASTParam : ASTNode {
};
struct ASTStatement : ASTNode {
virtual ~ASTStatement() = default;
~ASTStatement() override = default;
};
struct ASTCompoundStmt : ASTStatement {
virtual Value* accept(ASTVisitor &) override final;
Value* accept(ASTVisitor &) final;
std::vector<std::shared_ptr<ASTVarDeclaration>> local_declarations;
std::vector<std::shared_ptr<ASTStatement>> statement_list;
};
struct ASTExpressionStmt : ASTStatement {
virtual Value* accept(ASTVisitor &) override final;
Value* accept(ASTVisitor &) final;
std::shared_ptr<ASTExpression> expression;
};
struct ASTSelectionStmt : ASTStatement {
virtual Value* accept(ASTVisitor &) override final;
Value* accept(ASTVisitor &) final;
std::shared_ptr<ASTExpression> expression;
std::shared_ptr<ASTStatement> if_statement;
// should be nullptr if no else structure exists
......@@ -157,13 +157,13 @@ struct ASTSelectionStmt : ASTStatement {
};
struct ASTIterationStmt : ASTStatement {
virtual Value* accept(ASTVisitor &) override final;
Value* accept(ASTVisitor &) final;
std::shared_ptr<ASTExpression> expression;
std::shared_ptr<ASTStatement> statement;
};
struct ASTReturnStmt : ASTStatement {
virtual Value* accept(ASTVisitor &) override final;
Value* accept(ASTVisitor &) final;
// should be nullptr if return void
std::shared_ptr<ASTExpression> expression;
};
......@@ -171,41 +171,41 @@ struct ASTReturnStmt : ASTStatement {
struct ASTExpression : ASTFactor {};
struct ASTAssignExpression : ASTExpression {
virtual Value* accept(ASTVisitor &) override final;
Value* accept(ASTVisitor &) final;
std::shared_ptr<ASTVar> var;
std::shared_ptr<ASTExpression> expression;
};
struct ASTSimpleExpression : ASTExpression {
virtual Value* accept(ASTVisitor &) override final;
Value* accept(ASTVisitor &) final;
std::shared_ptr<ASTAdditiveExpression> additive_expression_l;
std::shared_ptr<ASTAdditiveExpression> additive_expression_r;
RelOp op;
};
struct ASTVar : ASTFactor {
virtual Value* accept(ASTVisitor &) override final;
Value* accept(ASTVisitor &) final;
std::string id;
// nullptr if var is of int type
std::shared_ptr<ASTExpression> expression;
};
struct ASTAdditiveExpression : ASTNode {
virtual Value* accept(ASTVisitor &) override final;
Value* accept(ASTVisitor &) final;
std::shared_ptr<ASTAdditiveExpression> additive_expression;
AddOp op;
std::shared_ptr<ASTTerm> term;
};
struct ASTTerm : ASTNode {
virtual Value* accept(ASTVisitor &) override final;
Value* accept(ASTVisitor &) final;
std::shared_ptr<ASTTerm> term;
MulOp op;
std::shared_ptr<ASTFactor> factor;
};
struct ASTCall : ASTFactor {
virtual Value* accept(ASTVisitor &) override final;
Value* accept(ASTVisitor &) final;
std::string id;
std::vector<std::shared_ptr<ASTExpression>> args;
};
......@@ -228,26 +228,27 @@ class ASTVisitor {
virtual Value* visit(ASTVar &) = 0;
virtual Value* visit(ASTTerm &) = 0;
virtual Value* visit(ASTCall &) = 0;
virtual ~ASTVisitor();
};
class ASTPrinter : public ASTVisitor {
public:
virtual Value* visit(ASTProgram &) override final;
virtual Value* visit(ASTNum &) override final;
virtual Value* visit(ASTVarDeclaration &) override final;
virtual Value* visit(ASTFunDeclaration &) override final;
virtual Value* visit(ASTParam &) override final;
virtual Value* visit(ASTCompoundStmt &) override final;
virtual Value* visit(ASTExpressionStmt &) override final;
virtual Value* visit(ASTSelectionStmt &) override final;
virtual Value* visit(ASTIterationStmt &) override final;
virtual Value* visit(ASTReturnStmt &) override final;
virtual Value* visit(ASTAssignExpression &) override final;
virtual Value* visit(ASTSimpleExpression &) override final;
virtual Value* visit(ASTAdditiveExpression &) override final;
virtual Value* visit(ASTVar &) override final;
virtual Value* visit(ASTTerm &) override final;
virtual Value* visit(ASTCall &) override final;
Value* visit(ASTProgram &) final;
Value* visit(ASTNum &) final;
Value* visit(ASTVarDeclaration &) final;
Value* visit(ASTFunDeclaration &) final;
Value* visit(ASTParam &) final;
Value* visit(ASTCompoundStmt &) final;
Value* visit(ASTExpressionStmt &) final;
Value* visit(ASTSelectionStmt &) final;
Value* visit(ASTIterationStmt &) final;
Value* visit(ASTReturnStmt &) final;
Value* visit(ASTAssignExpression &) final;
Value* visit(ASTSimpleExpression &) final;
Value* visit(ASTAdditiveExpression &) final;
Value* visit(ASTVar &) final;
Value* visit(ASTTerm &) final;
Value* visit(ASTCall &) final;
void add_depth() { depth += 2; }
void remove_depth() { depth -= 2; }
......
......@@ -4,9 +4,6 @@
#include "Value.hpp"
#include <list>
#include <llvm/ADT/ilist.h>
#include <llvm/ADT/ilist_node.h>
#include <set>
#include <string>
class Function;
......@@ -15,7 +12,11 @@ class Module;
class BasicBlock : public Value {
public:
~BasicBlock() = default;
BasicBlock(const BasicBlock& other) = delete;
BasicBlock(BasicBlock&& other) noexcept = delete;
BasicBlock& operator=(const BasicBlock& other) = delete;
BasicBlock& operator=(BasicBlock&& other) noexcept = delete;
~BasicBlock() override;
static BasicBlock *create(Module *m, const std::string &name,
Function *parent) {
auto prefix = name.empty() ? "" : "label_";
......@@ -30,39 +31,48 @@ class BasicBlock : public Value {
void add_succ_basic_block(BasicBlock *bb) { succ_bbs_.push_back(bb); }
void remove_pre_basic_block(BasicBlock *bb) { pre_bbs_.remove(bb); }
void remove_succ_basic_block(BasicBlock *bb) { succ_bbs_.remove(bb); }
BasicBlock* get_entry_block_of_same_function();
BasicBlock* get_entry_block_of_same_function() const;
// If the Block is terminated by ret/br
bool is_terminated() const;
// Get terminator, only accept valid case use
Instruction *get_terminator();
Instruction *get_terminator() const;
/****************api about Instruction****************/
// 在指令表最后插入指令
// 新特性:指令表分为 {alloca, phi | other inst} 两段,创建和向基本块插入 alloca 和 phi,都只会插在第一段,它们在常规指令前面。
// 因此,即使终止基本块也能插入 alloca 和 phi
void add_instruction(Instruction *instr);
void add_instr_begin(Instruction *instr) { instr_list_.push_front(instr); }
void add_instr_before_end(Instruction *instr) {
instr_list_.insert(std::prev(instr_list_.end()), instr);
}
void erase_instr(Instruction *instr) { instr_list_.erase(instr); }
// 在指令链表最前面插入指令
// 新特性:指令表分为 {alloca, phi | other inst} 两段,创建和向基本块插入 alloca 和 phi,都只会插在第一段,它们在常规指令前面。
void add_instr_begin(Instruction* instr);
// 绕过终止指令插入指令
// 新特性:指令表分为 {alloca, phi | other inst} 两段,创建和向基本块插入 alloca 和 phi,都只会插在第一段,它们在常规指令前面。
void add_instr_before_terminator(Instruction* instr);
// 从 BasicBlock 移除 Instruction,并 delete 这个 Instruction
void erase_instr(Instruction* instr) { instr_list_.remove(instr); delete instr; }
// 从 BasicBlock 移除 Instruction,你需要自己 delete 它
void remove_instr(Instruction *instr) { instr_list_.remove(instr); }
llvm::ilist<Instruction> &get_instructions() { return instr_list_; }
// 移除的 Instruction 需要自己 delete
std::list<Instruction*> &get_instructions() { return instr_list_; }
bool empty() const { return instr_list_.empty(); }
int get_num_of_instr() const { return instr_list_.size(); }
int get_num_of_instr() const { return static_cast<int>(instr_list_.size()); }
/****************api about accessing parent****************/
Function *get_parent() { return parent_; }
Module *get_module();
Function *get_parent() const { return parent_; }
Module *get_module() const;
void erase_from_parent();
virtual std::string print() override;
std::string print() override;
private:
BasicBlock(const BasicBlock &) = delete;
explicit BasicBlock(Module *m, const std::string &name, Function *parent);
std::list<BasicBlock *> pre_bbs_;
std::list<BasicBlock *> succ_bbs_;
llvm::ilist<Instruction> instr_list_;
std::list<Instruction*> instr_list_;
Function *parent_;
};
......@@ -5,11 +5,16 @@
#include "Value.hpp"
class Constant : public User {
private:
private:
// int value;
public:
public:
Constant(const Constant& other) = delete;
Constant(Constant&& other) noexcept = delete;
Constant& operator=(const Constant& other) = delete;
Constant& operator=(Constant&& other) noexcept = delete;
Constant(Type *ty, const std::string &name = "") : User(ty, name) {}
~Constant() = default;
~Constant() override;
};
class ConstantInt : public Constant {
......@@ -18,10 +23,10 @@ class ConstantInt : public Constant {
ConstantInt(Type *ty, int val) : Constant(ty, ""), value_(val) {}
public:
int get_value() { return value_; }
int get_value() const { return value_; }
static ConstantInt *get(int val, Module *m);
static ConstantInt *get(bool val, Module *m);
virtual std::string print() override;
std::string print() override;
};
class ConstantArray : public Constant {
......@@ -31,16 +36,21 @@ class ConstantArray : public Constant {
ConstantArray(ArrayType *ty, const std::vector<Constant *> &val);
public:
~ConstantArray() = default;
ConstantArray(const ConstantArray& other) = delete;
ConstantArray(ConstantArray&& other) noexcept = delete;
ConstantArray& operator=(const ConstantArray& other) = delete;
ConstantArray& operator=(ConstantArray&& other) noexcept = delete;
~ConstantArray() override = default;
Constant *get_element_value(int index);
Constant *get_element_value(int index) const;
unsigned get_size_of_array() { return const_array.size(); }
int get_size_of_array() const { return static_cast<int>(const_array.size()); }
static ConstantArray *get(ArrayType *ty,
const std::vector<Constant *> &val);
virtual std::string print() override;
std::string print() override;
};
class ConstantZero : public Constant {
......@@ -49,7 +59,7 @@ class ConstantZero : public Constant {
public:
static ConstantZero *get(Type *ty, Module *m);
virtual std::string print() override;
std::string print() override;
};
class ConstantFP : public Constant {
......@@ -59,6 +69,6 @@ class ConstantFP : public Constant {
public:
static ConstantFP *get(float val, Module *m);
float get_value() { return val_; }
virtual std::string print() override;
float get_value() const { return val_; }
std::string print() override;
};
......@@ -2,15 +2,9 @@
#include "BasicBlock.hpp"
#include "Type.hpp"
#include "User.hpp"
#include <cassert>
#include <cstddef>
#include <iterator>
#include <list>
#include <llvm/ADT/ilist.h>
#include <llvm/ADT/ilist_node.h>
#include <map>
#include <memory>
class Module;
......@@ -18,11 +12,14 @@ class Argument;
class Type;
class FunctionType;
class Function : public Value, public llvm::ilist_node<Function> {
class Function : public Value {
public:
Function(const Function &) = delete;
Function(const Function& other) = delete;
Function(Function&& other) noexcept = delete;
Function& operator=(const Function& other) = delete;
Function& operator=(Function&& other) noexcept = delete;
Function(FunctionType *ty, const std::string &name, Module *parent);
~Function() = default;
~Function() override;
static Function *create(FunctionType *ty, const std::string &name,
Module *parent);
......@@ -38,17 +35,17 @@ class Function : public Value, public llvm::ilist_node<Function> {
// 此处 remove 的 BasicBlock, 需要手动 delete
void remove(BasicBlock *bb);
BasicBlock *get_entry_block() { return basic_blocks_.front(); }
BasicBlock *get_entry_block() const { return basic_blocks_.front(); }
std::list<BasicBlock*> &get_basic_blocks() { return basic_blocks_; }
std::list<Argument> &get_args() { return arguments_; }
bool is_declaration() { return basic_blocks_.empty(); }
bool is_declaration() const { return basic_blocks_.empty(); }
void set_instr_name();
std::string print();
std::string print() override;
// 用于检查函数的基本块是否存在问题
void check_for_block_relation_error();
void check_for_block_relation_error() const;
private:
std::list<BasicBlock*> basic_blocks_;
......@@ -60,14 +57,17 @@ class Function : public Value, public llvm::ilist_node<Function> {
// Argument of Function, does not contain actual value
class Argument : public Value {
public:
Argument(const Argument &) = delete;
Argument(const Argument& other) = delete;
Argument(Argument&& other) noexcept = delete;
Argument& operator=(const Argument& other) = delete;
Argument& operator=(Argument&& other) noexcept = delete;
explicit Argument(Type *ty, const std::string &name = "",
Function *f = nullptr, unsigned arg_no = 0)
: Value(ty, name), parent_(f), arg_no_(arg_no) {}
virtual ~Argument() {}
~Argument() override = default;
inline const Function *get_parent() const { return parent_; }
inline Function *get_parent() { return parent_; }
const Function *get_parent() const { return parent_; }
Function *get_parent() { return parent_; }
/// For example in "void foo(int a, float b)" a is 0 and b is 1.
unsigned get_arg_no() const {
......@@ -75,7 +75,7 @@ class Argument : public Value {
return arg_no_;
}
virtual std::string print() override;
std::string print() override;
private:
Function *parent_;
......
......@@ -3,21 +3,24 @@
#include "Constant.hpp"
#include "User.hpp"
#include <llvm/ADT/ilist_node.h>
class Module;
class GlobalVariable : public User, public llvm::ilist_node<GlobalVariable> {
class GlobalVariable : public User {
private:
bool is_const_;
Constant *init_val_;
GlobalVariable(std::string name, Module *m, Type *ty, bool is_const,
GlobalVariable(const std::string& name, Module *m, Type *ty, bool is_const,
Constant *init = nullptr);
public:
GlobalVariable(const GlobalVariable &) = delete;
static GlobalVariable *create(std::string name, Module *m, Type *ty,
GlobalVariable(const GlobalVariable& other) = delete;
GlobalVariable(GlobalVariable&& other) noexcept = delete;
GlobalVariable& operator=(const GlobalVariable& other) = delete;
GlobalVariable& operator=(GlobalVariable&& other) noexcept = delete;
static GlobalVariable *create(const std::string& name, Module *m, Type *ty,
bool is_const, Constant *init);
virtual ~GlobalVariable() = default;
Constant *get_init() { return init_val_; }
bool is_const() { return is_const_; }
std::string print();
~GlobalVariable() override = default;
Constant *get_init() const { return init_val_; }
bool is_const() const { return is_const_; }
std::string print() override;
};
#pragma once
#include "BasicBlock.hpp"
#include "Function.hpp"
#include "Instruction.hpp"
#include "Value.hpp"
......@@ -11,124 +10,159 @@ class IRBuilder {
Module *m_;
public:
IRBuilder(BasicBlock *bb, Module *m) : BB_(bb), m_(m){};
IRBuilder(const IRBuilder& other) = delete;
IRBuilder(IRBuilder&& other) noexcept = delete;
IRBuilder& operator=(const IRBuilder& other) = delete;
IRBuilder& operator=(IRBuilder&& other) noexcept = delete;
IRBuilder(BasicBlock *bb, Module *m) : BB_(bb), m_(m){}
~IRBuilder() = default;
Module *get_module() { return m_; }
BasicBlock *get_insert_block() { return this->BB_; }
Module *get_module() const { return m_; }
BasicBlock *get_insert_block() const { return this->BB_; }
void set_insert_point(BasicBlock *bb) {
this->BB_ = bb;
} // 在某个基本块中插入指令
IBinaryInst *create_iadd(Value *lhs, Value *rhs) {
IBinaryInst *create_iadd(Value *lhs, Value *rhs) const
{
return IBinaryInst::create_add(lhs, rhs, this->BB_);
} // 创建加法指令(以及其他算术指令)
IBinaryInst *create_isub(Value *lhs, Value *rhs) {
IBinaryInst *create_isub(Value *lhs, Value *rhs) const
{
return IBinaryInst::create_sub(lhs, rhs, this->BB_);
}
IBinaryInst *create_imul(Value *lhs, Value *rhs) {
IBinaryInst *create_imul(Value *lhs, Value *rhs) const
{
return IBinaryInst::create_mul(lhs, rhs, this->BB_);
}
IBinaryInst *create_isdiv(Value *lhs, Value *rhs) {
IBinaryInst *create_isdiv(Value *lhs, Value *rhs) const
{
return IBinaryInst::create_sdiv(lhs, rhs, this->BB_);
}
ICmpInst *create_icmp_eq(Value *lhs, Value *rhs) {
ICmpInst *create_icmp_eq(Value *lhs, Value *rhs) const
{
return ICmpInst::create_eq(lhs, rhs, this->BB_);
}
ICmpInst *create_icmp_ne(Value *lhs, Value *rhs) {
ICmpInst *create_icmp_ne(Value *lhs, Value *rhs) const
{
return ICmpInst::create_ne(lhs, rhs, this->BB_);
}
ICmpInst *create_icmp_gt(Value *lhs, Value *rhs) {
ICmpInst *create_icmp_gt(Value *lhs, Value *rhs) const
{
return ICmpInst::create_gt(lhs, rhs, this->BB_);
}
ICmpInst *create_icmp_ge(Value *lhs, Value *rhs) {
ICmpInst *create_icmp_ge(Value *lhs, Value *rhs) const
{
return ICmpInst::create_ge(lhs, rhs, this->BB_);
}
ICmpInst *create_icmp_lt(Value *lhs, Value *rhs) {
ICmpInst *create_icmp_lt(Value *lhs, Value *rhs) const
{
return ICmpInst::create_lt(lhs, rhs, this->BB_);
}
ICmpInst *create_icmp_le(Value *lhs, Value *rhs) {
ICmpInst *create_icmp_le(Value *lhs, Value *rhs) const
{
return ICmpInst::create_le(lhs, rhs, this->BB_);
}
CallInst *create_call(Value *func, std::vector<Value *> args) {
return CallInst::create_call(static_cast<Function *>(func), args,
CallInst *create_call(Value *func, const std::vector<Value *>& args) const
{
return CallInst::create_call(dynamic_cast<Function *>(func), args,
this->BB_);
}
BranchInst *create_br(BasicBlock *if_true) {
BranchInst *create_br(BasicBlock *if_true) const
{
return BranchInst::create_br(if_true, this->BB_);
}
BranchInst *create_cond_br(Value *cond, BasicBlock *if_true,
BasicBlock *if_false) {
BasicBlock *if_false) const
{
return BranchInst::create_cond_br(cond, if_true, if_false, this->BB_);
}
ReturnInst *create_ret(Value *val) {
ReturnInst *create_ret(Value *val) const
{
return ReturnInst::create_ret(val, this->BB_);
}
ReturnInst *create_void_ret() {
ReturnInst *create_void_ret() const
{
return ReturnInst::create_void_ret(this->BB_);
}
GetElementPtrInst *create_gep(Value *ptr, std::vector<Value *> idxs) {
GetElementPtrInst *create_gep(Value *ptr, const std::vector<Value *>& idxs) const
{
return GetElementPtrInst::create_gep(ptr, idxs, this->BB_);
}
StoreInst *create_store(Value *val, Value *ptr) {
StoreInst *create_store(Value *val, Value *ptr) const
{
return StoreInst::create_store(val, ptr, this->BB_);
}
LoadInst *create_load(Value *ptr) {
LoadInst *create_load(Value *ptr) const
{
assert(ptr->get_type()->is_pointer_type() &&
"ptr must be pointer type");
return LoadInst::create_load(ptr, this->BB_);
}
AllocaInst *create_alloca(Type *ty) {
return AllocaInst::create_alloca(ty, this->BB_);
}
AllocaInst *create_alloca_begin(Type *ty) {
return AllocaInst::create_alloca_begin(ty, this->BB_);
AllocaInst *create_alloca(Type *ty) const
{
return AllocaInst::create_alloca(ty, this->BB_->get_entry_block_of_same_function());
}
ZextInst *create_zext(Value *val, Type *ty) {
ZextInst *create_zext(Value *val, Type *ty) const
{
return ZextInst::create_zext(val, ty, this->BB_);
}
SiToFpInst *create_sitofp(Value *val, Type *ty) {
SiToFpInst *create_sitofp(Value *val, Type *ty) const
{
return SiToFpInst::create_sitofp(val, this->BB_);
}
FpToSiInst *create_fptosi(Value *val, Type *ty) {
FpToSiInst *create_fptosi(Value *val, Type *ty) const
{
return FpToSiInst::create_fptosi(val, ty, this->BB_);
}
FCmpInst *create_fcmp_ne(Value *lhs, Value *rhs) {
FCmpInst *create_fcmp_ne(Value *lhs, Value *rhs) const
{
return FCmpInst::create_fne(lhs, rhs, this->BB_);
}
FCmpInst *create_fcmp_lt(Value *lhs, Value *rhs) {
FCmpInst *create_fcmp_lt(Value *lhs, Value *rhs) const
{
return FCmpInst::create_flt(lhs, rhs, this->BB_);
}
FCmpInst *create_fcmp_le(Value *lhs, Value *rhs) {
FCmpInst *create_fcmp_le(Value *lhs, Value *rhs) const
{
return FCmpInst::create_fle(lhs, rhs, this->BB_);
}
FCmpInst *create_fcmp_ge(Value *lhs, Value *rhs) {
FCmpInst *create_fcmp_ge(Value *lhs, Value *rhs) const
{
return FCmpInst::create_fge(lhs, rhs, this->BB_);
}
FCmpInst *create_fcmp_gt(Value *lhs, Value *rhs) {
FCmpInst *create_fcmp_gt(Value *lhs, Value *rhs) const
{
return FCmpInst::create_fgt(lhs, rhs, this->BB_);
}
FCmpInst *create_fcmp_eq(Value *lhs, Value *rhs) {
FCmpInst *create_fcmp_eq(Value *lhs, Value *rhs) const
{
return FCmpInst::create_feq(lhs, rhs, this->BB_);
}
FBinaryInst *create_fadd(Value *lhs, Value *rhs) {
FBinaryInst *create_fadd(Value *lhs, Value *rhs) const
{
return FBinaryInst::create_fadd(lhs, rhs, this->BB_);
}
FBinaryInst *create_fsub(Value *lhs, Value *rhs) {
FBinaryInst *create_fsub(Value *lhs, Value *rhs) const
{
return FBinaryInst::create_fsub(lhs, rhs, this->BB_);
}
FBinaryInst *create_fmul(Value *lhs, Value *rhs) {
FBinaryInst *create_fmul(Value *lhs, Value *rhs) const
{
return FBinaryInst::create_fmul(lhs, rhs, this->BB_);
}
FBinaryInst *create_fdiv(Value *lhs, Value *rhs) {
FBinaryInst *create_fdiv(Value *lhs, Value *rhs) const
{
return FBinaryInst::create_fdiv(lhs, rhs, this->BB_);
}
};
#pragma once
#include "BasicBlock.hpp"
#include "Constant.hpp"
#include "Function.hpp"
#include "GlobalVariable.hpp"
#include "Instruction.hpp"
#include "Module.hpp"
#include "Type.hpp"
#include "User.hpp"
#include "Value.hpp"
std::string print_as_op(Value *v, bool print_ty);
......
This diff is collapsed.
......@@ -2,13 +2,9 @@
#include "Function.hpp"
#include "GlobalVariable.hpp"
#include "Instruction.hpp"
#include "Type.hpp"
#include "Value.hpp"
#include <list>
#include <llvm/ADT/ilist.h>
#include <llvm/ADT/ilist_node.h>
#include <map>
#include <memory>
#include <string>
......@@ -18,14 +14,18 @@ class Function;
class Module {
public:
Module();
~Module() = default;
Type *get_void_type();
Type *get_label_type();
IntegerType *get_int1_type();
IntegerType *get_int32_type();
~Module();
Module(const Module& other) = delete;
Module(Module&& other) noexcept = delete;
Module& operator=(const Module& other) = delete;
Module& operator=(Module&& other) noexcept = delete;
Type *get_void_type() const;
Type *get_label_type() const;
IntegerType *get_int1_type() const;
IntegerType *get_int32_type() const;
PointerType *get_int32_ptr_type();
FloatType *get_float_type();
FloatType *get_float_type() const;
PointerType *get_float_ptr_type();
PointerType *get_pointer_type(Type *contained);
......@@ -33,27 +33,27 @@ class Module {
FunctionType *get_function_type(Type *retty, std::vector<Type *> &args);
void add_function(Function *f);
llvm::ilist<Function> &get_functions();
std::list<Function*> &get_functions();
void add_global_variable(GlobalVariable *g);
llvm::ilist<GlobalVariable> &get_global_variable();
std::list<GlobalVariable*> &get_global_variable();
void set_print_name();
std::string print();
private:
// The global variables in the module
llvm::ilist<GlobalVariable> global_list_;
std::list<GlobalVariable*> global_list_;
// The functions in the module
llvm::ilist<Function> function_list_;
std::unique_ptr<IntegerType> int1_ty_;
std::unique_ptr<IntegerType> int32_ty_;
std::unique_ptr<Type> label_ty_;
std::unique_ptr<Type> void_ty_;
std::unique_ptr<FloatType> float32_ty_;
std::map<Type *, std::unique_ptr<PointerType>> pointer_map_;
std::map<std::pair<Type *, int>, std::unique_ptr<ArrayType>> array_map_;
std::list<Function*> function_list_;
IntegerType* int1_ty_;
IntegerType* int32_ty_;
Type* label_ty_;
Type* void_ty_;
FloatType* float32_ty_;
std::map<Type *, PointerType*> pointer_map_;
std::map<std::pair<Type *, int>, ArrayType*> array_map_;
std::map<std::pair<Type *, std::vector<Type *>>,
std::unique_ptr<FunctionType>>
FunctionType*>
function_map_;
};
......@@ -12,7 +12,12 @@ class FloatType;
class Type {
public:
enum TypeID {
Type(const Type& other) = delete;
Type(Type&& other) noexcept = delete;
Type& operator=(const Type& other) = delete;
Type& operator=(Type&& other) noexcept = delete;
enum TypeID: uint8_t {
VoidTyID, // Void
LabelTyID, // Labels, e.g., BasicBlock
IntegerTyID, // Integers, include 32 bits and 1 bit
......@@ -25,11 +30,7 @@ class Type {
explicit Type(TypeID tid, Module *m);
// 生成 vptr, 使调试时能够显示 IntegerType 等 Type 的信息
#ifdef DEBUG_BUILD
virtual
#endif
~Type() = default;
virtual ~Type();
TypeID get_type_id() const { return tid_; }
......@@ -59,7 +60,13 @@ class Type {
class IntegerType : public Type {
public:
explicit IntegerType(unsigned num_bits, Module *m);
IntegerType(const IntegerType& other) = delete;
IntegerType(IntegerType&& other) noexcept = delete;
IntegerType& operator=(const IntegerType& other) = delete;
IntegerType& operator=(IntegerType&& other) noexcept = delete;
explicit IntegerType(unsigned num_bits, Module *m);
~IntegerType() override;
unsigned get_num_bits() const;
......@@ -69,10 +76,16 @@ class IntegerType : public Type {
class FunctionType : public Type {
public:
FunctionType(Type *result, std::vector<Type *> params);
FunctionType(const FunctionType& other) = delete;
FunctionType(FunctionType&& other) noexcept = delete;
FunctionType& operator=(const FunctionType& other) = delete;
FunctionType& operator=(FunctionType&& other) noexcept = delete;
static bool is_valid_return_type(Type *ty);
static bool is_valid_argument_type(Type *ty);
FunctionType(Type *result, const std::vector<Type *>& params);
~FunctionType() override;
static bool is_valid_return_type(const Type *ty);
static bool is_valid_argument_type(const Type *ty);
static FunctionType *get(Type *result, std::vector<Type *> params);
......@@ -90,9 +103,15 @@ class FunctionType : public Type {
class ArrayType : public Type {
public:
ArrayType(Type *contained, unsigned num_elements);
ArrayType(const ArrayType& other) = delete;
ArrayType(ArrayType&& other) noexcept = delete;
ArrayType& operator=(const ArrayType& other) = delete;
ArrayType& operator=(ArrayType&& other) noexcept = delete;
ArrayType(Type *contained, unsigned num_elements);
~ArrayType() override;
static bool is_valid_element_type(Type *ty);
static bool is_valid_element_type(const Type *ty);
static ArrayType *get(Type *contained, unsigned num_elements);
......@@ -106,7 +125,13 @@ class ArrayType : public Type {
class PointerType : public Type {
public:
PointerType(Type *contained);
PointerType(const PointerType& other) = delete;
PointerType(PointerType&& other) noexcept = delete;
PointerType& operator=(const PointerType& other) = delete;
PointerType& operator=(PointerType&& other) noexcept = delete;
PointerType(Type *contained);
~PointerType() override;
Type *get_element_type() const { return contained_; }
static PointerType *get(Type *contained);
......@@ -117,8 +142,12 @@ class PointerType : public Type {
class FloatType : public Type {
public:
FloatType(Module *m);
static FloatType *get(Module *m);
FloatType(const FloatType& other) = delete;
FloatType(FloatType&& other) noexcept = delete;
FloatType& operator=(const FloatType& other) = delete;
FloatType& operator=(FloatType&& other) noexcept = delete;
private:
FloatType(Module *m);
~FloatType() override;
static FloatType *get(Module *m);
};
......@@ -6,14 +6,18 @@
class User : public Value {
public:
User(Type *ty, const std::string &name = "") : Value(ty, name){};
virtual ~User() { remove_all_operands(); }
User(const User& other) = delete;
User(User&& other) noexcept = delete;
User& operator=(const User& other) = delete;
User& operator=(User&& other) noexcept = delete;
User(Type *ty, const std::string &name = "") : Value(ty, name){}
~User() override;
const std::vector<Value *> &get_operands() const { return operands_; }
unsigned get_num_operand() const { return operands_.size(); }
unsigned get_num_operand() const { return static_cast<unsigned>(operands_.size()); }
// start from 0
Value *get_operand(unsigned i) const { return operands_.at(i); };
Value *get_operand(unsigned i) const { return operands_.at(i); }
// start from 0
void set_operand(unsigned i, Value *v);
void add_operand(Value *v);
......
#pragma once
#include <functional>
#include <iostream>
#include <list>
#include <string>
#include <cassert>
......@@ -13,43 +12,48 @@ struct Use;
class Value {
public:
explicit Value(Type *ty, const std::string &name = "")
: type_(ty), name_(name){};
virtual ~Value() { replace_all_use_with(nullptr); }
Value(const Value& other) = delete;
Value(Value&& other) noexcept = delete;
Value& operator=(const Value& other) = delete;
Value& operator=(Value&& other) noexcept = delete;
std::string get_name() const { return name_; };
explicit Value(Type *ty, std::string name = "")
: type_(ty), name_(std::move(name)){}
virtual ~Value();
std::string get_name() const { return name_; }
Type *get_type() const { return type_; }
const std::list<Use> &get_use_list() const { return use_list_; }
bool set_name(std::string name);
bool set_name(const std::string& name);
void add_use(User *user, unsigned arg_no);
void remove_use(User *user, unsigned arg_no);
void replace_all_use_with(Value *new_val);
void replace_use_with_if(Value *new_val, std::function<bool(Use *)> pred);
void replace_all_use_with(Value *new_val) const;
void replace_use_with_if(Value *new_val, const std::function<bool(Use *)>& should_replace);
virtual std::string print() = 0;
template<typename T>
T *as()
{
static_assert(std::is_base_of<Value, T>::value, "T must be a subclass of Value");
static_assert(std::is_base_of_v<Value, T>, "T must be a subclass of Value");
const auto ptr = dynamic_cast<T*>(this);
assert(ptr && "dynamic_cast failed");
return ptr;
}
template<typename T>
[[nodiscard]] const T* as() const {
static_assert(std::is_base_of<Value, T>::value, "T must be a subclass of Value");
const T* as() const {
static_assert(std::is_base_of_v<Value, T>, "T must be a subclass of Value");
const auto ptr = dynamic_cast<const T*>(this);
assert(ptr);
return ptr;
}
// is 接口
template <typename T>
[[nodiscard]] bool is() const {
static_assert(std::is_base_of<Value, T>::value, "T must be a subclass of Value");
bool is() const {
static_assert(std::is_base_of_v<Value, T>, "T must be a subclass of Value");
return dynamic_cast<const T*>(this);
}
......
......@@ -5,21 +5,24 @@
#include "cminusf_builder.hpp"
#define CONST_FP(num) ConstantFP::get((float)num, module.get())
#define CONST_INT(num) ConstantInt::get(num, module.get())
#define CONST_FP(num) ConstantFP::get((float)(num), module)
#define CONST_INT(num) ConstantInt::get((num), module)
// types
Type *VOID_T;
Type *INT1_T;
Type *INT32_T;
Type *INT32PTR_T;
Type *FLOAT_T;
Type *FLOATPTR_T;
bool promote(IRBuilder *builder, Value **l_val_p, Value **r_val_p) {
namespace
{
Type* VOID_T;
Type* INT1_T;
Type* INT32_T;
Type* INT32PTR_T;
Type* FLOAT_T;
Type* FLOATPTR_T;
}
static bool promote(const IRBuilder *builder, Value **l_val_p, Value **r_val_p) {
bool is_int = false;
auto &l_val = *l_val_p;
auto &r_val = *r_val_p;
auto& l_val = *l_val_p;
auto& r_val = *r_val_p;
if (l_val->get_type() == r_val->get_type()) {
is_int = l_val->get_type()->is_integer_type();
} else {
......@@ -63,7 +66,7 @@ Value* CminusfBuilder::visit(ASTNum &node) {
}
Value* CminusfBuilder::visit(ASTVarDeclaration &node) {
Type *var_type = nullptr;
Type *var_type;
if (node.type == TYPE_INT) {
var_type = module->get_int32_type();
} else {
......@@ -72,14 +75,14 @@ Value* CminusfBuilder::visit(ASTVarDeclaration &node) {
if (scope.in_global()) {
if (node.num == nullptr) {
auto *initializer = ConstantZero::get(var_type, module.get());
auto *var = GlobalVariable::create(node.id, module.get(), var_type,
auto *initializer = ConstantZero::get(var_type, module);
auto *var = GlobalVariable::create(node.id, module, var_type,
false, initializer);
scope.push(node.id, var);
} else {
auto *array_type = ArrayType::get(var_type, node.num->i_val);
auto *initializer = ConstantZero::get(array_type, module.get());
auto *var = GlobalVariable::create(node.id, module.get(),
auto *initializer = ConstantZero::get(array_type, module);
auto *var = GlobalVariable::create(node.id, module,
array_type, false, initializer);
scope.push(node.id, var);
}
......@@ -88,7 +91,7 @@ Value* CminusfBuilder::visit(ASTVarDeclaration &node) {
auto nowBB = builder->get_insert_block();
auto entryBB = context.func->get_entry_block();
builder->set_insert_point(entryBB);
auto *var = builder->create_alloca_begin(var_type);
auto *var = builder->create_alloca(var_type);
builder->set_insert_point(nowBB);
scope.push(node.id, var);
} else {
......@@ -96,7 +99,7 @@ Value* CminusfBuilder::visit(ASTVarDeclaration &node) {
auto entryBB = context.func->get_entry_block();
builder->set_insert_point(entryBB);
auto *array_type = ArrayType::get(var_type, node.num->i_val);
auto *var = builder->create_alloca_begin(array_type);
auto *var = builder->create_alloca(array_type);
builder->set_insert_point(nowBB);
scope.push(node.id, var);
}
......@@ -105,8 +108,7 @@ Value* CminusfBuilder::visit(ASTVarDeclaration &node) {
}
Value* CminusfBuilder::visit(ASTFunDeclaration &node) {
FunctionType *fun_type;
Type *ret_type;
Type *ret_type;
std::vector<Type *> param_types;
if (node.type == TYPE_INT)
ret_type = INT32_T;
......@@ -131,11 +133,11 @@ Value* CminusfBuilder::visit(ASTFunDeclaration &node) {
}
}
fun_type = FunctionType::get(ret_type, param_types);
auto func = Function::create(fun_type, node.id, module.get());
FunctionType* fun_type = FunctionType::get(ret_type, param_types);
auto func = Function::create(fun_type, node.id, module);
scope.push(node.id, func);
context.func = func;
auto funBB = BasicBlock::create(module.get(), "entry", func);
auto funBB = BasicBlock::create(module, "entry", func);
builder->set_insert_point(funBB);
scope.enter();
context.pre_enter_scope = true;
......@@ -145,7 +147,7 @@ Value* CminusfBuilder::visit(ASTFunDeclaration &node) {
}
for (unsigned int i = 0; i < node.params.size(); ++i) {
if (node.params[i]->isarray) {
Value *array_alloc = nullptr;
Value *array_alloc;
if (node.params[i]->type == TYPE_INT) {
array_alloc = builder->create_alloca(INT32PTR_T);
} else {
......@@ -154,7 +156,7 @@ Value* CminusfBuilder::visit(ASTFunDeclaration &node) {
builder->create_store(args[i], array_alloc);
scope.push(node.params[i]->id, array_alloc);
} else {
Value *alloc = nullptr;
Value *alloc;
if (node.params[i]->type == TYPE_INT) {
alloc = builder->create_alloca(INT32_T);
} else {
......@@ -216,10 +218,10 @@ Value* CminusfBuilder::visit(ASTExpressionStmt &node) {
Value* CminusfBuilder::visit(ASTSelectionStmt &node) {
auto *ret_val = node.expression->accept(*this);
auto *trueBB = BasicBlock::create(module.get(), "", context.func);
auto *trueBB = BasicBlock::create(module, "", context.func);
BasicBlock *falseBB{};
auto *contBB = BasicBlock::create(module.get(), "", context.func);
Value *cond_val = nullptr;
auto *contBB = BasicBlock::create(module, "", context.func);
Value *cond_val;
if (ret_val->get_type()->is_integer_type()) {
cond_val = builder->create_icmp_ne(ret_val, CONST_INT(0));
} else {
......@@ -229,7 +231,7 @@ Value* CminusfBuilder::visit(ASTSelectionStmt &node) {
if (node.else_statement == nullptr) {
builder->create_cond_br(cond_val, trueBB, contBB);
} else {
falseBB = BasicBlock::create(module.get(), "", context.func);
falseBB = BasicBlock::create(module, "", context.func);
builder->create_cond_br(cond_val, trueBB, falseBB);
}
builder->set_insert_point(trueBB);
......@@ -254,16 +256,16 @@ Value* CminusfBuilder::visit(ASTSelectionStmt &node) {
}
Value* CminusfBuilder::visit(ASTIterationStmt &node) {
auto *exprBB = BasicBlock::create(module.get(), "", context.func);
auto *exprBB = BasicBlock::create(module, "", context.func);
if (not builder->get_insert_block()->is_terminated()) {
builder->create_br(exprBB);
}
builder->set_insert_point(exprBB);
auto *ret_val = node.expression->accept(*this);
auto *trueBB = BasicBlock::create(module.get(), "", context.func);
auto *contBB = BasicBlock::create(module.get(), "", context.func);
Value *cond_val = nullptr;
auto *trueBB = BasicBlock::create(module, "", context.func);
auto *contBB = BasicBlock::create(module, "", context.func);
Value *cond_val;
if (ret_val->get_type()->is_integer_type()) {
cond_val = builder->create_icmp_ne(ret_val, CONST_INT(0));
} else {
......@@ -312,7 +314,7 @@ Value* CminusfBuilder::visit(ASTVar &node) {
var->get_type()->get_pointer_element_type()->is_pointer_type();
bool should_return_lvalue = context.require_lvalue;
context.require_lvalue = false;
Value *ret_val = nullptr;
Value *ret_val;
if (node.expression == nullptr) {
if (should_return_lvalue) {
ret_val = var;
......@@ -327,14 +329,13 @@ Value* CminusfBuilder::visit(ASTVar &node) {
}
} else {
auto *val = node.expression->accept(*this);
Value *is_neg = nullptr;
auto *exceptBB = BasicBlock::create(module.get(), "", context.func);
auto *contBB = BasicBlock::create(module.get(), "", context.func);
auto *exceptBB = BasicBlock::create(module, "", context.func);
auto *contBB = BasicBlock::create(module, "", context.func);
if (val->get_type()->is_float_type()) {
val = builder->create_fptosi(val, INT32_T);
}
is_neg = builder->create_icmp_lt(val, CONST_INT(0));
Value* is_neg = builder->create_icmp_lt(val, CONST_INT(0));
builder->create_cond_br(is_neg, exceptBB, contBB);
builder->set_insert_point(exceptBB);
......@@ -349,7 +350,7 @@ Value* CminusfBuilder::visit(ASTVar &node) {
}
builder->set_insert_point(contBB);
Value *tmp_ptr = nullptr;
Value *tmp_ptr;
if (is_int || is_float) {
tmp_ptr = builder->create_gep(var, {val});
} else if (is_ptr) {
......@@ -391,7 +392,7 @@ Value* CminusfBuilder::visit(ASTSimpleExpression &node) {
auto *l_val = node.additive_expression_l->accept(*this);
auto *r_val = node.additive_expression_r->accept(*this);
bool is_int = promote(&*builder, &l_val, &r_val);
bool is_int = promote(builder, &l_val, &r_val);
Value *cmp = nullptr;
switch (node.op) {
case OP_LT:
......@@ -448,7 +449,7 @@ Value* CminusfBuilder::visit(ASTAdditiveExpression &node) {
auto *l_val = node.additive_expression->accept(*this);
auto *r_val = node.term->accept(*this);
bool is_int = promote(&*builder, &l_val, &r_val);
bool is_int = promote(builder, &l_val, &r_val);
Value *ret_val = nullptr;
switch (node.op) {
case OP_PLUS:
......@@ -476,7 +477,7 @@ Value* CminusfBuilder::visit(ASTTerm &node) {
auto *l_val = node.term->accept(*this);
auto *r_val = node.factor->accept(*this);
bool is_int = promote(&*builder, &l_val, &r_val);
bool is_int = promote(builder, &l_val, &r_val);
Value *ret_val = nullptr;
switch (node.op) {
......@@ -513,7 +514,7 @@ Value* CminusfBuilder::visit(ASTCall &node) {
}
}
args.push_back(arg_val);
param_type++;
++param_type;
}
return builder->create_call(static_cast<Function *>(func), args);
......
......@@ -46,7 +46,7 @@ int main(int argc, char **argv) {
ASTPrinter printer;
ast.run_visitor(printer);
} else {
std::unique_ptr<Module> m;
Module* m;
CminusfBuilder builder;
ast.run_visitor(builder);
m = builder.getModule();
......@@ -58,11 +58,13 @@ int main(int argc, char **argv) {
output_stream << "source_filename = " << abs_path << "\n\n";
output_stream << m->print();
} else if (config.emitasm) {
CodeGen codegen(m.get());
CodeGen codegen(m);
codegen.run();
output_stream << codegen.print();
}
delete m;
// TODO: lab4 (IR optimization or codegen)
}
......
......@@ -15,16 +15,16 @@ void CodeGen::allocate() {
// 为指令结果分配栈空间
for (auto bb : context.func->get_basic_blocks()) {
for (auto& instr : bb->get_instructions()) {
for (auto instr : bb->get_instructions()) {
// 每个非 void 的定值都分配栈空间
if (not instr.is_void()) {
auto size = instr.get_type()->get_size();
if (not instr->is_void()) {
auto size = instr->get_type()->get_size();
offset = offset + size;
context.offset_map[&instr] = -static_cast<int>(offset);
context.offset_map[instr] = -static_cast<int>(offset);
}
// alloca 的副作用:分配额外空间
if (instr.is_alloca()) {
auto *alloca_inst = static_cast<AllocaInst *>(&instr);
if (instr->is_alloca()) {
auto *alloca_inst = dynamic_cast<AllocaInst *>(instr);
auto alloc_size = alloca_inst->get_alloca_type()->get_size();
offset += alloc_size;
}
......@@ -36,20 +36,20 @@ void CodeGen::allocate() {
}
void CodeGen::copy_stmt() {
for (auto &succ : context.bb->get_succ_basic_blocks()) {
for (auto &inst : succ->get_instructions()) {
if (inst.is_phi()) {
for (auto succ : context.bb->get_succ_basic_blocks()) {
for (auto inst : succ->get_instructions()) {
if (inst->is_phi()) {
// 遍历后继块中 phi 的定值 bb
for (unsigned i = 1; i < inst.get_operands().size(); i += 2) {
for (unsigned i = 1; i < inst->get_operands().size(); i += 2) {
// phi 的定值 bb 是当前翻译块
if (inst.get_operand(i) == context.bb) {
auto *lvalue = inst.get_operand(i - 1);
if (inst->get_operand(i) == context.bb) {
auto *lvalue = inst->get_operand(i - 1);
if (lvalue->get_type()->is_float_type()) {
load_to_freg(lvalue, FReg::fa(0));
store_from_freg(&inst, FReg::fa(0));
store_from_freg(inst, FReg::fa(0));
} else {
load_to_greg(lvalue, Reg::a(0));
store_from_greg(&inst, Reg::a(0));
store_from_greg(inst, Reg::a(0));
}
break;
}
......@@ -373,16 +373,16 @@ void CodeGen::run() {
append_inst(".text", ASMInstruction::Atrribute);
append_inst(".section", {".bss", "\"aw\"", "@nobits"},
ASMInstruction::Atrribute);
for (auto &global : m->get_global_variable()) {
for (auto global : m->get_global_variable()) {
auto size =
global.get_type()->get_pointer_element_type()->get_size();
append_inst(".globl", {global.get_name()},
global->get_type()->get_pointer_element_type()->get_size();
append_inst(".globl", {global->get_name()},
ASMInstruction::Atrribute);
append_inst(".type", {global.get_name(), "@object"},
append_inst(".type", {global->get_name(), "@object"},
ASMInstruction::Atrribute);
append_inst(".size", {global.get_name(), std::to_string(size)},
append_inst(".size", {global->get_name(), std::to_string(size)},
ASMInstruction::Atrribute);
append_inst(global.get_name(), ASMInstruction::Label);
append_inst(global->get_name(), ASMInstruction::Label);
append_inst(".space", {std::to_string(size)},
ASMInstruction::Atrribute);
}
......@@ -390,31 +390,31 @@ void CodeGen::run() {
// 函数代码段
output.emplace_back(".text", ASMInstruction::Atrribute);
for (auto &func : m->get_functions()) {
if (not func.is_declaration()) {
for (auto func : m->get_functions()) {
if (not func->is_declaration()) {
// 更新 context
context.clear();
context.func = &func;
context.func = func;
// 函数信息
append_inst(".globl", {func.get_name()}, ASMInstruction::Atrribute);
append_inst(".type", {func.get_name(), "@function"},
append_inst(".globl", {func->get_name()}, ASMInstruction::Atrribute);
append_inst(".type", {func->get_name(), "@function"},
ASMInstruction::Atrribute);
append_inst(func.get_name(), ASMInstruction::Label);
append_inst(func->get_name(), ASMInstruction::Label);
// 分配函数栈帧
allocate();
// 生成 prologue
gen_prologue();
for (auto bb : func.get_basic_blocks()) {
for (auto bb : func->get_basic_blocks()) {
context.bb = bb;
append_inst(label_name(context.bb), ASMInstruction::Label);
for (auto &instr : bb->get_instructions()) {
for (auto instr : bb->get_instructions()) {
// For debug
append_inst(instr.print(), ASMInstruction::Comment);
context.inst = &instr; // 更新 context
switch (instr.get_instr_type()) {
append_inst(instr->print(), ASMInstruction::Comment);
context.inst = instr; // 更新 context
switch (instr->get_instr_type()) {
case Instruction::ret:
gen_ret();
break;
......
......@@ -396,11 +396,13 @@ Value* ASTCall::accept(ASTVisitor &visitor) { return visitor.visit(*this); }
#define _DEBUG_PRINT_N_(N) \
{ std::cout << std::string(N, '-'); }
ASTVisitor::~ASTVisitor() = default;
Value* ASTPrinter::visit(ASTProgram &node) {
_DEBUG_PRINT_N_(depth);
std::cout << "program" << std::endl;
add_depth();
for (auto decl : node.declarations) {
for (auto &decl : node.declarations) {
decl->accept(*this);
}
remove_depth();
......@@ -437,7 +439,7 @@ Value* ASTPrinter::visit(ASTFunDeclaration &node) {
_DEBUG_PRINT_N_(depth);
std::cout << "fun-declaration: " << node.id << std::endl;
add_depth();
for (auto param : node.params) {
for (auto &param : node.params) {
param->accept(*this);
}
......@@ -459,11 +461,11 @@ Value* ASTPrinter::visit(ASTCompoundStmt &node) {
_DEBUG_PRINT_N_(depth);
std::cout << "compound-stmt" << std::endl;
add_depth();
for (auto decl : node.local_declarations) {
for (auto &decl : node.local_declarations) {
decl->accept(*this);
}
for (auto stmt : node.statement_list) {
for (auto &stmt : node.statement_list) {
stmt->accept(*this);
}
remove_depth();
......@@ -625,7 +627,7 @@ Value* ASTPrinter::visit(ASTCall &node) {
_DEBUG_PRINT_N_(depth);
std::cout << "call: " << node.id << "()" << std::endl;
add_depth();
for (auto arg : node.args) {
for (auto &arg : node.args) {
arg->accept(*this);
}
remove_depth();
......
......@@ -13,14 +13,14 @@ BasicBlock::BasicBlock(Module *m, const std::string &name = "",
parent_->add_basic_block(this);
}
Module *BasicBlock::get_module() { return get_parent()->get_parent(); }
Module *BasicBlock::get_module() const { return get_parent()->get_parent(); }
void BasicBlock::erase_from_parent() { this->get_parent()->remove(this); }
bool BasicBlock::is_terminated() const {
if (instr_list_.empty()) {
return false;
}
switch (instr_list_.back().get_instr_type()) {
switch (instr_list_.back()->get_instr_type()) {
case Instruction::ret:
case Instruction::br:
return true;
......@@ -29,17 +29,49 @@ bool BasicBlock::is_terminated() const {
}
}
Instruction *BasicBlock::get_terminator() {
Instruction *BasicBlock::get_terminator() const
{
assert(is_terminated() &&
"Trying to get terminator from an bb which is not terminated");
return &instr_list_.back();
return instr_list_.back();
}
void BasicBlock::add_instruction(Instruction *instr) {
if (instr->is_alloca() || instr->is_phi())
{
auto it = instr_list_.begin();
for (; it != instr_list_.end() && ((*it)->is_alloca() || (*it)->is_phi()); ++it){}
instr_list_.emplace(it, instr);
return;
}
assert(not is_terminated() && "Inserting instruction to terminated bb");
instr_list_.push_back(instr);
}
void BasicBlock::add_instr_begin(Instruction* instr)
{
if (instr->is_alloca() || instr->is_phi())
instr_list_.push_front(instr);
else
{
auto it = instr_list_.begin();
for (; it != instr_list_.end() && ((*it)->is_alloca() || (*it)->is_phi()); ++it) {}
instr_list_.emplace(it, instr);
}
}
void BasicBlock::add_instr_before_terminator(Instruction* instr) {
if (instr->is_alloca() || instr->is_phi())
{
auto it = instr_list_.begin();
for (; it != instr_list_.end() && ((*it)->is_alloca() || (*it)->is_phi()); ++it) {}
instr_list_.emplace(it, instr);
return;
}
if (!is_terminated()) instr_list_.emplace_back(instr);
else instr_list_.insert(std::prev(instr_list_.end()), instr);
}
std::string BasicBlock::print() {
std::string bb_ir;
bb_ir += this->get_name();
......@@ -61,16 +93,22 @@ std::string BasicBlock::print() {
bb_ir += "; Error: Block without parent!";
}
bb_ir += "\n";
for (auto &instr : this->get_instructions()) {
for (auto instr : this->get_instructions()) {
bb_ir += " ";
bb_ir += instr.print();
bb_ir += instr->print();
bb_ir += "\n";
}
return bb_ir;
}
BasicBlock* BasicBlock::get_entry_block_of_same_function(){
assert((not (parent_ == nullptr)) && "bb have no parent function");
BasicBlock::~BasicBlock()
{
for (auto inst : instr_list_) delete inst;
}
BasicBlock* BasicBlock::get_entry_block_of_same_function() const
{
assert(parent_ != nullptr && "bb have no parent function");
return parent_->get_entry_block();
}
\ No newline at end of file
}
......@@ -14,5 +14,4 @@ add_library(
target_link_libraries(
IR_lib
LLVMSupport
)
#include "Constant.hpp"
#include <cstring>
#include "Module.hpp"
#include <iostream>
#include <memory>
#include <sstream>
#include <unordered_map>
......@@ -27,6 +29,8 @@ static std::unordered_map<std::pair<float, Module *>,
cached_float;
static std::unordered_map<Type *, std::unique_ptr<ConstantZero>> cached_zero;
Constant::~Constant() = default;
ConstantInt *ConstantInt::get(int val, Module *m) {
if (cached_int.find(std::make_pair(val, m)) != cached_int.end())
return cached_int[std::make_pair(val, m)].get();
......@@ -62,7 +66,8 @@ ConstantArray::ConstantArray(ArrayType *ty, const std::vector<Constant *> &val)
this->const_array.assign(val.begin(), val.end());
}
Constant *ConstantArray::get_element_value(int index) {
Constant *ConstantArray::get_element_value(int index) const
{
return this->const_array[index];
}
......@@ -76,7 +81,7 @@ std::string ConstantArray::print() {
const_ir += this->get_type()->print();
const_ir += " ";
const_ir += "[";
for (unsigned i = 0; i < this->get_size_of_array(); i++) {
for (int i = 0; i < this->get_size_of_array(); i++) {
Constant *element = get_element_value(i);
if (!dynamic_cast<ConstantArray *>(get_element_value(i))) {
const_ir += element->get_type()->print();
......@@ -102,7 +107,9 @@ std::string ConstantFP::print() {
std::stringstream fp_ir_ss;
std::string fp_ir;
double val = this->get_value();
fp_ir_ss << "0x" << std::hex << *(uint64_t *)&val << std::endl;
uint64_t uval;
memcpy(&uval, &val, sizeof(double));
fp_ir_ss << "0x" << std::hex << uval << '\n';
fp_ir_ss >> fp_ir;
return fp_ir;
}
......
......@@ -15,13 +15,19 @@ Function::Function(FunctionType *ty, const std::string &name, Module *parent)
arguments_.emplace_back(ty->get_param_type(i), "", this, i);
}
}
Function::~Function()
{
for (auto bb : basic_blocks_) delete bb;
}
Function *Function::create(FunctionType *ty, const std::string &name,
Module *parent) {
return new Function(ty, name, parent);
}
FunctionType *Function::get_function_type() const {
return static_cast<FunctionType *>(get_type());
return dynamic_cast<FunctionType *>(get_type());
}
Type *Function::get_return_type() const {
......@@ -32,7 +38,7 @@ unsigned Function::get_num_of_args() const {
return get_function_type()->get_num_of_args();
}
unsigned Function::get_num_basic_blocks() const { return basic_blocks_.size(); }
unsigned Function::get_num_basic_blocks() const { return static_cast<unsigned>(basic_blocks_.size()); }
Module *Function::get_parent() const { return parent_; }
......@@ -65,11 +71,11 @@ void Function::set_instr_name() {
seq.insert({bb, seq_num});
}
}
for (auto &instr : bb->get_instructions()) {
if (!instr.is_void() && seq.find(&instr) == seq.end()) {
for (auto instr : bb->get_instructions()) {
if (!instr->is_void() && seq.find(instr) == seq.end()) {
auto seq_num = seq.size() + seq_cnt_;
if (instr.set_name("op" + std::to_string(seq_num))) {
seq.insert({&instr, seq_num});
if (instr->set_name("op" + std::to_string(seq_num))) {
seq.insert({instr, seq_num});
}
}
}
......@@ -96,9 +102,9 @@ std::string Function::print() {
for (unsigned i = 0; i < this->get_num_of_args(); i++) {
if (i)
func_ir += ", ";
func_ir += static_cast<FunctionType *>(this->get_type())
->get_param_type(i)
->print();
func_ir += dynamic_cast<FunctionType *>(this->get_type())
->get_param_type(i)
->print();
}
} else {
for (auto &arg : get_args()) {
......@@ -115,7 +121,7 @@ std::string Function::print() {
} else {
func_ir += " {";
func_ir += "\n";
for (auto &bb : this->get_basic_blocks()) {
for (auto bb : this->get_basic_blocks()) {
func_ir += bb->print();
}
func_ir += "}";
......@@ -133,7 +139,7 @@ std::string Argument::print() {
}
void Function::check_for_block_relation_error()
void Function::check_for_block_relation_error() const
{
// 检查函数的基本块表是否包含所有且仅包含 get_parent 是本函数的基本块
std::unordered_set<BasicBlock*> bbs;
......@@ -172,7 +178,7 @@ void Function::check_for_block_relation_error()
for (auto bb : basic_blocks_)
{
assert((!bb->get_instructions().empty()) && "发现了空基本块");
auto b = &bb->get_instructions().back();
auto b = bb->get_instructions().back();
assert((b->is_br() || b->is_ret()) && "发现了无 terminator 基本块");
assert((b->is_br() || bb->get_succ_basic_blocks().empty()) && "某基本块末尾是 ret 指令但是后继块表不是空的, 或者末尾是 br 但后继表为空");
}
......@@ -184,12 +190,14 @@ void Function::check_for_block_relation_error()
std::unordered_set<BasicBlock*> br_get;
for (auto suc : bb->get_succ_basic_blocks())
suc_table.emplace(suc);
auto& ops = bb->get_instructions().back().get_operands();
auto& ops = bb->get_instructions().back()->get_operands();
for (auto i : ops)
{
auto bb2 = dynamic_cast<BasicBlock*>(i);
if(bb2 != nullptr) br_get.emplace(bb2);
}
// 这三个检查保证有问题会报错,但不保证每次报错的基本块都相同
// 例如 A, B 两个基本块都存在问题,可能有时 A 报错有时 B 报错
for(auto i : suc_table)
assert(br_get.count(i) && "基本块 A 的后继块有 B,但 B 并未在 A 的 br 指令中出现");
for(auto i : br_get)
......@@ -207,7 +215,7 @@ void Function::check_for_block_relation_error()
}
}
// 检查基本块前驱后继关系是否是与 branch 指令对应
for (auto& bb : basic_blocks_)
for (auto bb : basic_blocks_)
{
for (auto pre : bb->get_pre_basic_blocks())
{
......@@ -225,11 +233,11 @@ void Function::check_for_block_relation_error()
}
}
// 检查指令 parent 设置
for (auto& bb : basic_blocks_)
for (auto bb : basic_blocks_)
{
for (auto& inst : bb->get_instructions())
for (auto inst : bb->get_instructions())
{
assert ((inst.get_parent() == bb) && "基本块 A 指令表包含指令 b, 但是 b 的 get_parent 函数不返回 A");
assert ((inst->get_parent() == bb) && "基本块 A 指令表包含指令 b, 但是 b 的 get_parent 函数不返回 A");
}
}
}
#include "GlobalVariable.hpp"
#include "IRprinter.hpp"
#include "Module.hpp"
GlobalVariable::GlobalVariable(std::string name, Module *m, Type *ty,
GlobalVariable::GlobalVariable(const std::string& name, Module *m, Type *ty,
bool is_const, Constant *init)
: User(ty, name), is_const_(is_const), init_val_(init) {
m->add_global_variable(this);
......@@ -10,7 +11,7 @@ GlobalVariable::GlobalVariable(std::string name, Module *m, Type *ty,
}
} // global操作数为initval
GlobalVariable *GlobalVariable::create(std::string name, Module *m, Type *ty,
GlobalVariable *GlobalVariable::create(const std::string& name, Module *m, Type *ty,
bool is_const,
Constant *init = nullptr) {
return new GlobalVariable(name, m, PointerType::get(ty), is_const, init);
......
#include "IRprinter.hpp"
#include "Instruction.hpp"
#include <cassert>
#include <type_traits>
#include "Constant.hpp"
#include "Function.hpp"
#include "GlobalVariable.hpp"
std::string print_as_op(Value *v, bool print_ty) {
std::string op_ir;
......@@ -10,9 +12,7 @@ std::string print_as_op(Value *v, bool print_ty) {
op_ir += " ";
}
if (dynamic_cast<GlobalVariable *>(v)) {
op_ir += "@" + v->get_name();
} else if (dynamic_cast<Function *>(v)) {
if (dynamic_cast<GlobalVariable *>(v) || dynamic_cast<Function*>(v)) {
op_ir += "@" + v->get_name();
} else if (dynamic_cast<Constant *>(v)) {
op_ir += v->print();
......@@ -91,7 +91,8 @@ std::string print_instr_op_name(Instruction::OpID id) {
assert(false && "Must be bug");
}
template <class BinInst> std::string print_binary_inst(const BinInst &inst) {
template <class BinInst>
static std::string print_binary_inst(const BinInst &inst) {
std::string instr_ir;
instr_ir += "%";
instr_ir += inst.get_name();
......@@ -112,7 +113,8 @@ template <class BinInst> std::string print_binary_inst(const BinInst &inst) {
std::string IBinaryInst::print() { return print_binary_inst(*this); }
std::string FBinaryInst::print() { return print_binary_inst(*this); }
template <class CMP> std::string print_cmp_inst(const CMP &inst) {
template <class CMP>
static std::string print_cmp_inst(const CMP &inst) {
std::string cmp_type;
if (inst.is_cmp())
cmp_type = "icmp";
......
......@@ -7,7 +7,6 @@
#include <algorithm>
#include <cassert>
#include <stdexcept>
#include <string>
#include <vector>
......@@ -17,8 +16,10 @@ Instruction::Instruction(Type *ty, OpID id, BasicBlock *parent)
parent->add_instruction(this);
}
Function *Instruction::get_function() { return parent_->get_parent(); }
Module *Instruction::get_module() { return parent_->get_module(); }
Instruction::~Instruction() = default;
Function *Instruction::get_function() const { return parent_->get_parent(); }
Module *Instruction::get_module() const { return parent_->get_module(); }
std::string Instruction::get_instr_op_name() const {
return print_instr_op_name(op_id_);
......@@ -122,12 +123,12 @@ FCmpInst *FCmpInst::create_fne(Value *v1, Value *v2, BasicBlock *bb) {
return create(fne, v1, v2, bb);
}
CallInst::CallInst(Function *func, std::vector<Value *> args, BasicBlock *bb)
CallInst::CallInst(Function *func, const std::vector<Value *>& args, BasicBlock *bb)
: BaseInst<CallInst>(func->get_return_type(), call, bb) {
assert(func->get_type()->is_function_type() && "Not a function");
assert((func->get_num_of_args() == args.size()) && "Wrong number of args");
add_operand(func);
auto func_type = static_cast<FunctionType *>(func->get_type());
auto func_type = dynamic_cast<FunctionType *>(func->get_type());
for (unsigned i = 0; i < args.size(); i++) {
assert(func_type->get_param_type(i) == args[i]->get_type() &&
"CallInst: Wrong arg type");
......@@ -135,13 +136,13 @@ CallInst::CallInst(Function *func, std::vector<Value *> args, BasicBlock *bb)
}
}
CallInst *CallInst::create_call(Function *func, std::vector<Value *> args,
CallInst *CallInst::create_call(Function *func, const std::vector<Value *>& args,
BasicBlock *bb) {
return create(func, args, bb);
}
FunctionType *CallInst::get_function_type() const {
return static_cast<FunctionType *>(get_operand(0)->get_type());
return dynamic_cast<FunctionType *>(get_operand(0)->get_type());
}
BranchInst::BranchInst(Value *cond, BasicBlock *if_true, BasicBlock *if_false,
......@@ -170,10 +171,10 @@ BranchInst::BranchInst(Value *cond, BasicBlock *if_true, BasicBlock *if_false,
BranchInst::~BranchInst() {
std::list<BasicBlock *> succs;
if (is_cond_br()) {
succs.push_back(static_cast<BasicBlock *>(get_operand(1)));
succs.push_back(static_cast<BasicBlock *>(get_operand(2)));
succs.push_back(dynamic_cast<BasicBlock *>(get_operand(1)));
succs.push_back(dynamic_cast<BasicBlock *>(get_operand(2)));
} else {
succs.push_back(static_cast<BasicBlock *>(get_operand(0)));
succs.push_back(dynamic_cast<BasicBlock *>(get_operand(0)));
}
for (auto succ_bb : succs) {
if (succ_bb) {
......@@ -214,20 +215,20 @@ ReturnInst *ReturnInst::create_void_ret(BasicBlock *bb) {
bool ReturnInst::is_void_ret() const { return get_num_operand() == 0; }
GetElementPtrInst::GetElementPtrInst(Value *ptr, std::vector<Value *> idxs,
GetElementPtrInst::GetElementPtrInst(Value *ptr, const std::vector<Value *>& idxs,
BasicBlock *bb)
: BaseInst<GetElementPtrInst>(PointerType::get(get_element_type(ptr, idxs)),
getelementptr, bb) {
add_operand(ptr);
for (unsigned i = 0; i < idxs.size(); i++) {
Value *idx = idxs[i];
assert(idx->get_type()->is_integer_type() && "Index is not integer");
for (auto idx : idxs)
{
assert(idx->get_type()->is_integer_type() && "Index is not integer");
add_operand(idx);
}
}
Type *GetElementPtrInst::get_element_type(Value *ptr,
std::vector<Value *> idxs) {
Type *GetElementPtrInst::get_element_type(const Value *ptr,
const std::vector<Value *>& idxs) {
assert(ptr->get_type()->is_pointer_type() &&
"GetElementPtrInst ptr is not a pointer");
......@@ -236,14 +237,14 @@ Type *GetElementPtrInst::get_element_type(Value *ptr,
"GetElementPtrInst ptr is wrong type" &&
(ty->is_array_type() || ty->is_integer_type() || ty->is_float_type()));
if (ty->is_array_type()) {
ArrayType *arr_ty = static_cast<ArrayType *>(ty);
ArrayType *arr_ty = dynamic_cast<ArrayType *>(ty);
for (unsigned i = 1; i < idxs.size(); i++) {
ty = arr_ty->get_element_type();
if (i < idxs.size() - 1) {
assert(ty->is_array_type() && "Index error!");
}
if (ty->is_array_type()) {
arr_ty = static_cast<ArrayType *>(ty);
arr_ty = dynamic_cast<ArrayType *>(ty);
}
}
}
......@@ -255,7 +256,7 @@ Type *GetElementPtrInst::get_element_type() const {
}
GetElementPtrInst *GetElementPtrInst::create_gep(Value *ptr,
std::vector<Value *> idxs,
const std::vector<Value *>& idxs,
BasicBlock *bb) {
return create(ptr, idxs, bb);
}
......@@ -287,7 +288,7 @@ LoadInst *LoadInst::create_load(Value *ptr, BasicBlock *bb) {
AllocaInst::AllocaInst(Type *ty, BasicBlock *bb)
: BaseInst<AllocaInst>(PointerType::get(ty), alloca, bb) {
static const std::array allowed_alloc_type = {
static constexpr std::array allowed_alloc_type = {
Type::IntegerTyID, Type::FloatTyID, Type::ArrayTyID, Type::PointerTyID};
assert(std::find(allowed_alloc_type.begin(), allowed_alloc_type.end(),
ty->get_type_id()) != allowed_alloc_type.end() &&
......@@ -298,26 +299,13 @@ AllocaInst *AllocaInst::create_alloca(Type *ty, BasicBlock *bb) {
return create(ty, bb);
}
AllocaInst *AllocaInst::create_alloca_begin(Type *ty, BasicBlock *bb) {
auto ret = create(ty, nullptr);
if(bb != nullptr)
{
ret->set_parent(bb);
if (bb->is_terminated())
bb->add_instr_before_end(ret);
else
bb->add_instruction(ret);
}
return ret;
}
ZextInst::ZextInst(Value *val, Type *ty, BasicBlock *bb)
: BaseInst<ZextInst>(ty, zext, bb) {
assert(val->get_type()->is_integer_type() &&
"ZextInst operand is not integer");
assert(ty->is_integer_type() && "ZextInst destination type is not integer");
assert((static_cast<IntegerType *>(val->get_type())->get_num_bits() <
static_cast<IntegerType *>(ty)->get_num_bits()) &&
assert((dynamic_cast<IntegerType *>(val->get_type())->get_num_bits() <
dynamic_cast<IntegerType *>(ty)->get_num_bits()) &&
"ZextInst operand bit size is not smaller than destination type bit "
"size");
add_operand(val);
......@@ -358,8 +346,8 @@ SiToFpInst *SiToFpInst::create_sitofp(Value *val, BasicBlock *bb) {
return create(val, bb->get_module()->get_float_type(), bb);
}
PhiInst::PhiInst(Type *ty, std::vector<Value *> vals,
std::vector<BasicBlock *> val_bbs, BasicBlock *bb)
PhiInst::PhiInst(Type *ty, const std::vector<Value *>& vals,
const std::vector<BasicBlock *>& val_bbs, BasicBlock *bb)
: BaseInst<PhiInst>(ty, phi) {
assert(vals.size() == val_bbs.size() && "Unmatched vals and bbs");
for (unsigned i = 0; i < vals.size(); i++) {
......@@ -371,7 +359,7 @@ PhiInst::PhiInst(Type *ty, std::vector<Value *> vals,
}
PhiInst *PhiInst::create_phi(Type *ty, BasicBlock *bb,
std::vector<Value *> vals,
std::vector<BasicBlock *> val_bbs) {
const std::vector<Value *>& vals,
const std::vector<BasicBlock *>& val_bbs) {
return create(ty, vals, val_bbs, bb);
}
......@@ -6,61 +6,75 @@
#include <string>
Module::Module() {
void_ty_ = std::make_unique<Type>(Type::VoidTyID, this);
label_ty_ = std::make_unique<Type>(Type::LabelTyID, this);
int1_ty_ = std::make_unique<IntegerType>(1, this);
int32_ty_ = std::make_unique<IntegerType>(32, this);
float32_ty_ = std::make_unique<FloatType>(this);
void_ty_ = new Type(Type::VoidTyID, this);
label_ty_ = new Type(Type::LabelTyID, this);
int1_ty_ = new IntegerType(1, this);
int32_ty_ = new IntegerType(32, this);
float32_ty_ = new FloatType(this);
}
Type *Module::get_void_type() { return void_ty_.get(); }
Type *Module::get_label_type() { return label_ty_.get(); }
IntegerType *Module::get_int1_type() { return int1_ty_.get(); }
IntegerType *Module::get_int32_type() { return int32_ty_.get(); }
FloatType *Module::get_float_type() { return float32_ty_.get(); }
Module::~Module()
{
delete void_ty_;
delete label_ty_;
delete int1_ty_;
delete int32_ty_;
delete float32_ty_;
for (auto& i : pointer_map_) delete i.second;
for (auto& i : array_map_) delete i.second;
for (auto& i : function_map_) delete i.second;
for (auto i : function_list_) delete i;
for (auto i : global_list_) delete i;
}
Type *Module::get_void_type() const { return void_ty_; }
Type *Module::get_label_type() const { return label_ty_; }
IntegerType *Module::get_int1_type() const { return int1_ty_; }
IntegerType *Module::get_int32_type() const { return int32_ty_; }
FloatType *Module::get_float_type() const { return float32_ty_; }
PointerType *Module::get_int32_ptr_type() {
return get_pointer_type(int32_ty_.get());
return get_pointer_type(int32_ty_);
}
PointerType *Module::get_float_ptr_type() {
return get_pointer_type(float32_ty_.get());
return get_pointer_type(float32_ty_);
}
PointerType *Module::get_pointer_type(Type *contained) {
if (pointer_map_.find(contained) == pointer_map_.end()) {
pointer_map_[contained] = std::make_unique<PointerType>(contained);
pointer_map_[contained] = new PointerType(contained);
}
return pointer_map_[contained].get();
return pointer_map_[contained];
}
ArrayType *Module::get_array_type(Type *contained, unsigned num_elements) {
if (array_map_.find({contained, num_elements}) == array_map_.end()) {
array_map_[{contained, num_elements}] =
std::make_unique<ArrayType>(contained, num_elements);
new ArrayType(contained, num_elements);
}
return array_map_[{contained, num_elements}].get();
return array_map_[{contained, num_elements}];
}
FunctionType *Module::get_function_type(Type *retty,
std::vector<Type *> &args) {
if (not function_map_.count({retty, args})) {
function_map_[{retty, args}] =
std::make_unique<FunctionType>(retty, args);
new FunctionType(retty, args);
}
return function_map_[{retty, args}].get();
return function_map_[{retty, args}];
}
void Module::add_function(Function *f) { function_list_.push_back(f); }
llvm::ilist<Function> &Module::get_functions() { return function_list_; }
std::list<Function*> &Module::get_functions() { return function_list_; }
void Module::add_global_variable(GlobalVariable *g) {
global_list_.push_back(g);
}
llvm::ilist<GlobalVariable> &Module::get_global_variable() {
std::list<GlobalVariable*> &Module::get_global_variable() {
return global_list_;
}
void Module::set_print_name() {
for (auto &func : this->get_functions()) {
func.set_instr_name();
for (auto func : this->get_functions()) {
func->set_instr_name();
}
return;
}
......@@ -68,12 +82,12 @@ void Module::set_print_name() {
std::string Module::print() {
set_print_name();
std::string module_ir;
for (auto &global_val : this->global_list_) {
module_ir += global_val.print();
for (auto global_val : this->global_list_) {
module_ir += global_val->print();
module_ir += "\n";
}
for (auto &func : this->function_list_) {
module_ir += func.print();
for (auto func : this->function_list_) {
module_ir += func->print();
module_ir += "\n";
}
return module_ir;
......
......@@ -3,31 +3,32 @@
#include <array>
#include <cassert>
#include <stdexcept>
Type::Type(TypeID tid, Module *m) {
tid_ = tid;
m_ = m;
}
Type::~Type() = default;
bool Type::is_int1_type() const {
return is_integer_type() and
static_cast<const IntegerType *>(this)->get_num_bits() == 1;
dynamic_cast<const IntegerType *>(this)->get_num_bits() == 1;
}
bool Type::is_int32_type() const {
return is_integer_type() and
static_cast<const IntegerType *>(this)->get_num_bits() == 32;
dynamic_cast<const IntegerType *>(this)->get_num_bits() == 32;
}
Type *Type::get_pointer_element_type() const {
if (this->is_pointer_type())
return static_cast<const PointerType *>(this)->get_element_type();
return dynamic_cast<const PointerType *>(this)->get_element_type();
assert(false and "get_pointer_element_type() called on non-pointer type");
}
Type *Type::get_array_element_type() const {
if (this->is_array_type())
return static_cast<const ArrayType *>(this)->get_element_type();
return dynamic_cast<const ArrayType *>(this)->get_element_type();
assert(false and "get_array_element_type() called on non-array type");
}
......@@ -42,7 +43,7 @@ unsigned Type::get_size() const {
assert(false && "Type::get_size(): unexpected int type bits");
}
case ArrayTyID: {
auto array_type = static_cast<const ArrayType *>(this);
auto array_type = dynamic_cast<const ArrayType *>(this);
auto element_size = array_type->get_element_type()->get_size();
auto num_elements = array_type->get_num_of_elements();
return element_size * num_elements;
......@@ -71,20 +72,20 @@ std::string Type::print() const {
case IntegerTyID:
type_ir += "i";
type_ir += std::to_string(
static_cast<const IntegerType *>(this)->get_num_bits());
dynamic_cast<const IntegerType *>(this)->get_num_bits());
break;
case FunctionTyID:
type_ir +=
static_cast<const FunctionType *>(this)->get_return_type()->print();
dynamic_cast<const FunctionType *>(this)->get_return_type()->print();
type_ir += " (";
for (unsigned i = 0;
i < static_cast<const FunctionType *>(this)->get_num_of_args();
i < dynamic_cast<const FunctionType *>(this)->get_num_of_args();
i++) {
if (i)
type_ir += ", ";
type_ir += static_cast<const FunctionType *>(this)
->get_param_type(i)
->print();
type_ir += dynamic_cast<const FunctionType *>(this)
->get_param_type(i)
->print();
}
type_ir += ")";
break;
......@@ -95,10 +96,10 @@ std::string Type::print() const {
case ArrayTyID:
type_ir += "[";
type_ir += std::to_string(
static_cast<const ArrayType *>(this)->get_num_of_elements());
dynamic_cast<const ArrayType *>(this)->get_num_of_elements());
type_ir += " x ";
type_ir +=
static_cast<const ArrayType *>(this)->get_element_type()->print();
dynamic_cast<const ArrayType *>(this)->get_element_type()->print();
type_ir += "]";
break;
case FloatTyID:
......@@ -113,9 +114,11 @@ std::string Type::print() const {
IntegerType::IntegerType(unsigned num_bits, Module *m)
: Type(Type::IntegerTyID, m), num_bits_(num_bits) {}
IntegerType::~IntegerType() = default;
unsigned IntegerType::get_num_bits() const { return num_bits_; }
FunctionType::FunctionType(Type *result, std::vector<Type *> params)
FunctionType::FunctionType(Type *result, const std::vector<Type *>& params)
: Type(Type::FunctionTyID, nullptr) {
assert(is_valid_return_type(result) && "Invalid return type for function!");
result_ = result;
......@@ -127,11 +130,13 @@ FunctionType::FunctionType(Type *result, std::vector<Type *> params)
}
}
bool FunctionType::is_valid_return_type(Type *ty) {
FunctionType::~FunctionType() = default;
bool FunctionType::is_valid_return_type(const Type *ty) {
return ty->is_integer_type() || ty->is_void_type() || ty->is_float_type();
}
bool FunctionType::is_valid_argument_type(Type *ty) {
bool FunctionType::is_valid_argument_type(const Type *ty) {
return ty->is_integer_type() || ty->is_pointer_type() ||
ty->is_float_type();
}
......@@ -140,7 +145,7 @@ FunctionType *FunctionType::get(Type *result, std::vector<Type *> params) {
return result->get_module()->get_function_type(result, params);
}
unsigned FunctionType::get_num_of_args() const { return args_.size(); }
unsigned FunctionType::get_num_of_args() const { return static_cast<unsigned>(args_.size()); }
Type *FunctionType::get_param_type(unsigned i) const { return args_[i]; }
......@@ -154,7 +159,9 @@ ArrayType::ArrayType(Type *contained, unsigned num_elements)
contained_ = contained;
}
bool ArrayType::is_valid_element_type(Type *ty) {
ArrayType::~ArrayType() = default;
bool ArrayType::is_valid_element_type(const Type *ty) {
return ty->is_integer_type() || ty->is_array_type() || ty->is_float_type();
}
......@@ -164,7 +171,7 @@ ArrayType *ArrayType::get(Type *contained, unsigned num_elements) {
PointerType::PointerType(Type *contained)
: Type(Type::PointerTyID, contained->get_module()), contained_(contained) {
static const std::array allowed_elem_type = {
static constexpr std::array allowed_elem_type = {
Type::IntegerTyID, Type::FloatTyID, Type::ArrayTyID, Type::PointerTyID};
auto elem_type_id = contained->get_type_id();
assert(std::find(allowed_elem_type.begin(), allowed_elem_type.end(),
......@@ -172,10 +179,14 @@ PointerType::PointerType(Type *contained)
"Not allowed type for pointer");
}
PointerType::~PointerType() = default;
PointerType *PointerType::get(Type *contained) {
return contained->get_module()->get_pointer_type(contained);
}
FloatType::FloatType(Module *m) : Type(Type::FloatTyID, m) {}
FloatType::~FloatType() = default;
FloatType *FloatType::get(Module *m) { return m->get_float_type(); }
......@@ -2,6 +2,8 @@
#include <cassert>
User::~User() { remove_all_operands(); }
void User::set_operand(unsigned i, Value *v) {
assert(i < operands_.size() && "set_operand out of index");
if (operands_[i]) { // old operand
......@@ -15,7 +17,7 @@ void User::set_operand(unsigned i, Value *v) {
void User::add_operand(Value *v) {
assert(v != nullptr && "bad use: add_operand(nullptr)");
v->add_use(this, operands_.size());
v->add_use(this, static_cast<unsigned>(operands_.size()));
operands_.push_back(v);
}
......
#include "Value.hpp"
#include "Type.hpp"
#include "User.hpp"
#include <cassert>
bool Value::set_name(std::string name) {
if (name_ == "") {
Value::~Value() { replace_all_use_with(nullptr); }
bool Value::set_name(const std::string& name) {
if (name_.empty()) {
name_ = name;
return true;
}
......@@ -21,21 +21,22 @@ void Value::remove_use(User *user, unsigned arg_no) {
use_list_.remove_if([&](const Use &use) { return use == target_use; });
}
void Value::replace_all_use_with(Value *new_val) {
void Value::replace_all_use_with(Value *new_val) const
{
if (this == new_val)
return;
while (use_list_.size()) {
while (!use_list_.empty()) {
auto use = use_list_.begin();
use->val_->set_operand(use->arg_no_, new_val);
}
}
void Value::replace_use_with_if(Value *new_val,
std::function<bool(Use *)> should_replace) {
const std::function<bool(Use *)>& should_replace) {
if (this == new_val)
return;
for (auto iter = use_list_.begin(); iter != use_list_.end();) {
auto use = *iter++;
auto &use = *iter++;
if (not should_replace(&use))
continue;
use.val_->set_operand(use.arg_no_, new_val);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment