Commit 04681c6c authored by Yang's avatar Yang

remove llvm

parent 606023f0
build/ build/
.cache/ .cache/
.vs
out
CMakePresets.json
Folder.DotSettings.user
\ No newline at end of file
{
"configurations": [
{
"type": "lldb",
"request": "launch",
"name": "Compile & Debug 1.ll",
// 要调试的程序
"program": "${workspaceFolder}/build/cminusfc",
// 命令行参数
"args": [
"-o",
"./build/1.ll",
"-emit-llvm",
"./build/1.cminus"
],
// 程序运行的目录
"cwd": "${workspaceFolder}",
// 程序运行前运行的命令(例如 build)
"preLaunchTask": "make cminusfc"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug 1.ll",
// 要调试的程序
"program": "${workspaceFolder}/build/cminusfc",
// 命令行参数
"args": [
"-o",
"./build/1.ll",
"-emit-llvm",
"./build/1.cminus"
],
// 程序运行的目录
"cwd": "${workspaceFolder}"
}
]
}
\ No newline at end of file
{
"version": "2.0.0",
"tasks": [
{
"type": "shell",
"label": "make cminusfc",
"command": "cd build && make"
}
]
}
\ No newline at end of file
...@@ -29,14 +29,6 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON) ...@@ -29,14 +29,6 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
find_package(FLEX REQUIRED) find_package(FLEX REQUIRED)
find_package(BISON REQUIRED) find_package(BISON REQUIRED)
find_package(LLVM REQUIRED CONFIG)
message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}")
message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}")
llvm_map_components_to_libnames(
llvm_libs
support
core
)
INCLUDE_DIRECTORIES( INCLUDE_DIRECTORIES(
include include
...@@ -44,11 +36,8 @@ INCLUDE_DIRECTORIES( ...@@ -44,11 +36,8 @@ INCLUDE_DIRECTORIES(
include/common include/common
include/lightir include/lightir
include/codegen include/codegen
${LLVM_INCLUDE_DIRS}
) )
add_definitions(${LLVM_DEFINITIONS})
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR})
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR})
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR})
......
#pragma once #pragma once
#include "BasicBlock.hpp" #include "BasicBlock.hpp"
#include "Constant.hpp"
#include "Function.hpp" #include "Function.hpp"
#include "IRBuilder.hpp" #include "IRBuilder.hpp"
#include "Module.hpp" #include "Module.hpp"
...@@ -19,7 +18,7 @@ class Scope { ...@@ -19,7 +18,7 @@ class Scope {
// exit a scope // exit a scope
void exit() { inner.pop_back(); } void exit() { inner.pop_back(); }
bool in_global() { return inner.size() == 1; } bool in_global() const { return inner.size() == 1; }
// push a name to scope // push a name to scope
// return true if successful // return true if successful
...@@ -30,7 +29,7 @@ class Scope { ...@@ -30,7 +29,7 @@ class Scope {
} }
Value *find(const std::string& name) { Value *find(const std::string& name) {
for (auto s = inner.rbegin(); s != inner.rend(); s++) { for (auto s = inner.rbegin(); s != inner.rend(); ++s) {
auto iter = s->find(name); auto iter = s->find(name);
if (iter != s->end()) { if (iter != s->end()) {
return iter->second; return iter->second;
...@@ -39,8 +38,6 @@ class Scope { ...@@ -39,8 +38,6 @@ class Scope {
// Name not found: handled here? // Name not found: handled here?
assert(false && "Name not found in scope"); assert(false && "Name not found in scope");
return nullptr;
} }
private: private:
...@@ -50,29 +47,29 @@ class Scope { ...@@ -50,29 +47,29 @@ class Scope {
class CminusfBuilder : public ASTVisitor { class CminusfBuilder : public ASTVisitor {
public: public:
CminusfBuilder() { CminusfBuilder() {
module = std::make_unique<Module>(); module = new Module();
builder = std::make_unique<IRBuilder>(nullptr, module.get()); builder = new IRBuilder(nullptr, module);
auto *TyVoid = module->get_void_type(); auto *TyVoid = module->get_void_type();
auto *TyInt32 = module->get_int32_type(); auto *TyInt32 = module->get_int32_type();
auto *TyFloat = module->get_float_type(); auto *TyFloat = module->get_float_type();
auto *input_type = FunctionType::get(TyInt32, {}); auto *input_type = FunctionType::get(TyInt32, {});
auto *input_fun = Function::create(input_type, "input", module.get()); auto *input_fun = Function::create(input_type, "input", module);
std::vector<Type *> output_params; std::vector<Type *> output_params;
output_params.push_back(TyInt32); output_params.push_back(TyInt32);
auto *output_type = FunctionType::get(TyVoid, output_params); auto *output_type = FunctionType::get(TyVoid, output_params);
auto *output_fun = Function::create(output_type, "output", module.get()); auto *output_fun = Function::create(output_type, "output", module);
std::vector<Type *> output_float_params; std::vector<Type *> output_float_params;
output_float_params.push_back(TyFloat); output_float_params.push_back(TyFloat);
auto *output_float_type = FunctionType::get(TyVoid, output_float_params); auto *output_float_type = FunctionType::get(TyVoid, output_float_params);
auto *output_float_fun = auto *output_float_fun =
Function::create(output_float_type, "outputFloat", module.get()); Function::create(output_float_type, "outputFloat", module);
auto *neg_idx_except_type = FunctionType::get(TyVoid, {}); auto *neg_idx_except_type = FunctionType::get(TyVoid, {});
auto *neg_idx_except_fun = Function::create( auto *neg_idx_except_fun = Function::create(
neg_idx_except_type, "neg_idx_except", module.get()); neg_idx_except_type, "neg_idx_except", module);
scope.enter(); scope.enter();
scope.push("input", input_fun); scope.push("input", input_fun);
...@@ -81,29 +78,31 @@ class CminusfBuilder : public ASTVisitor { ...@@ -81,29 +78,31 @@ class CminusfBuilder : public ASTVisitor {
scope.push("neg_idx_except", neg_idx_except_fun); scope.push("neg_idx_except", neg_idx_except_fun);
} }
std::unique_ptr<Module> getModule() { return std::move(module); } Module* getModule() const { return module; }
~CminusfBuilder() override { delete builder; }
private: private:
virtual Value *visit(ASTProgram &) override final; Value *visit(ASTProgram &) final;
virtual Value *visit(ASTNum &) override final; Value *visit(ASTNum &) final;
virtual Value *visit(ASTVarDeclaration &) override final; Value *visit(ASTVarDeclaration &) final;
virtual Value *visit(ASTFunDeclaration &) override final; Value *visit(ASTFunDeclaration &) final;
virtual Value *visit(ASTParam &) override final; Value *visit(ASTParam &) final;
virtual Value *visit(ASTCompoundStmt &) override final; Value *visit(ASTCompoundStmt &) final;
virtual Value *visit(ASTExpressionStmt &) override final; Value *visit(ASTExpressionStmt &) final;
virtual Value *visit(ASTSelectionStmt &) override final; Value *visit(ASTSelectionStmt &) final;
virtual Value *visit(ASTIterationStmt &) override final; Value *visit(ASTIterationStmt &) final;
virtual Value *visit(ASTReturnStmt &) override final; Value *visit(ASTReturnStmt &) final;
virtual Value *visit(ASTAssignExpression &) override final; Value *visit(ASTAssignExpression &) final;
virtual Value *visit(ASTSimpleExpression &) override final; Value *visit(ASTSimpleExpression &) final;
virtual Value *visit(ASTAdditiveExpression &) override final; Value *visit(ASTAdditiveExpression &) final;
virtual Value *visit(ASTVar &) override final; Value *visit(ASTVar &) final;
virtual Value *visit(ASTTerm &) override final; Value *visit(ASTTerm &) final;
virtual Value *visit(ASTCall &) override final; Value *visit(ASTCall &) final;
std::unique_ptr<IRBuilder> builder; IRBuilder* builder;
Scope scope; Scope scope;
std::unique_ptr<Module> module; Module* module;
struct { struct {
// whether require lvalue // whether require lvalue
......
...@@ -90,23 +90,23 @@ struct ASTNode { ...@@ -90,23 +90,23 @@ struct ASTNode {
}; };
struct ASTProgram : ASTNode { struct ASTProgram : ASTNode {
virtual Value* accept(ASTVisitor &) override final; Value* accept(ASTVisitor &) final;
virtual ~ASTProgram() = default; ~ASTProgram() override = default;
std::vector<std::shared_ptr<ASTDeclaration>> declarations; std::vector<std::shared_ptr<ASTDeclaration>> declarations;
}; };
struct ASTDeclaration : ASTNode { struct ASTDeclaration : ASTNode {
virtual ~ASTDeclaration() = default; ~ASTDeclaration() override = default;
CminusType type; CminusType type;
std::string id; std::string id;
}; };
struct ASTFactor : ASTNode { struct ASTFactor : ASTNode {
virtual ~ASTFactor() = default; ~ASTFactor() override = default;
}; };
struct ASTNum : ASTFactor { struct ASTNum : ASTFactor {
virtual Value* accept(ASTVisitor &) override final; Value* accept(ASTVisitor &) final;
CminusType type; CminusType type;
union { union {
int i_val; int i_val;
...@@ -115,18 +115,18 @@ struct ASTNum : ASTFactor { ...@@ -115,18 +115,18 @@ struct ASTNum : ASTFactor {
}; };
struct ASTVarDeclaration : ASTDeclaration { struct ASTVarDeclaration : ASTDeclaration {
virtual Value* accept(ASTVisitor &) override final; Value* accept(ASTVisitor &) final;
std::shared_ptr<ASTNum> num; std::shared_ptr<ASTNum> num;
}; };
struct ASTFunDeclaration : ASTDeclaration { struct ASTFunDeclaration : ASTDeclaration {
virtual Value* accept(ASTVisitor &) override final; Value* accept(ASTVisitor &) final;
std::vector<std::shared_ptr<ASTParam>> params; std::vector<std::shared_ptr<ASTParam>> params;
std::shared_ptr<ASTCompoundStmt> compound_stmt; std::shared_ptr<ASTCompoundStmt> compound_stmt;
}; };
struct ASTParam : ASTNode { struct ASTParam : ASTNode {
virtual Value* accept(ASTVisitor &) override final; Value* accept(ASTVisitor &) final;
CminusType type; CminusType type;
std::string id; std::string id;
// true if it is array param // true if it is array param
...@@ -134,22 +134,22 @@ struct ASTParam : ASTNode { ...@@ -134,22 +134,22 @@ struct ASTParam : ASTNode {
}; };
struct ASTStatement : ASTNode { struct ASTStatement : ASTNode {
virtual ~ASTStatement() = default; ~ASTStatement() override = default;
}; };
struct ASTCompoundStmt : ASTStatement { struct ASTCompoundStmt : ASTStatement {
virtual Value* accept(ASTVisitor &) override final; Value* accept(ASTVisitor &) final;
std::vector<std::shared_ptr<ASTVarDeclaration>> local_declarations; std::vector<std::shared_ptr<ASTVarDeclaration>> local_declarations;
std::vector<std::shared_ptr<ASTStatement>> statement_list; std::vector<std::shared_ptr<ASTStatement>> statement_list;
}; };
struct ASTExpressionStmt : ASTStatement { struct ASTExpressionStmt : ASTStatement {
virtual Value* accept(ASTVisitor &) override final; Value* accept(ASTVisitor &) final;
std::shared_ptr<ASTExpression> expression; std::shared_ptr<ASTExpression> expression;
}; };
struct ASTSelectionStmt : ASTStatement { struct ASTSelectionStmt : ASTStatement {
virtual Value* accept(ASTVisitor &) override final; Value* accept(ASTVisitor &) final;
std::shared_ptr<ASTExpression> expression; std::shared_ptr<ASTExpression> expression;
std::shared_ptr<ASTStatement> if_statement; std::shared_ptr<ASTStatement> if_statement;
// should be nullptr if no else structure exists // should be nullptr if no else structure exists
...@@ -157,13 +157,13 @@ struct ASTSelectionStmt : ASTStatement { ...@@ -157,13 +157,13 @@ struct ASTSelectionStmt : ASTStatement {
}; };
struct ASTIterationStmt : ASTStatement { struct ASTIterationStmt : ASTStatement {
virtual Value* accept(ASTVisitor &) override final; Value* accept(ASTVisitor &) final;
std::shared_ptr<ASTExpression> expression; std::shared_ptr<ASTExpression> expression;
std::shared_ptr<ASTStatement> statement; std::shared_ptr<ASTStatement> statement;
}; };
struct ASTReturnStmt : ASTStatement { struct ASTReturnStmt : ASTStatement {
virtual Value* accept(ASTVisitor &) override final; Value* accept(ASTVisitor &) final;
// should be nullptr if return void // should be nullptr if return void
std::shared_ptr<ASTExpression> expression; std::shared_ptr<ASTExpression> expression;
}; };
...@@ -171,41 +171,41 @@ struct ASTReturnStmt : ASTStatement { ...@@ -171,41 +171,41 @@ struct ASTReturnStmt : ASTStatement {
struct ASTExpression : ASTFactor {}; struct ASTExpression : ASTFactor {};
struct ASTAssignExpression : ASTExpression { struct ASTAssignExpression : ASTExpression {
virtual Value* accept(ASTVisitor &) override final; Value* accept(ASTVisitor &) final;
std::shared_ptr<ASTVar> var; std::shared_ptr<ASTVar> var;
std::shared_ptr<ASTExpression> expression; std::shared_ptr<ASTExpression> expression;
}; };
struct ASTSimpleExpression : ASTExpression { struct ASTSimpleExpression : ASTExpression {
virtual Value* accept(ASTVisitor &) override final; Value* accept(ASTVisitor &) final;
std::shared_ptr<ASTAdditiveExpression> additive_expression_l; std::shared_ptr<ASTAdditiveExpression> additive_expression_l;
std::shared_ptr<ASTAdditiveExpression> additive_expression_r; std::shared_ptr<ASTAdditiveExpression> additive_expression_r;
RelOp op; RelOp op;
}; };
struct ASTVar : ASTFactor { struct ASTVar : ASTFactor {
virtual Value* accept(ASTVisitor &) override final; Value* accept(ASTVisitor &) final;
std::string id; std::string id;
// nullptr if var is of int type // nullptr if var is of int type
std::shared_ptr<ASTExpression> expression; std::shared_ptr<ASTExpression> expression;
}; };
struct ASTAdditiveExpression : ASTNode { struct ASTAdditiveExpression : ASTNode {
virtual Value* accept(ASTVisitor &) override final; Value* accept(ASTVisitor &) final;
std::shared_ptr<ASTAdditiveExpression> additive_expression; std::shared_ptr<ASTAdditiveExpression> additive_expression;
AddOp op; AddOp op;
std::shared_ptr<ASTTerm> term; std::shared_ptr<ASTTerm> term;
}; };
struct ASTTerm : ASTNode { struct ASTTerm : ASTNode {
virtual Value* accept(ASTVisitor &) override final; Value* accept(ASTVisitor &) final;
std::shared_ptr<ASTTerm> term; std::shared_ptr<ASTTerm> term;
MulOp op; MulOp op;
std::shared_ptr<ASTFactor> factor; std::shared_ptr<ASTFactor> factor;
}; };
struct ASTCall : ASTFactor { struct ASTCall : ASTFactor {
virtual Value* accept(ASTVisitor &) override final; Value* accept(ASTVisitor &) final;
std::string id; std::string id;
std::vector<std::shared_ptr<ASTExpression>> args; std::vector<std::shared_ptr<ASTExpression>> args;
}; };
...@@ -228,26 +228,27 @@ class ASTVisitor { ...@@ -228,26 +228,27 @@ class ASTVisitor {
virtual Value* visit(ASTVar &) = 0; virtual Value* visit(ASTVar &) = 0;
virtual Value* visit(ASTTerm &) = 0; virtual Value* visit(ASTTerm &) = 0;
virtual Value* visit(ASTCall &) = 0; virtual Value* visit(ASTCall &) = 0;
virtual ~ASTVisitor();
}; };
class ASTPrinter : public ASTVisitor { class ASTPrinter : public ASTVisitor {
public: public:
virtual Value* visit(ASTProgram &) override final; Value* visit(ASTProgram &) final;
virtual Value* visit(ASTNum &) override final; Value* visit(ASTNum &) final;
virtual Value* visit(ASTVarDeclaration &) override final; Value* visit(ASTVarDeclaration &) final;
virtual Value* visit(ASTFunDeclaration &) override final; Value* visit(ASTFunDeclaration &) final;
virtual Value* visit(ASTParam &) override final; Value* visit(ASTParam &) final;
virtual Value* visit(ASTCompoundStmt &) override final; Value* visit(ASTCompoundStmt &) final;
virtual Value* visit(ASTExpressionStmt &) override final; Value* visit(ASTExpressionStmt &) final;
virtual Value* visit(ASTSelectionStmt &) override final; Value* visit(ASTSelectionStmt &) final;
virtual Value* visit(ASTIterationStmt &) override final; Value* visit(ASTIterationStmt &) final;
virtual Value* visit(ASTReturnStmt &) override final; Value* visit(ASTReturnStmt &) final;
virtual Value* visit(ASTAssignExpression &) override final; Value* visit(ASTAssignExpression &) final;
virtual Value* visit(ASTSimpleExpression &) override final; Value* visit(ASTSimpleExpression &) final;
virtual Value* visit(ASTAdditiveExpression &) override final; Value* visit(ASTAdditiveExpression &) final;
virtual Value* visit(ASTVar &) override final; Value* visit(ASTVar &) final;
virtual Value* visit(ASTTerm &) override final; Value* visit(ASTTerm &) final;
virtual Value* visit(ASTCall &) override final; Value* visit(ASTCall &) final;
void add_depth() { depth += 2; } void add_depth() { depth += 2; }
void remove_depth() { depth -= 2; } void remove_depth() { depth -= 2; }
......
...@@ -4,9 +4,6 @@ ...@@ -4,9 +4,6 @@
#include "Value.hpp" #include "Value.hpp"
#include <list> #include <list>
#include <llvm/ADT/ilist.h>
#include <llvm/ADT/ilist_node.h>
#include <set>
#include <string> #include <string>
class Function; class Function;
...@@ -15,7 +12,11 @@ class Module; ...@@ -15,7 +12,11 @@ class Module;
class BasicBlock : public Value { class BasicBlock : public Value {
public: public:
~BasicBlock() = default; BasicBlock(const BasicBlock& other) = delete;
BasicBlock(BasicBlock&& other) noexcept = delete;
BasicBlock& operator=(const BasicBlock& other) = delete;
BasicBlock& operator=(BasicBlock&& other) noexcept = delete;
~BasicBlock() override;
static BasicBlock *create(Module *m, const std::string &name, static BasicBlock *create(Module *m, const std::string &name,
Function *parent) { Function *parent) {
auto prefix = name.empty() ? "" : "label_"; auto prefix = name.empty() ? "" : "label_";
...@@ -30,39 +31,48 @@ class BasicBlock : public Value { ...@@ -30,39 +31,48 @@ class BasicBlock : public Value {
void add_succ_basic_block(BasicBlock *bb) { succ_bbs_.push_back(bb); } void add_succ_basic_block(BasicBlock *bb) { succ_bbs_.push_back(bb); }
void remove_pre_basic_block(BasicBlock *bb) { pre_bbs_.remove(bb); } void remove_pre_basic_block(BasicBlock *bb) { pre_bbs_.remove(bb); }
void remove_succ_basic_block(BasicBlock *bb) { succ_bbs_.remove(bb); } void remove_succ_basic_block(BasicBlock *bb) { succ_bbs_.remove(bb); }
BasicBlock* get_entry_block_of_same_function(); BasicBlock* get_entry_block_of_same_function() const;
// If the Block is terminated by ret/br // If the Block is terminated by ret/br
bool is_terminated() const; bool is_terminated() const;
// Get terminator, only accept valid case use // Get terminator, only accept valid case use
Instruction *get_terminator(); Instruction *get_terminator() const;
/****************api about Instruction****************/ /****************api about Instruction****************/
// 在指令表最后插入指令
// 新特性:指令表分为 {alloca, phi | other inst} 两段,创建和向基本块插入 alloca 和 phi,都只会插在第一段,它们在常规指令前面。
// 因此,即使终止基本块也能插入 alloca 和 phi
void add_instruction(Instruction *instr); void add_instruction(Instruction *instr);
void add_instr_begin(Instruction *instr) { instr_list_.push_front(instr); } // 在指令链表最前面插入指令
void add_instr_before_end(Instruction *instr) { // 新特性:指令表分为 {alloca, phi | other inst} 两段,创建和向基本块插入 alloca 和 phi,都只会插在第一段,它们在常规指令前面。
instr_list_.insert(std::prev(instr_list_.end()), instr); void add_instr_begin(Instruction* instr);
} // 绕过终止指令插入指令
void erase_instr(Instruction *instr) { instr_list_.erase(instr); } // 新特性:指令表分为 {alloca, phi | other inst} 两段,创建和向基本块插入 alloca 和 phi,都只会插在第一段,它们在常规指令前面。
void add_instr_before_terminator(Instruction* instr);
// 从 BasicBlock 移除 Instruction,并 delete 这个 Instruction
void erase_instr(Instruction* instr) { instr_list_.remove(instr); delete instr; }
// 从 BasicBlock 移除 Instruction,你需要自己 delete 它
void remove_instr(Instruction *instr) { instr_list_.remove(instr); } void remove_instr(Instruction *instr) { instr_list_.remove(instr); }
llvm::ilist<Instruction> &get_instructions() { return instr_list_; } // 移除的 Instruction 需要自己 delete
std::list<Instruction*> &get_instructions() { return instr_list_; }
bool empty() const { return instr_list_.empty(); } bool empty() const { return instr_list_.empty(); }
int get_num_of_instr() const { return instr_list_.size(); } int get_num_of_instr() const { return static_cast<int>(instr_list_.size()); }
/****************api about accessing parent****************/ /****************api about accessing parent****************/
Function *get_parent() { return parent_; } Function *get_parent() const { return parent_; }
Module *get_module(); Module *get_module() const;
void erase_from_parent(); void erase_from_parent();
virtual std::string print() override; std::string print() override;
private: private:
BasicBlock(const BasicBlock &) = delete;
explicit BasicBlock(Module *m, const std::string &name, Function *parent); explicit BasicBlock(Module *m, const std::string &name, Function *parent);
std::list<BasicBlock *> pre_bbs_; std::list<BasicBlock *> pre_bbs_;
std::list<BasicBlock *> succ_bbs_; std::list<BasicBlock *> succ_bbs_;
llvm::ilist<Instruction> instr_list_; std::list<Instruction*> instr_list_;
Function *parent_; Function *parent_;
}; };
...@@ -5,11 +5,16 @@ ...@@ -5,11 +5,16 @@
#include "Value.hpp" #include "Value.hpp"
class Constant : public User { class Constant : public User {
private: private:
// int value; // int value;
public: public:
Constant(const Constant& other) = delete;
Constant(Constant&& other) noexcept = delete;
Constant& operator=(const Constant& other) = delete;
Constant& operator=(Constant&& other) noexcept = delete;
Constant(Type *ty, const std::string &name = "") : User(ty, name) {} Constant(Type *ty, const std::string &name = "") : User(ty, name) {}
~Constant() = default; ~Constant() override;
}; };
class ConstantInt : public Constant { class ConstantInt : public Constant {
...@@ -18,10 +23,10 @@ class ConstantInt : public Constant { ...@@ -18,10 +23,10 @@ class ConstantInt : public Constant {
ConstantInt(Type *ty, int val) : Constant(ty, ""), value_(val) {} ConstantInt(Type *ty, int val) : Constant(ty, ""), value_(val) {}
public: public:
int get_value() { return value_; } int get_value() const { return value_; }
static ConstantInt *get(int val, Module *m); static ConstantInt *get(int val, Module *m);
static ConstantInt *get(bool val, Module *m); static ConstantInt *get(bool val, Module *m);
virtual std::string print() override; std::string print() override;
}; };
class ConstantArray : public Constant { class ConstantArray : public Constant {
...@@ -31,16 +36,21 @@ class ConstantArray : public Constant { ...@@ -31,16 +36,21 @@ class ConstantArray : public Constant {
ConstantArray(ArrayType *ty, const std::vector<Constant *> &val); ConstantArray(ArrayType *ty, const std::vector<Constant *> &val);
public: public:
~ConstantArray() = default; ConstantArray(const ConstantArray& other) = delete;
ConstantArray(ConstantArray&& other) noexcept = delete;
ConstantArray& operator=(const ConstantArray& other) = delete;
ConstantArray& operator=(ConstantArray&& other) noexcept = delete;
~ConstantArray() override = default;
Constant *get_element_value(int index); Constant *get_element_value(int index) const;
unsigned get_size_of_array() { return const_array.size(); } int get_size_of_array() const { return static_cast<int>(const_array.size()); }
static ConstantArray *get(ArrayType *ty, static ConstantArray *get(ArrayType *ty,
const std::vector<Constant *> &val); const std::vector<Constant *> &val);
virtual std::string print() override; std::string print() override;
}; };
class ConstantZero : public Constant { class ConstantZero : public Constant {
...@@ -49,7 +59,7 @@ class ConstantZero : public Constant { ...@@ -49,7 +59,7 @@ class ConstantZero : public Constant {
public: public:
static ConstantZero *get(Type *ty, Module *m); static ConstantZero *get(Type *ty, Module *m);
virtual std::string print() override; std::string print() override;
}; };
class ConstantFP : public Constant { class ConstantFP : public Constant {
...@@ -59,6 +69,6 @@ class ConstantFP : public Constant { ...@@ -59,6 +69,6 @@ class ConstantFP : public Constant {
public: public:
static ConstantFP *get(float val, Module *m); static ConstantFP *get(float val, Module *m);
float get_value() { return val_; } float get_value() const { return val_; }
virtual std::string print() override; std::string print() override;
}; };
...@@ -2,15 +2,9 @@ ...@@ -2,15 +2,9 @@
#include "BasicBlock.hpp" #include "BasicBlock.hpp"
#include "Type.hpp" #include "Type.hpp"
#include "User.hpp"
#include <cassert> #include <cassert>
#include <cstddef>
#include <iterator>
#include <list> #include <list>
#include <llvm/ADT/ilist.h>
#include <llvm/ADT/ilist_node.h>
#include <map>
#include <memory> #include <memory>
class Module; class Module;
...@@ -18,11 +12,14 @@ class Argument; ...@@ -18,11 +12,14 @@ class Argument;
class Type; class Type;
class FunctionType; class FunctionType;
class Function : public Value, public llvm::ilist_node<Function> { class Function : public Value {
public: public:
Function(const Function &) = delete; Function(const Function& other) = delete;
Function(Function&& other) noexcept = delete;
Function& operator=(const Function& other) = delete;
Function& operator=(Function&& other) noexcept = delete;
Function(FunctionType *ty, const std::string &name, Module *parent); Function(FunctionType *ty, const std::string &name, Module *parent);
~Function() = default; ~Function() override;
static Function *create(FunctionType *ty, const std::string &name, static Function *create(FunctionType *ty, const std::string &name,
Module *parent); Module *parent);
...@@ -38,17 +35,17 @@ class Function : public Value, public llvm::ilist_node<Function> { ...@@ -38,17 +35,17 @@ class Function : public Value, public llvm::ilist_node<Function> {
// 此处 remove 的 BasicBlock, 需要手动 delete // 此处 remove 的 BasicBlock, 需要手动 delete
void remove(BasicBlock *bb); void remove(BasicBlock *bb);
BasicBlock *get_entry_block() { return basic_blocks_.front(); } BasicBlock *get_entry_block() const { return basic_blocks_.front(); }
std::list<BasicBlock*> &get_basic_blocks() { return basic_blocks_; } std::list<BasicBlock*> &get_basic_blocks() { return basic_blocks_; }
std::list<Argument> &get_args() { return arguments_; } std::list<Argument> &get_args() { return arguments_; }
bool is_declaration() { return basic_blocks_.empty(); } bool is_declaration() const { return basic_blocks_.empty(); }
void set_instr_name(); void set_instr_name();
std::string print(); std::string print() override;
// 用于检查函数的基本块是否存在问题 // 用于检查函数的基本块是否存在问题
void check_for_block_relation_error(); void check_for_block_relation_error() const;
private: private:
std::list<BasicBlock*> basic_blocks_; std::list<BasicBlock*> basic_blocks_;
...@@ -60,14 +57,17 @@ class Function : public Value, public llvm::ilist_node<Function> { ...@@ -60,14 +57,17 @@ class Function : public Value, public llvm::ilist_node<Function> {
// Argument of Function, does not contain actual value // Argument of Function, does not contain actual value
class Argument : public Value { class Argument : public Value {
public: public:
Argument(const Argument &) = delete; Argument(const Argument& other) = delete;
Argument(Argument&& other) noexcept = delete;
Argument& operator=(const Argument& other) = delete;
Argument& operator=(Argument&& other) noexcept = delete;
explicit Argument(Type *ty, const std::string &name = "", explicit Argument(Type *ty, const std::string &name = "",
Function *f = nullptr, unsigned arg_no = 0) Function *f = nullptr, unsigned arg_no = 0)
: Value(ty, name), parent_(f), arg_no_(arg_no) {} : Value(ty, name), parent_(f), arg_no_(arg_no) {}
virtual ~Argument() {} ~Argument() override = default;
inline const Function *get_parent() const { return parent_; } const Function *get_parent() const { return parent_; }
inline Function *get_parent() { return parent_; } Function *get_parent() { return parent_; }
/// For example in "void foo(int a, float b)" a is 0 and b is 1. /// For example in "void foo(int a, float b)" a is 0 and b is 1.
unsigned get_arg_no() const { unsigned get_arg_no() const {
...@@ -75,7 +75,7 @@ class Argument : public Value { ...@@ -75,7 +75,7 @@ class Argument : public Value {
return arg_no_; return arg_no_;
} }
virtual std::string print() override; std::string print() override;
private: private:
Function *parent_; Function *parent_;
......
...@@ -3,21 +3,24 @@ ...@@ -3,21 +3,24 @@
#include "Constant.hpp" #include "Constant.hpp"
#include "User.hpp" #include "User.hpp"
#include <llvm/ADT/ilist_node.h>
class Module; class Module;
class GlobalVariable : public User, public llvm::ilist_node<GlobalVariable> {
class GlobalVariable : public User {
private: private:
bool is_const_; bool is_const_;
Constant *init_val_; Constant *init_val_;
GlobalVariable(std::string name, Module *m, Type *ty, bool is_const, GlobalVariable(const std::string& name, Module *m, Type *ty, bool is_const,
Constant *init = nullptr); Constant *init = nullptr);
public: public:
GlobalVariable(const GlobalVariable &) = delete; GlobalVariable(const GlobalVariable& other) = delete;
static GlobalVariable *create(std::string name, Module *m, Type *ty, GlobalVariable(GlobalVariable&& other) noexcept = delete;
GlobalVariable& operator=(const GlobalVariable& other) = delete;
GlobalVariable& operator=(GlobalVariable&& other) noexcept = delete;
static GlobalVariable *create(const std::string& name, Module *m, Type *ty,
bool is_const, Constant *init); bool is_const, Constant *init);
virtual ~GlobalVariable() = default; ~GlobalVariable() override = default;
Constant *get_init() { return init_val_; } Constant *get_init() const { return init_val_; }
bool is_const() { return is_const_; } bool is_const() const { return is_const_; }
std::string print(); std::string print() override;
}; };
#pragma once #pragma once
#include "BasicBlock.hpp"
#include "Function.hpp" #include "Function.hpp"
#include "Instruction.hpp" #include "Instruction.hpp"
#include "Value.hpp" #include "Value.hpp"
...@@ -11,124 +10,159 @@ class IRBuilder { ...@@ -11,124 +10,159 @@ class IRBuilder {
Module *m_; Module *m_;
public: public:
IRBuilder(BasicBlock *bb, Module *m) : BB_(bb), m_(m){}; IRBuilder(const IRBuilder& other) = delete;
IRBuilder(IRBuilder&& other) noexcept = delete;
IRBuilder& operator=(const IRBuilder& other) = delete;
IRBuilder& operator=(IRBuilder&& other) noexcept = delete;
IRBuilder(BasicBlock *bb, Module *m) : BB_(bb), m_(m){}
~IRBuilder() = default; ~IRBuilder() = default;
Module *get_module() { return m_; } Module *get_module() const { return m_; }
BasicBlock *get_insert_block() { return this->BB_; } BasicBlock *get_insert_block() const { return this->BB_; }
void set_insert_point(BasicBlock *bb) { void set_insert_point(BasicBlock *bb) {
this->BB_ = bb; this->BB_ = bb;
} // 在某个基本块中插入指令 } // 在某个基本块中插入指令
IBinaryInst *create_iadd(Value *lhs, Value *rhs) { IBinaryInst *create_iadd(Value *lhs, Value *rhs) const
{
return IBinaryInst::create_add(lhs, rhs, this->BB_); return IBinaryInst::create_add(lhs, rhs, this->BB_);
} // 创建加法指令(以及其他算术指令) } // 创建加法指令(以及其他算术指令)
IBinaryInst *create_isub(Value *lhs, Value *rhs) { IBinaryInst *create_isub(Value *lhs, Value *rhs) const
{
return IBinaryInst::create_sub(lhs, rhs, this->BB_); return IBinaryInst::create_sub(lhs, rhs, this->BB_);
} }
IBinaryInst *create_imul(Value *lhs, Value *rhs) { IBinaryInst *create_imul(Value *lhs, Value *rhs) const
{
return IBinaryInst::create_mul(lhs, rhs, this->BB_); return IBinaryInst::create_mul(lhs, rhs, this->BB_);
} }
IBinaryInst *create_isdiv(Value *lhs, Value *rhs) { IBinaryInst *create_isdiv(Value *lhs, Value *rhs) const
{
return IBinaryInst::create_sdiv(lhs, rhs, this->BB_); return IBinaryInst::create_sdiv(lhs, rhs, this->BB_);
} }
ICmpInst *create_icmp_eq(Value *lhs, Value *rhs) { ICmpInst *create_icmp_eq(Value *lhs, Value *rhs) const
{
return ICmpInst::create_eq(lhs, rhs, this->BB_); return ICmpInst::create_eq(lhs, rhs, this->BB_);
} }
ICmpInst *create_icmp_ne(Value *lhs, Value *rhs) { ICmpInst *create_icmp_ne(Value *lhs, Value *rhs) const
{
return ICmpInst::create_ne(lhs, rhs, this->BB_); return ICmpInst::create_ne(lhs, rhs, this->BB_);
} }
ICmpInst *create_icmp_gt(Value *lhs, Value *rhs) { ICmpInst *create_icmp_gt(Value *lhs, Value *rhs) const
{
return ICmpInst::create_gt(lhs, rhs, this->BB_); return ICmpInst::create_gt(lhs, rhs, this->BB_);
} }
ICmpInst *create_icmp_ge(Value *lhs, Value *rhs) { ICmpInst *create_icmp_ge(Value *lhs, Value *rhs) const
{
return ICmpInst::create_ge(lhs, rhs, this->BB_); return ICmpInst::create_ge(lhs, rhs, this->BB_);
} }
ICmpInst *create_icmp_lt(Value *lhs, Value *rhs) { ICmpInst *create_icmp_lt(Value *lhs, Value *rhs) const
{
return ICmpInst::create_lt(lhs, rhs, this->BB_); return ICmpInst::create_lt(lhs, rhs, this->BB_);
} }
ICmpInst *create_icmp_le(Value *lhs, Value *rhs) { ICmpInst *create_icmp_le(Value *lhs, Value *rhs) const
{
return ICmpInst::create_le(lhs, rhs, this->BB_); return ICmpInst::create_le(lhs, rhs, this->BB_);
} }
CallInst *create_call(Value *func, std::vector<Value *> args) { CallInst *create_call(Value *func, const std::vector<Value *>& args) const
return CallInst::create_call(static_cast<Function *>(func), args, {
return CallInst::create_call(dynamic_cast<Function *>(func), args,
this->BB_); this->BB_);
} }
BranchInst *create_br(BasicBlock *if_true) { BranchInst *create_br(BasicBlock *if_true) const
{
return BranchInst::create_br(if_true, this->BB_); return BranchInst::create_br(if_true, this->BB_);
} }
BranchInst *create_cond_br(Value *cond, BasicBlock *if_true, BranchInst *create_cond_br(Value *cond, BasicBlock *if_true,
BasicBlock *if_false) { BasicBlock *if_false) const
{
return BranchInst::create_cond_br(cond, if_true, if_false, this->BB_); return BranchInst::create_cond_br(cond, if_true, if_false, this->BB_);
} }
ReturnInst *create_ret(Value *val) { ReturnInst *create_ret(Value *val) const
{
return ReturnInst::create_ret(val, this->BB_); return ReturnInst::create_ret(val, this->BB_);
} }
ReturnInst *create_void_ret() { ReturnInst *create_void_ret() const
{
return ReturnInst::create_void_ret(this->BB_); return ReturnInst::create_void_ret(this->BB_);
} }
GetElementPtrInst *create_gep(Value *ptr, std::vector<Value *> idxs) { GetElementPtrInst *create_gep(Value *ptr, const std::vector<Value *>& idxs) const
{
return GetElementPtrInst::create_gep(ptr, idxs, this->BB_); return GetElementPtrInst::create_gep(ptr, idxs, this->BB_);
} }
StoreInst *create_store(Value *val, Value *ptr) { StoreInst *create_store(Value *val, Value *ptr) const
{
return StoreInst::create_store(val, ptr, this->BB_); return StoreInst::create_store(val, ptr, this->BB_);
} }
LoadInst *create_load(Value *ptr) { LoadInst *create_load(Value *ptr) const
{
assert(ptr->get_type()->is_pointer_type() && assert(ptr->get_type()->is_pointer_type() &&
"ptr must be pointer type"); "ptr must be pointer type");
return LoadInst::create_load(ptr, this->BB_); return LoadInst::create_load(ptr, this->BB_);
} }
AllocaInst *create_alloca(Type *ty) { AllocaInst *create_alloca(Type *ty) const
return AllocaInst::create_alloca(ty, this->BB_); {
} return AllocaInst::create_alloca(ty, this->BB_->get_entry_block_of_same_function());
AllocaInst *create_alloca_begin(Type *ty) {
return AllocaInst::create_alloca_begin(ty, this->BB_);
} }
ZextInst *create_zext(Value *val, Type *ty) {
ZextInst *create_zext(Value *val, Type *ty) const
{
return ZextInst::create_zext(val, ty, this->BB_); return ZextInst::create_zext(val, ty, this->BB_);
} }
SiToFpInst *create_sitofp(Value *val, Type *ty) { SiToFpInst *create_sitofp(Value *val, Type *ty) const
{
return SiToFpInst::create_sitofp(val, this->BB_); return SiToFpInst::create_sitofp(val, this->BB_);
} }
FpToSiInst *create_fptosi(Value *val, Type *ty) { FpToSiInst *create_fptosi(Value *val, Type *ty) const
{
return FpToSiInst::create_fptosi(val, ty, this->BB_); return FpToSiInst::create_fptosi(val, ty, this->BB_);
} }
FCmpInst *create_fcmp_ne(Value *lhs, Value *rhs) { FCmpInst *create_fcmp_ne(Value *lhs, Value *rhs) const
{
return FCmpInst::create_fne(lhs, rhs, this->BB_); return FCmpInst::create_fne(lhs, rhs, this->BB_);
} }
FCmpInst *create_fcmp_lt(Value *lhs, Value *rhs) { FCmpInst *create_fcmp_lt(Value *lhs, Value *rhs) const
{
return FCmpInst::create_flt(lhs, rhs, this->BB_); return FCmpInst::create_flt(lhs, rhs, this->BB_);
} }
FCmpInst *create_fcmp_le(Value *lhs, Value *rhs) { FCmpInst *create_fcmp_le(Value *lhs, Value *rhs) const
{
return FCmpInst::create_fle(lhs, rhs, this->BB_); return FCmpInst::create_fle(lhs, rhs, this->BB_);
} }
FCmpInst *create_fcmp_ge(Value *lhs, Value *rhs) { FCmpInst *create_fcmp_ge(Value *lhs, Value *rhs) const
{
return FCmpInst::create_fge(lhs, rhs, this->BB_); return FCmpInst::create_fge(lhs, rhs, this->BB_);
} }
FCmpInst *create_fcmp_gt(Value *lhs, Value *rhs) { FCmpInst *create_fcmp_gt(Value *lhs, Value *rhs) const
{
return FCmpInst::create_fgt(lhs, rhs, this->BB_); return FCmpInst::create_fgt(lhs, rhs, this->BB_);
} }
FCmpInst *create_fcmp_eq(Value *lhs, Value *rhs) { FCmpInst *create_fcmp_eq(Value *lhs, Value *rhs) const
{
return FCmpInst::create_feq(lhs, rhs, this->BB_); return FCmpInst::create_feq(lhs, rhs, this->BB_);
} }
FBinaryInst *create_fadd(Value *lhs, Value *rhs) { FBinaryInst *create_fadd(Value *lhs, Value *rhs) const
{
return FBinaryInst::create_fadd(lhs, rhs, this->BB_); return FBinaryInst::create_fadd(lhs, rhs, this->BB_);
} }
FBinaryInst *create_fsub(Value *lhs, Value *rhs) { FBinaryInst *create_fsub(Value *lhs, Value *rhs) const
{
return FBinaryInst::create_fsub(lhs, rhs, this->BB_); return FBinaryInst::create_fsub(lhs, rhs, this->BB_);
} }
FBinaryInst *create_fmul(Value *lhs, Value *rhs) { FBinaryInst *create_fmul(Value *lhs, Value *rhs) const
{
return FBinaryInst::create_fmul(lhs, rhs, this->BB_); return FBinaryInst::create_fmul(lhs, rhs, this->BB_);
} }
FBinaryInst *create_fdiv(Value *lhs, Value *rhs) { FBinaryInst *create_fdiv(Value *lhs, Value *rhs) const
{
return FBinaryInst::create_fdiv(lhs, rhs, this->BB_); return FBinaryInst::create_fdiv(lhs, rhs, this->BB_);
} }
}; };
#pragma once #pragma once
#include "BasicBlock.hpp" #include "BasicBlock.hpp"
#include "Constant.hpp"
#include "Function.hpp"
#include "GlobalVariable.hpp"
#include "Instruction.hpp" #include "Instruction.hpp"
#include "Module.hpp"
#include "Type.hpp"
#include "User.hpp"
#include "Value.hpp" #include "Value.hpp"
std::string print_as_op(Value *v, bool print_ty); std::string print_as_op(Value *v, bool print_ty);
......
...@@ -3,15 +3,12 @@ ...@@ -3,15 +3,12 @@
#include "Type.hpp" #include "Type.hpp"
#include "User.hpp" #include "User.hpp"
#include <cstdint>
#include <llvm/ADT/ilist_node.h>
class BasicBlock; class BasicBlock;
class Function; class Function;
class Instruction : public User, public llvm::ilist_node<Instruction> { class Instruction : public User {
public: public:
enum OpID : uint32_t { enum OpID : uint8_t {
// Terminator Instructions // Terminator Instructions
ret, ret,
br, br,
...@@ -56,21 +53,26 @@ class Instruction : public User, public llvm::ilist_node<Instruction> { ...@@ -56,21 +53,26 @@ class Instruction : public User, public llvm::ilist_node<Instruction> {
/* @parent: if parent!=nullptr, auto insert to bb /* @parent: if parent!=nullptr, auto insert to bb
* @ty: result type */ * @ty: result type */
Instruction(Type *ty, OpID id, BasicBlock *parent = nullptr); Instruction(Type *ty, OpID id, BasicBlock *parent = nullptr);
Instruction(const Instruction &) = delete; ~Instruction() override;
virtual ~Instruction() = default;
Instruction(const Instruction& other) = delete;
Instruction(Instruction&& other) noexcept = delete;
Instruction& operator=(const Instruction& other) = delete;
Instruction& operator=(Instruction&& other) noexcept = delete;
BasicBlock *get_parent() { return parent_; } BasicBlock *get_parent() { return parent_; }
const BasicBlock *get_parent() const { return parent_; } const BasicBlock *get_parent() const { return parent_; }
void set_parent(BasicBlock *parent) { this->parent_ = parent; } void set_parent(BasicBlock *parent) { this->parent_ = parent; }
// Return the function this instruction belongs to. // Return the function this instruction belongs to.
Function *get_function(); Function *get_function() const;
Module *get_module(); Module *get_module() const;
OpID get_instr_type() const { return op_id_; } OpID get_instr_type() const { return op_id_; }
std::string get_instr_op_name() const; std::string get_instr_op_name() const;
bool is_void() { bool is_void() const
{
return ((op_id_ == ret) || (op_id_ == br) || (op_id_ == store) || return ((op_id_ == ret) || (op_id_ == br) || (op_id_ == store) ||
(op_id_ == call && this->get_type()->is_void_type())); (op_id_ == call && this->get_type()->is_void_type()));
} }
...@@ -136,7 +138,7 @@ class IBinaryInst : public BaseInst<IBinaryInst> { ...@@ -136,7 +138,7 @@ class IBinaryInst : public BaseInst<IBinaryInst> {
static IBinaryInst *create_mul(Value *v1, Value *v2, BasicBlock *bb); static IBinaryInst *create_mul(Value *v1, Value *v2, BasicBlock *bb);
static IBinaryInst *create_sdiv(Value *v1, Value *v2, BasicBlock *bb); static IBinaryInst *create_sdiv(Value *v1, Value *v2, BasicBlock *bb);
virtual std::string print() override; std::string print() override;
}; };
class FBinaryInst : public BaseInst<FBinaryInst> { class FBinaryInst : public BaseInst<FBinaryInst> {
...@@ -151,7 +153,7 @@ class FBinaryInst : public BaseInst<FBinaryInst> { ...@@ -151,7 +153,7 @@ class FBinaryInst : public BaseInst<FBinaryInst> {
static FBinaryInst *create_fmul(Value *v1, Value *v2, BasicBlock *bb); static FBinaryInst *create_fmul(Value *v1, Value *v2, BasicBlock *bb);
static FBinaryInst *create_fdiv(Value *v1, Value *v2, BasicBlock *bb); static FBinaryInst *create_fdiv(Value *v1, Value *v2, BasicBlock *bb);
virtual std::string print() override; std::string print() override;
}; };
class ICmpInst : public BaseInst<ICmpInst> { class ICmpInst : public BaseInst<ICmpInst> {
...@@ -168,7 +170,7 @@ class ICmpInst : public BaseInst<ICmpInst> { ...@@ -168,7 +170,7 @@ class ICmpInst : public BaseInst<ICmpInst> {
static ICmpInst *create_eq(Value *v1, Value *v2, BasicBlock *bb); static ICmpInst *create_eq(Value *v1, Value *v2, BasicBlock *bb);
static ICmpInst *create_ne(Value *v1, Value *v2, BasicBlock *bb); static ICmpInst *create_ne(Value *v1, Value *v2, BasicBlock *bb);
virtual std::string print() override; std::string print() override;
}; };
class FCmpInst : public BaseInst<FCmpInst> { class FCmpInst : public BaseInst<FCmpInst> {
...@@ -185,21 +187,21 @@ class FCmpInst : public BaseInst<FCmpInst> { ...@@ -185,21 +187,21 @@ class FCmpInst : public BaseInst<FCmpInst> {
static FCmpInst *create_feq(Value *v1, Value *v2, BasicBlock *bb); static FCmpInst *create_feq(Value *v1, Value *v2, BasicBlock *bb);
static FCmpInst *create_fne(Value *v1, Value *v2, BasicBlock *bb); static FCmpInst *create_fne(Value *v1, Value *v2, BasicBlock *bb);
virtual std::string print() override; std::string print() override;
}; };
class CallInst : public BaseInst<CallInst> { class CallInst : public BaseInst<CallInst> {
friend BaseInst<CallInst>; friend BaseInst<CallInst>;
protected: protected:
CallInst(Function *func, std::vector<Value *> args, BasicBlock *bb); CallInst(Function *func, const std::vector<Value *>& args, BasicBlock *bb);
public: public:
static CallInst *create_call(Function *func, std::vector<Value *> args, static CallInst *create_call(Function *func, const std::vector<Value *>& args,
BasicBlock *bb); BasicBlock *bb);
FunctionType *get_function_type() const; FunctionType *get_function_type() const;
virtual std::string print() override; std::string print() override;
}; };
class BranchInst : public BaseInst<BranchInst> { class BranchInst : public BaseInst<BranchInst> {
...@@ -208,9 +210,14 @@ class BranchInst : public BaseInst<BranchInst> { ...@@ -208,9 +210,14 @@ class BranchInst : public BaseInst<BranchInst> {
private: private:
BranchInst(Value *cond, BasicBlock *if_true, BasicBlock *if_false, BranchInst(Value *cond, BasicBlock *if_true, BasicBlock *if_false,
BasicBlock *bb); BasicBlock *bb);
~BranchInst(); ~BranchInst() override;
public:
BranchInst(const BranchInst& other) = delete;
BranchInst(BranchInst&& other) noexcept = delete;
BranchInst& operator=(const BranchInst& other) = delete;
BranchInst& operator=(BranchInst&& other) noexcept = delete;
public:
static BranchInst *create_cond_br(Value *cond, BasicBlock *if_true, static BranchInst *create_cond_br(Value *cond, BasicBlock *if_true,
BasicBlock *if_false, BasicBlock *bb); BasicBlock *if_false, BasicBlock *bb);
static BranchInst *create_br(BasicBlock *if_true, BasicBlock *bb); static BranchInst *create_br(BasicBlock *if_true, BasicBlock *bb);
...@@ -219,7 +226,7 @@ class BranchInst : public BaseInst<BranchInst> { ...@@ -219,7 +226,7 @@ class BranchInst : public BaseInst<BranchInst> {
Value *get_condition() const { return get_operand(0); } Value *get_condition() const { return get_operand(0); }
virtual std::string print() override; std::string print() override;
}; };
class ReturnInst : public BaseInst<ReturnInst> { class ReturnInst : public BaseInst<ReturnInst> {
...@@ -233,69 +240,66 @@ class ReturnInst : public BaseInst<ReturnInst> { ...@@ -233,69 +240,66 @@ class ReturnInst : public BaseInst<ReturnInst> {
static ReturnInst *create_void_ret(BasicBlock *bb); static ReturnInst *create_void_ret(BasicBlock *bb);
bool is_void_ret() const; bool is_void_ret() const;
virtual std::string print() override; std::string print() override;
}; };
class GetElementPtrInst : public BaseInst<GetElementPtrInst> { class GetElementPtrInst : public BaseInst<GetElementPtrInst> {
friend BaseInst<GetElementPtrInst>; friend BaseInst<GetElementPtrInst>;
private: private:
GetElementPtrInst(Value *ptr, std::vector<Value *> idxs, BasicBlock *bb); GetElementPtrInst(Value *ptr, const std::vector<Value *>& idxs, BasicBlock *bb);
public: public:
static Type *get_element_type(Value *ptr, std::vector<Value *> idxs); static Type *get_element_type(const Value *ptr, const std::vector<Value *>& idxs);
static GetElementPtrInst *create_gep(Value *ptr, std::vector<Value *> idxs, static GetElementPtrInst *create_gep(Value *ptr, const std::vector<Value *>& idxs,
BasicBlock *bb); BasicBlock *bb);
Type *get_element_type() const; Type *get_element_type() const;
virtual std::string print() override; std::string print() override;
}; };
class StoreInst : public BaseInst<StoreInst> { class StoreInst : public BaseInst<StoreInst> {
friend BaseInst<StoreInst>; friend BaseInst<StoreInst>;
private:
StoreInst(Value *val, Value *ptr, BasicBlock *bb); StoreInst(Value *val, Value *ptr, BasicBlock *bb);
public: public:
static StoreInst *create_store(Value *val, Value *ptr, BasicBlock *bb); static StoreInst *create_store(Value *val, Value *ptr, BasicBlock *bb);
Value *get_rval() { return this->get_operand(0); } Value *get_rval() const { return this->get_operand(0); }
Value *get_lval() { return this->get_operand(1); } Value *get_lval() const { return this->get_operand(1); }
virtual std::string print() override; std::string print() override;
}; };
class LoadInst : public BaseInst<LoadInst> { class LoadInst : public BaseInst<LoadInst> {
friend BaseInst<LoadInst>; friend BaseInst<LoadInst>;
private:
LoadInst(Value *ptr, BasicBlock *bb); LoadInst(Value *ptr, BasicBlock *bb);
public: public:
static LoadInst *create_load(Value *ptr, BasicBlock *bb); static LoadInst *create_load(Value *ptr, BasicBlock *bb);
Value *get_lval() const { return this->get_operand(0); } Value *get_lval() const { return this->get_operand(0); }
Type *get_load_type() const { return get_type(); }; Type *get_load_type() const { return get_type(); }
virtual std::string print() override; std::string print() override;
}; };
class AllocaInst : public BaseInst<AllocaInst> { class AllocaInst : public BaseInst<AllocaInst> {
friend BaseInst<AllocaInst>; friend BaseInst<AllocaInst>;
private:
AllocaInst(Type *ty, BasicBlock *bb); AllocaInst(Type *ty, BasicBlock *bb);
public: public:
// 新特性:指令表分为 {alloca, phi | other inst} 两段,创建和向基本块插入 alloca 和 phi,都只会插在第一段,它们在常规指令前面
static AllocaInst *create_alloca(Type *ty, BasicBlock *bb); static AllocaInst *create_alloca(Type *ty, BasicBlock *bb);
static AllocaInst *create_alloca_begin(Type *ty, BasicBlock *bb);
Type *get_alloca_type() const { Type *get_alloca_type() const {
return get_type()->get_pointer_element_type(); return get_type()->get_pointer_element_type();
}; }
virtual std::string print() override; std::string print() override;
}; };
class ZextInst : public BaseInst<ZextInst> { class ZextInst : public BaseInst<ZextInst> {
...@@ -308,9 +312,9 @@ class ZextInst : public BaseInst<ZextInst> { ...@@ -308,9 +312,9 @@ class ZextInst : public BaseInst<ZextInst> {
static ZextInst *create_zext(Value *val, Type *ty, BasicBlock *bb); static ZextInst *create_zext(Value *val, Type *ty, BasicBlock *bb);
static ZextInst *create_zext_to_i32(Value *val, BasicBlock *bb); static ZextInst *create_zext_to_i32(Value *val, BasicBlock *bb);
Type *get_dest_type() const { return get_type(); }; Type *get_dest_type() const { return get_type(); }
virtual std::string print() override; std::string print() override;
}; };
class FpToSiInst : public BaseInst<FpToSiInst> { class FpToSiInst : public BaseInst<FpToSiInst> {
...@@ -323,9 +327,9 @@ class FpToSiInst : public BaseInst<FpToSiInst> { ...@@ -323,9 +327,9 @@ class FpToSiInst : public BaseInst<FpToSiInst> {
static FpToSiInst *create_fptosi(Value *val, Type *ty, BasicBlock *bb); static FpToSiInst *create_fptosi(Value *val, Type *ty, BasicBlock *bb);
static FpToSiInst *create_fptosi_to_i32(Value *val, BasicBlock *bb); static FpToSiInst *create_fptosi_to_i32(Value *val, BasicBlock *bb);
Type *get_dest_type() const { return get_type(); }; Type *get_dest_type() const { return get_type(); }
virtual std::string print() override; std::string print() override;
}; };
class SiToFpInst : public BaseInst<SiToFpInst> { class SiToFpInst : public BaseInst<SiToFpInst> {
...@@ -337,34 +341,38 @@ class SiToFpInst : public BaseInst<SiToFpInst> { ...@@ -337,34 +341,38 @@ class SiToFpInst : public BaseInst<SiToFpInst> {
public: public:
static SiToFpInst *create_sitofp(Value *val, BasicBlock *bb); static SiToFpInst *create_sitofp(Value *val, BasicBlock *bb);
Type *get_dest_type() const { return get_type(); }; Type *get_dest_type() const { return get_type(); }
virtual std::string print() override; std::string print() override;
}; };
class PhiInst : public BaseInst<PhiInst> { class PhiInst : public BaseInst<PhiInst> {
friend BaseInst<PhiInst>; friend BaseInst<PhiInst>;
private: private:
PhiInst(Type *ty, std::vector<Value *> vals, PhiInst(Type *ty, const std::vector<Value *>& vals,
std::vector<BasicBlock *> val_bbs, BasicBlock *bb); const std::vector<BasicBlock *>& val_bbs, BasicBlock *bb);
public: public:
// 新特性:指令表分为 {alloca, phi | other inst} 两段,创建和向基本块插入 alloca 和 phi,都只会插在第一段,它们在常规指令前面
static PhiInst *create_phi(Type *ty, BasicBlock *bb, static PhiInst *create_phi(Type *ty, BasicBlock *bb,
std::vector<Value *> vals = {}, const std::vector<Value *>& vals = {},
std::vector<BasicBlock *> val_bbs = {}); const std::vector<BasicBlock *>& val_bbs = {});
void add_phi_pair_operand(Value *val, Value *pre_bb) { void add_phi_pair_operand(Value *val, Value *pre_bb) {
this->add_operand(val); this->add_operand(val);
this->add_operand(pre_bb); this->add_operand(pre_bb);
} }
std::vector<std::pair<Value *, BasicBlock *>> get_phi_pairs() { std::vector<std::pair<Value *, BasicBlock *>> get_phi_pairs() const
{
std::vector<std::pair<Value *, BasicBlock *>> res; std::vector<std::pair<Value *, BasicBlock *>> res;
for (size_t i = 0; i < get_num_operand(); i += 2) { int ops = static_cast<int>(get_num_operand());
res.push_back({this->get_operand(i), for (int i = 0; i < ops; i += 2) {
this->get_operand(i + 1)->as<BasicBlock>()}); res.emplace_back(this->get_operand(i),
this->get_operand(i + 1)->as<BasicBlock>());
} }
return res; return res;
} }
virtual std::string print() override;
std::string print() override;
}; };
...@@ -2,13 +2,9 @@ ...@@ -2,13 +2,9 @@
#include "Function.hpp" #include "Function.hpp"
#include "GlobalVariable.hpp" #include "GlobalVariable.hpp"
#include "Instruction.hpp"
#include "Type.hpp" #include "Type.hpp"
#include "Value.hpp"
#include <list> #include <list>
#include <llvm/ADT/ilist.h>
#include <llvm/ADT/ilist_node.h>
#include <map> #include <map>
#include <memory> #include <memory>
#include <string> #include <string>
...@@ -18,14 +14,18 @@ class Function; ...@@ -18,14 +14,18 @@ class Function;
class Module { class Module {
public: public:
Module(); Module();
~Module() = default; ~Module();
Module(const Module& other) = delete;
Type *get_void_type(); Module(Module&& other) noexcept = delete;
Type *get_label_type(); Module& operator=(const Module& other) = delete;
IntegerType *get_int1_type(); Module& operator=(Module&& other) noexcept = delete;
IntegerType *get_int32_type();
Type *get_void_type() const;
Type *get_label_type() const;
IntegerType *get_int1_type() const;
IntegerType *get_int32_type() const;
PointerType *get_int32_ptr_type(); PointerType *get_int32_ptr_type();
FloatType *get_float_type(); FloatType *get_float_type() const;
PointerType *get_float_ptr_type(); PointerType *get_float_ptr_type();
PointerType *get_pointer_type(Type *contained); PointerType *get_pointer_type(Type *contained);
...@@ -33,27 +33,27 @@ class Module { ...@@ -33,27 +33,27 @@ class Module {
FunctionType *get_function_type(Type *retty, std::vector<Type *> &args); FunctionType *get_function_type(Type *retty, std::vector<Type *> &args);
void add_function(Function *f); void add_function(Function *f);
llvm::ilist<Function> &get_functions(); std::list<Function*> &get_functions();
void add_global_variable(GlobalVariable *g); void add_global_variable(GlobalVariable *g);
llvm::ilist<GlobalVariable> &get_global_variable(); std::list<GlobalVariable*> &get_global_variable();
void set_print_name(); void set_print_name();
std::string print(); std::string print();
private: private:
// The global variables in the module // The global variables in the module
llvm::ilist<GlobalVariable> global_list_; std::list<GlobalVariable*> global_list_;
// The functions in the module // The functions in the module
llvm::ilist<Function> function_list_; std::list<Function*> function_list_;
std::unique_ptr<IntegerType> int1_ty_; IntegerType* int1_ty_;
std::unique_ptr<IntegerType> int32_ty_; IntegerType* int32_ty_;
std::unique_ptr<Type> label_ty_; Type* label_ty_;
std::unique_ptr<Type> void_ty_; Type* void_ty_;
std::unique_ptr<FloatType> float32_ty_; FloatType* float32_ty_;
std::map<Type *, std::unique_ptr<PointerType>> pointer_map_; std::map<Type *, PointerType*> pointer_map_;
std::map<std::pair<Type *, int>, std::unique_ptr<ArrayType>> array_map_; std::map<std::pair<Type *, int>, ArrayType*> array_map_;
std::map<std::pair<Type *, std::vector<Type *>>, std::map<std::pair<Type *, std::vector<Type *>>,
std::unique_ptr<FunctionType>> FunctionType*>
function_map_; function_map_;
}; };
...@@ -12,7 +12,12 @@ class FloatType; ...@@ -12,7 +12,12 @@ class FloatType;
class Type { class Type {
public: public:
enum TypeID { Type(const Type& other) = delete;
Type(Type&& other) noexcept = delete;
Type& operator=(const Type& other) = delete;
Type& operator=(Type&& other) noexcept = delete;
enum TypeID: uint8_t {
VoidTyID, // Void VoidTyID, // Void
LabelTyID, // Labels, e.g., BasicBlock LabelTyID, // Labels, e.g., BasicBlock
IntegerTyID, // Integers, include 32 bits and 1 bit IntegerTyID, // Integers, include 32 bits and 1 bit
...@@ -25,11 +30,7 @@ class Type { ...@@ -25,11 +30,7 @@ class Type {
explicit Type(TypeID tid, Module *m); explicit Type(TypeID tid, Module *m);
// 生成 vptr, 使调试时能够显示 IntegerType 等 Type 的信息 // 生成 vptr, 使调试时能够显示 IntegerType 等 Type 的信息
#ifdef DEBUG_BUILD virtual ~Type();
virtual
#endif
~Type() = default;
TypeID get_type_id() const { return tid_; } TypeID get_type_id() const { return tid_; }
...@@ -59,7 +60,13 @@ class Type { ...@@ -59,7 +60,13 @@ class Type {
class IntegerType : public Type { class IntegerType : public Type {
public: public:
IntegerType(const IntegerType& other) = delete;
IntegerType(IntegerType&& other) noexcept = delete;
IntegerType& operator=(const IntegerType& other) = delete;
IntegerType& operator=(IntegerType&& other) noexcept = delete;
explicit IntegerType(unsigned num_bits, Module *m); explicit IntegerType(unsigned num_bits, Module *m);
~IntegerType() override;
unsigned get_num_bits() const; unsigned get_num_bits() const;
...@@ -69,10 +76,16 @@ class IntegerType : public Type { ...@@ -69,10 +76,16 @@ class IntegerType : public Type {
class FunctionType : public Type { class FunctionType : public Type {
public: public:
FunctionType(Type *result, std::vector<Type *> params); FunctionType(const FunctionType& other) = delete;
FunctionType(FunctionType&& other) noexcept = delete;
FunctionType& operator=(const FunctionType& other) = delete;
FunctionType& operator=(FunctionType&& other) noexcept = delete;
static bool is_valid_return_type(Type *ty); FunctionType(Type *result, const std::vector<Type *>& params);
static bool is_valid_argument_type(Type *ty); ~FunctionType() override;
static bool is_valid_return_type(const Type *ty);
static bool is_valid_argument_type(const Type *ty);
static FunctionType *get(Type *result, std::vector<Type *> params); static FunctionType *get(Type *result, std::vector<Type *> params);
...@@ -90,9 +103,15 @@ class FunctionType : public Type { ...@@ -90,9 +103,15 @@ class FunctionType : public Type {
class ArrayType : public Type { class ArrayType : public Type {
public: public:
ArrayType(const ArrayType& other) = delete;
ArrayType(ArrayType&& other) noexcept = delete;
ArrayType& operator=(const ArrayType& other) = delete;
ArrayType& operator=(ArrayType&& other) noexcept = delete;
ArrayType(Type *contained, unsigned num_elements); ArrayType(Type *contained, unsigned num_elements);
~ArrayType() override;
static bool is_valid_element_type(Type *ty); static bool is_valid_element_type(const Type *ty);
static ArrayType *get(Type *contained, unsigned num_elements); static ArrayType *get(Type *contained, unsigned num_elements);
...@@ -106,7 +125,13 @@ class ArrayType : public Type { ...@@ -106,7 +125,13 @@ class ArrayType : public Type {
class PointerType : public Type { class PointerType : public Type {
public: public:
PointerType(const PointerType& other) = delete;
PointerType(PointerType&& other) noexcept = delete;
PointerType& operator=(const PointerType& other) = delete;
PointerType& operator=(PointerType&& other) noexcept = delete;
PointerType(Type *contained); PointerType(Type *contained);
~PointerType() override;
Type *get_element_type() const { return contained_; } Type *get_element_type() const { return contained_; }
static PointerType *get(Type *contained); static PointerType *get(Type *contained);
...@@ -117,8 +142,12 @@ class PointerType : public Type { ...@@ -117,8 +142,12 @@ class PointerType : public Type {
class FloatType : public Type { class FloatType : public Type {
public: public:
FloatType(const FloatType& other) = delete;
FloatType(FloatType&& other) noexcept = delete;
FloatType& operator=(const FloatType& other) = delete;
FloatType& operator=(FloatType&& other) noexcept = delete;
FloatType(Module *m); FloatType(Module *m);
~FloatType() override;
static FloatType *get(Module *m); static FloatType *get(Module *m);
private:
}; };
...@@ -6,14 +6,18 @@ ...@@ -6,14 +6,18 @@
class User : public Value { class User : public Value {
public: public:
User(Type *ty, const std::string &name = "") : Value(ty, name){}; User(const User& other) = delete;
virtual ~User() { remove_all_operands(); } User(User&& other) noexcept = delete;
User& operator=(const User& other) = delete;
User& operator=(User&& other) noexcept = delete;
User(Type *ty, const std::string &name = "") : Value(ty, name){}
~User() override;
const std::vector<Value *> &get_operands() const { return operands_; } const std::vector<Value *> &get_operands() const { return operands_; }
unsigned get_num_operand() const { return operands_.size(); } unsigned get_num_operand() const { return static_cast<unsigned>(operands_.size()); }
// start from 0 // start from 0
Value *get_operand(unsigned i) const { return operands_.at(i); }; Value *get_operand(unsigned i) const { return operands_.at(i); }
// start from 0 // start from 0
void set_operand(unsigned i, Value *v); void set_operand(unsigned i, Value *v);
void add_operand(Value *v); void add_operand(Value *v);
......
#pragma once #pragma once
#include <functional> #include <functional>
#include <iostream>
#include <list> #include <list>
#include <string> #include <string>
#include <cassert> #include <cassert>
...@@ -13,43 +12,48 @@ struct Use; ...@@ -13,43 +12,48 @@ struct Use;
class Value { class Value {
public: public:
explicit Value(Type *ty, const std::string &name = "") Value(const Value& other) = delete;
: type_(ty), name_(name){}; Value(Value&& other) noexcept = delete;
virtual ~Value() { replace_all_use_with(nullptr); } Value& operator=(const Value& other) = delete;
Value& operator=(Value&& other) noexcept = delete;
std::string get_name() const { return name_; }; explicit Value(Type *ty, std::string name = "")
: type_(ty), name_(std::move(name)){}
virtual ~Value();
std::string get_name() const { return name_; }
Type *get_type() const { return type_; } Type *get_type() const { return type_; }
const std::list<Use> &get_use_list() const { return use_list_; } const std::list<Use> &get_use_list() const { return use_list_; }
bool set_name(std::string name); bool set_name(const std::string& name);
void add_use(User *user, unsigned arg_no); void add_use(User *user, unsigned arg_no);
void remove_use(User *user, unsigned arg_no); void remove_use(User *user, unsigned arg_no);
void replace_all_use_with(Value *new_val); void replace_all_use_with(Value *new_val) const;
void replace_use_with_if(Value *new_val, std::function<bool(Use *)> pred); void replace_use_with_if(Value *new_val, const std::function<bool(Use *)>& should_replace);
virtual std::string print() = 0; virtual std::string print() = 0;
template<typename T> template<typename T>
T *as() T *as()
{ {
static_assert(std::is_base_of<Value, T>::value, "T must be a subclass of Value"); static_assert(std::is_base_of_v<Value, T>, "T must be a subclass of Value");
const auto ptr = dynamic_cast<T*>(this); const auto ptr = dynamic_cast<T*>(this);
assert(ptr && "dynamic_cast failed"); assert(ptr && "dynamic_cast failed");
return ptr; return ptr;
} }
template<typename T> template<typename T>
[[nodiscard]] const T* as() const { const T* as() const {
static_assert(std::is_base_of<Value, T>::value, "T must be a subclass of Value"); static_assert(std::is_base_of_v<Value, T>, "T must be a subclass of Value");
const auto ptr = dynamic_cast<const T*>(this); const auto ptr = dynamic_cast<const T*>(this);
assert(ptr); assert(ptr);
return ptr; return ptr;
} }
// is 接口 // is 接口
template <typename T> template <typename T>
[[nodiscard]] bool is() const { bool is() const {
static_assert(std::is_base_of<Value, T>::value, "T must be a subclass of Value"); static_assert(std::is_base_of_v<Value, T>, "T must be a subclass of Value");
return dynamic_cast<const T*>(this); return dynamic_cast<const T*>(this);
} }
......
...@@ -5,21 +5,24 @@ ...@@ -5,21 +5,24 @@
#include "cminusf_builder.hpp" #include "cminusf_builder.hpp"
#define CONST_FP(num) ConstantFP::get((float)num, module.get()) #define CONST_FP(num) ConstantFP::get((float)(num), module)
#define CONST_INT(num) ConstantInt::get(num, module.get()) #define CONST_INT(num) ConstantInt::get((num), module)
// types // types
Type *VOID_T; namespace
Type *INT1_T; {
Type *INT32_T; Type* VOID_T;
Type *INT32PTR_T; Type* INT1_T;
Type *FLOAT_T; Type* INT32_T;
Type *FLOATPTR_T; Type* INT32PTR_T;
Type* FLOAT_T;
bool promote(IRBuilder *builder, Value **l_val_p, Value **r_val_p) { Type* FLOATPTR_T;
}
static bool promote(const IRBuilder *builder, Value **l_val_p, Value **r_val_p) {
bool is_int = false; bool is_int = false;
auto &l_val = *l_val_p; auto& l_val = *l_val_p;
auto &r_val = *r_val_p; auto& r_val = *r_val_p;
if (l_val->get_type() == r_val->get_type()) { if (l_val->get_type() == r_val->get_type()) {
is_int = l_val->get_type()->is_integer_type(); is_int = l_val->get_type()->is_integer_type();
} else { } else {
...@@ -63,7 +66,7 @@ Value* CminusfBuilder::visit(ASTNum &node) { ...@@ -63,7 +66,7 @@ Value* CminusfBuilder::visit(ASTNum &node) {
} }
Value* CminusfBuilder::visit(ASTVarDeclaration &node) { Value* CminusfBuilder::visit(ASTVarDeclaration &node) {
Type *var_type = nullptr; Type *var_type;
if (node.type == TYPE_INT) { if (node.type == TYPE_INT) {
var_type = module->get_int32_type(); var_type = module->get_int32_type();
} else { } else {
...@@ -72,14 +75,14 @@ Value* CminusfBuilder::visit(ASTVarDeclaration &node) { ...@@ -72,14 +75,14 @@ Value* CminusfBuilder::visit(ASTVarDeclaration &node) {
if (scope.in_global()) { if (scope.in_global()) {
if (node.num == nullptr) { if (node.num == nullptr) {
auto *initializer = ConstantZero::get(var_type, module.get()); auto *initializer = ConstantZero::get(var_type, module);
auto *var = GlobalVariable::create(node.id, module.get(), var_type, auto *var = GlobalVariable::create(node.id, module, var_type,
false, initializer); false, initializer);
scope.push(node.id, var); scope.push(node.id, var);
} else { } else {
auto *array_type = ArrayType::get(var_type, node.num->i_val); auto *array_type = ArrayType::get(var_type, node.num->i_val);
auto *initializer = ConstantZero::get(array_type, module.get()); auto *initializer = ConstantZero::get(array_type, module);
auto *var = GlobalVariable::create(node.id, module.get(), auto *var = GlobalVariable::create(node.id, module,
array_type, false, initializer); array_type, false, initializer);
scope.push(node.id, var); scope.push(node.id, var);
} }
...@@ -88,7 +91,7 @@ Value* CminusfBuilder::visit(ASTVarDeclaration &node) { ...@@ -88,7 +91,7 @@ Value* CminusfBuilder::visit(ASTVarDeclaration &node) {
auto nowBB = builder->get_insert_block(); auto nowBB = builder->get_insert_block();
auto entryBB = context.func->get_entry_block(); auto entryBB = context.func->get_entry_block();
builder->set_insert_point(entryBB); builder->set_insert_point(entryBB);
auto *var = builder->create_alloca_begin(var_type); auto *var = builder->create_alloca(var_type);
builder->set_insert_point(nowBB); builder->set_insert_point(nowBB);
scope.push(node.id, var); scope.push(node.id, var);
} else { } else {
...@@ -96,7 +99,7 @@ Value* CminusfBuilder::visit(ASTVarDeclaration &node) { ...@@ -96,7 +99,7 @@ Value* CminusfBuilder::visit(ASTVarDeclaration &node) {
auto entryBB = context.func->get_entry_block(); auto entryBB = context.func->get_entry_block();
builder->set_insert_point(entryBB); builder->set_insert_point(entryBB);
auto *array_type = ArrayType::get(var_type, node.num->i_val); auto *array_type = ArrayType::get(var_type, node.num->i_val);
auto *var = builder->create_alloca_begin(array_type); auto *var = builder->create_alloca(array_type);
builder->set_insert_point(nowBB); builder->set_insert_point(nowBB);
scope.push(node.id, var); scope.push(node.id, var);
} }
...@@ -105,7 +108,6 @@ Value* CminusfBuilder::visit(ASTVarDeclaration &node) { ...@@ -105,7 +108,6 @@ Value* CminusfBuilder::visit(ASTVarDeclaration &node) {
} }
Value* CminusfBuilder::visit(ASTFunDeclaration &node) { Value* CminusfBuilder::visit(ASTFunDeclaration &node) {
FunctionType *fun_type;
Type *ret_type; Type *ret_type;
std::vector<Type *> param_types; std::vector<Type *> param_types;
if (node.type == TYPE_INT) if (node.type == TYPE_INT)
...@@ -131,11 +133,11 @@ Value* CminusfBuilder::visit(ASTFunDeclaration &node) { ...@@ -131,11 +133,11 @@ Value* CminusfBuilder::visit(ASTFunDeclaration &node) {
} }
} }
fun_type = FunctionType::get(ret_type, param_types); FunctionType* fun_type = FunctionType::get(ret_type, param_types);
auto func = Function::create(fun_type, node.id, module.get()); auto func = Function::create(fun_type, node.id, module);
scope.push(node.id, func); scope.push(node.id, func);
context.func = func; context.func = func;
auto funBB = BasicBlock::create(module.get(), "entry", func); auto funBB = BasicBlock::create(module, "entry", func);
builder->set_insert_point(funBB); builder->set_insert_point(funBB);
scope.enter(); scope.enter();
context.pre_enter_scope = true; context.pre_enter_scope = true;
...@@ -145,7 +147,7 @@ Value* CminusfBuilder::visit(ASTFunDeclaration &node) { ...@@ -145,7 +147,7 @@ Value* CminusfBuilder::visit(ASTFunDeclaration &node) {
} }
for (unsigned int i = 0; i < node.params.size(); ++i) { for (unsigned int i = 0; i < node.params.size(); ++i) {
if (node.params[i]->isarray) { if (node.params[i]->isarray) {
Value *array_alloc = nullptr; Value *array_alloc;
if (node.params[i]->type == TYPE_INT) { if (node.params[i]->type == TYPE_INT) {
array_alloc = builder->create_alloca(INT32PTR_T); array_alloc = builder->create_alloca(INT32PTR_T);
} else { } else {
...@@ -154,7 +156,7 @@ Value* CminusfBuilder::visit(ASTFunDeclaration &node) { ...@@ -154,7 +156,7 @@ Value* CminusfBuilder::visit(ASTFunDeclaration &node) {
builder->create_store(args[i], array_alloc); builder->create_store(args[i], array_alloc);
scope.push(node.params[i]->id, array_alloc); scope.push(node.params[i]->id, array_alloc);
} else { } else {
Value *alloc = nullptr; Value *alloc;
if (node.params[i]->type == TYPE_INT) { if (node.params[i]->type == TYPE_INT) {
alloc = builder->create_alloca(INT32_T); alloc = builder->create_alloca(INT32_T);
} else { } else {
...@@ -216,10 +218,10 @@ Value* CminusfBuilder::visit(ASTExpressionStmt &node) { ...@@ -216,10 +218,10 @@ Value* CminusfBuilder::visit(ASTExpressionStmt &node) {
Value* CminusfBuilder::visit(ASTSelectionStmt &node) { Value* CminusfBuilder::visit(ASTSelectionStmt &node) {
auto *ret_val = node.expression->accept(*this); auto *ret_val = node.expression->accept(*this);
auto *trueBB = BasicBlock::create(module.get(), "", context.func); auto *trueBB = BasicBlock::create(module, "", context.func);
BasicBlock *falseBB{}; BasicBlock *falseBB{};
auto *contBB = BasicBlock::create(module.get(), "", context.func); auto *contBB = BasicBlock::create(module, "", context.func);
Value *cond_val = nullptr; Value *cond_val;
if (ret_val->get_type()->is_integer_type()) { if (ret_val->get_type()->is_integer_type()) {
cond_val = builder->create_icmp_ne(ret_val, CONST_INT(0)); cond_val = builder->create_icmp_ne(ret_val, CONST_INT(0));
} else { } else {
...@@ -229,7 +231,7 @@ Value* CminusfBuilder::visit(ASTSelectionStmt &node) { ...@@ -229,7 +231,7 @@ Value* CminusfBuilder::visit(ASTSelectionStmt &node) {
if (node.else_statement == nullptr) { if (node.else_statement == nullptr) {
builder->create_cond_br(cond_val, trueBB, contBB); builder->create_cond_br(cond_val, trueBB, contBB);
} else { } else {
falseBB = BasicBlock::create(module.get(), "", context.func); falseBB = BasicBlock::create(module, "", context.func);
builder->create_cond_br(cond_val, trueBB, falseBB); builder->create_cond_br(cond_val, trueBB, falseBB);
} }
builder->set_insert_point(trueBB); builder->set_insert_point(trueBB);
...@@ -254,16 +256,16 @@ Value* CminusfBuilder::visit(ASTSelectionStmt &node) { ...@@ -254,16 +256,16 @@ Value* CminusfBuilder::visit(ASTSelectionStmt &node) {
} }
Value* CminusfBuilder::visit(ASTIterationStmt &node) { Value* CminusfBuilder::visit(ASTIterationStmt &node) {
auto *exprBB = BasicBlock::create(module.get(), "", context.func); auto *exprBB = BasicBlock::create(module, "", context.func);
if (not builder->get_insert_block()->is_terminated()) { if (not builder->get_insert_block()->is_terminated()) {
builder->create_br(exprBB); builder->create_br(exprBB);
} }
builder->set_insert_point(exprBB); builder->set_insert_point(exprBB);
auto *ret_val = node.expression->accept(*this); auto *ret_val = node.expression->accept(*this);
auto *trueBB = BasicBlock::create(module.get(), "", context.func); auto *trueBB = BasicBlock::create(module, "", context.func);
auto *contBB = BasicBlock::create(module.get(), "", context.func); auto *contBB = BasicBlock::create(module, "", context.func);
Value *cond_val = nullptr; Value *cond_val;
if (ret_val->get_type()->is_integer_type()) { if (ret_val->get_type()->is_integer_type()) {
cond_val = builder->create_icmp_ne(ret_val, CONST_INT(0)); cond_val = builder->create_icmp_ne(ret_val, CONST_INT(0));
} else { } else {
...@@ -312,7 +314,7 @@ Value* CminusfBuilder::visit(ASTVar &node) { ...@@ -312,7 +314,7 @@ Value* CminusfBuilder::visit(ASTVar &node) {
var->get_type()->get_pointer_element_type()->is_pointer_type(); var->get_type()->get_pointer_element_type()->is_pointer_type();
bool should_return_lvalue = context.require_lvalue; bool should_return_lvalue = context.require_lvalue;
context.require_lvalue = false; context.require_lvalue = false;
Value *ret_val = nullptr; Value *ret_val;
if (node.expression == nullptr) { if (node.expression == nullptr) {
if (should_return_lvalue) { if (should_return_lvalue) {
ret_val = var; ret_val = var;
...@@ -327,14 +329,13 @@ Value* CminusfBuilder::visit(ASTVar &node) { ...@@ -327,14 +329,13 @@ Value* CminusfBuilder::visit(ASTVar &node) {
} }
} else { } else {
auto *val = node.expression->accept(*this); auto *val = node.expression->accept(*this);
Value *is_neg = nullptr; auto *exceptBB = BasicBlock::create(module, "", context.func);
auto *exceptBB = BasicBlock::create(module.get(), "", context.func); auto *contBB = BasicBlock::create(module, "", context.func);
auto *contBB = BasicBlock::create(module.get(), "", context.func);
if (val->get_type()->is_float_type()) { if (val->get_type()->is_float_type()) {
val = builder->create_fptosi(val, INT32_T); val = builder->create_fptosi(val, INT32_T);
} }
is_neg = builder->create_icmp_lt(val, CONST_INT(0)); Value* is_neg = builder->create_icmp_lt(val, CONST_INT(0));
builder->create_cond_br(is_neg, exceptBB, contBB); builder->create_cond_br(is_neg, exceptBB, contBB);
builder->set_insert_point(exceptBB); builder->set_insert_point(exceptBB);
...@@ -349,7 +350,7 @@ Value* CminusfBuilder::visit(ASTVar &node) { ...@@ -349,7 +350,7 @@ Value* CminusfBuilder::visit(ASTVar &node) {
} }
builder->set_insert_point(contBB); builder->set_insert_point(contBB);
Value *tmp_ptr = nullptr; Value *tmp_ptr;
if (is_int || is_float) { if (is_int || is_float) {
tmp_ptr = builder->create_gep(var, {val}); tmp_ptr = builder->create_gep(var, {val});
} else if (is_ptr) { } else if (is_ptr) {
...@@ -391,7 +392,7 @@ Value* CminusfBuilder::visit(ASTSimpleExpression &node) { ...@@ -391,7 +392,7 @@ Value* CminusfBuilder::visit(ASTSimpleExpression &node) {
auto *l_val = node.additive_expression_l->accept(*this); auto *l_val = node.additive_expression_l->accept(*this);
auto *r_val = node.additive_expression_r->accept(*this); auto *r_val = node.additive_expression_r->accept(*this);
bool is_int = promote(&*builder, &l_val, &r_val); bool is_int = promote(builder, &l_val, &r_val);
Value *cmp = nullptr; Value *cmp = nullptr;
switch (node.op) { switch (node.op) {
case OP_LT: case OP_LT:
...@@ -448,7 +449,7 @@ Value* CminusfBuilder::visit(ASTAdditiveExpression &node) { ...@@ -448,7 +449,7 @@ Value* CminusfBuilder::visit(ASTAdditiveExpression &node) {
auto *l_val = node.additive_expression->accept(*this); auto *l_val = node.additive_expression->accept(*this);
auto *r_val = node.term->accept(*this); auto *r_val = node.term->accept(*this);
bool is_int = promote(&*builder, &l_val, &r_val); bool is_int = promote(builder, &l_val, &r_val);
Value *ret_val = nullptr; Value *ret_val = nullptr;
switch (node.op) { switch (node.op) {
case OP_PLUS: case OP_PLUS:
...@@ -476,7 +477,7 @@ Value* CminusfBuilder::visit(ASTTerm &node) { ...@@ -476,7 +477,7 @@ Value* CminusfBuilder::visit(ASTTerm &node) {
auto *l_val = node.term->accept(*this); auto *l_val = node.term->accept(*this);
auto *r_val = node.factor->accept(*this); auto *r_val = node.factor->accept(*this);
bool is_int = promote(&*builder, &l_val, &r_val); bool is_int = promote(builder, &l_val, &r_val);
Value *ret_val = nullptr; Value *ret_val = nullptr;
switch (node.op) { switch (node.op) {
...@@ -513,7 +514,7 @@ Value* CminusfBuilder::visit(ASTCall &node) { ...@@ -513,7 +514,7 @@ Value* CminusfBuilder::visit(ASTCall &node) {
} }
} }
args.push_back(arg_val); args.push_back(arg_val);
param_type++; ++param_type;
} }
return builder->create_call(static_cast<Function *>(func), args); return builder->create_call(static_cast<Function *>(func), args);
......
...@@ -46,7 +46,7 @@ int main(int argc, char **argv) { ...@@ -46,7 +46,7 @@ int main(int argc, char **argv) {
ASTPrinter printer; ASTPrinter printer;
ast.run_visitor(printer); ast.run_visitor(printer);
} else { } else {
std::unique_ptr<Module> m; Module* m;
CminusfBuilder builder; CminusfBuilder builder;
ast.run_visitor(builder); ast.run_visitor(builder);
m = builder.getModule(); m = builder.getModule();
...@@ -58,11 +58,13 @@ int main(int argc, char **argv) { ...@@ -58,11 +58,13 @@ int main(int argc, char **argv) {
output_stream << "source_filename = " << abs_path << "\n\n"; output_stream << "source_filename = " << abs_path << "\n\n";
output_stream << m->print(); output_stream << m->print();
} else if (config.emitasm) { } else if (config.emitasm) {
CodeGen codegen(m.get()); CodeGen codegen(m);
codegen.run(); codegen.run();
output_stream << codegen.print(); output_stream << codegen.print();
} }
delete m;
// TODO: lab4 (IR optimization or codegen) // TODO: lab4 (IR optimization or codegen)
} }
......
...@@ -15,16 +15,16 @@ void CodeGen::allocate() { ...@@ -15,16 +15,16 @@ void CodeGen::allocate() {
// 为指令结果分配栈空间 // 为指令结果分配栈空间
for (auto bb : context.func->get_basic_blocks()) { for (auto bb : context.func->get_basic_blocks()) {
for (auto& instr : bb->get_instructions()) { for (auto instr : bb->get_instructions()) {
// 每个非 void 的定值都分配栈空间 // 每个非 void 的定值都分配栈空间
if (not instr.is_void()) { if (not instr->is_void()) {
auto size = instr.get_type()->get_size(); auto size = instr->get_type()->get_size();
offset = offset + size; offset = offset + size;
context.offset_map[&instr] = -static_cast<int>(offset); context.offset_map[instr] = -static_cast<int>(offset);
} }
// alloca 的副作用:分配额外空间 // alloca 的副作用:分配额外空间
if (instr.is_alloca()) { if (instr->is_alloca()) {
auto *alloca_inst = static_cast<AllocaInst *>(&instr); auto *alloca_inst = dynamic_cast<AllocaInst *>(instr);
auto alloc_size = alloca_inst->get_alloca_type()->get_size(); auto alloc_size = alloca_inst->get_alloca_type()->get_size();
offset += alloc_size; offset += alloc_size;
} }
...@@ -36,20 +36,20 @@ void CodeGen::allocate() { ...@@ -36,20 +36,20 @@ void CodeGen::allocate() {
} }
void CodeGen::copy_stmt() { void CodeGen::copy_stmt() {
for (auto &succ : context.bb->get_succ_basic_blocks()) { for (auto succ : context.bb->get_succ_basic_blocks()) {
for (auto &inst : succ->get_instructions()) { for (auto inst : succ->get_instructions()) {
if (inst.is_phi()) { if (inst->is_phi()) {
// 遍历后继块中 phi 的定值 bb // 遍历后继块中 phi 的定值 bb
for (unsigned i = 1; i < inst.get_operands().size(); i += 2) { for (unsigned i = 1; i < inst->get_operands().size(); i += 2) {
// phi 的定值 bb 是当前翻译块 // phi 的定值 bb 是当前翻译块
if (inst.get_operand(i) == context.bb) { if (inst->get_operand(i) == context.bb) {
auto *lvalue = inst.get_operand(i - 1); auto *lvalue = inst->get_operand(i - 1);
if (lvalue->get_type()->is_float_type()) { if (lvalue->get_type()->is_float_type()) {
load_to_freg(lvalue, FReg::fa(0)); load_to_freg(lvalue, FReg::fa(0));
store_from_freg(&inst, FReg::fa(0)); store_from_freg(inst, FReg::fa(0));
} else { } else {
load_to_greg(lvalue, Reg::a(0)); load_to_greg(lvalue, Reg::a(0));
store_from_greg(&inst, Reg::a(0)); store_from_greg(inst, Reg::a(0));
} }
break; break;
} }
...@@ -373,16 +373,16 @@ void CodeGen::run() { ...@@ -373,16 +373,16 @@ void CodeGen::run() {
append_inst(".text", ASMInstruction::Atrribute); append_inst(".text", ASMInstruction::Atrribute);
append_inst(".section", {".bss", "\"aw\"", "@nobits"}, append_inst(".section", {".bss", "\"aw\"", "@nobits"},
ASMInstruction::Atrribute); ASMInstruction::Atrribute);
for (auto &global : m->get_global_variable()) { for (auto global : m->get_global_variable()) {
auto size = auto size =
global.get_type()->get_pointer_element_type()->get_size(); global->get_type()->get_pointer_element_type()->get_size();
append_inst(".globl", {global.get_name()}, append_inst(".globl", {global->get_name()},
ASMInstruction::Atrribute); ASMInstruction::Atrribute);
append_inst(".type", {global.get_name(), "@object"}, append_inst(".type", {global->get_name(), "@object"},
ASMInstruction::Atrribute); ASMInstruction::Atrribute);
append_inst(".size", {global.get_name(), std::to_string(size)}, append_inst(".size", {global->get_name(), std::to_string(size)},
ASMInstruction::Atrribute); ASMInstruction::Atrribute);
append_inst(global.get_name(), ASMInstruction::Label); append_inst(global->get_name(), ASMInstruction::Label);
append_inst(".space", {std::to_string(size)}, append_inst(".space", {std::to_string(size)},
ASMInstruction::Atrribute); ASMInstruction::Atrribute);
} }
...@@ -390,31 +390,31 @@ void CodeGen::run() { ...@@ -390,31 +390,31 @@ void CodeGen::run() {
// 函数代码段 // 函数代码段
output.emplace_back(".text", ASMInstruction::Atrribute); output.emplace_back(".text", ASMInstruction::Atrribute);
for (auto &func : m->get_functions()) { for (auto func : m->get_functions()) {
if (not func.is_declaration()) { if (not func->is_declaration()) {
// 更新 context // 更新 context
context.clear(); context.clear();
context.func = &func; context.func = func;
// 函数信息 // 函数信息
append_inst(".globl", {func.get_name()}, ASMInstruction::Atrribute); append_inst(".globl", {func->get_name()}, ASMInstruction::Atrribute);
append_inst(".type", {func.get_name(), "@function"}, append_inst(".type", {func->get_name(), "@function"},
ASMInstruction::Atrribute); ASMInstruction::Atrribute);
append_inst(func.get_name(), ASMInstruction::Label); append_inst(func->get_name(), ASMInstruction::Label);
// 分配函数栈帧 // 分配函数栈帧
allocate(); allocate();
// 生成 prologue // 生成 prologue
gen_prologue(); gen_prologue();
for (auto bb : func.get_basic_blocks()) { for (auto bb : func->get_basic_blocks()) {
context.bb = bb; context.bb = bb;
append_inst(label_name(context.bb), ASMInstruction::Label); append_inst(label_name(context.bb), ASMInstruction::Label);
for (auto &instr : bb->get_instructions()) { for (auto instr : bb->get_instructions()) {
// For debug // For debug
append_inst(instr.print(), ASMInstruction::Comment); append_inst(instr->print(), ASMInstruction::Comment);
context.inst = &instr; // 更新 context context.inst = instr; // 更新 context
switch (instr.get_instr_type()) { switch (instr->get_instr_type()) {
case Instruction::ret: case Instruction::ret:
gen_ret(); gen_ret();
break; break;
......
...@@ -396,11 +396,13 @@ Value* ASTCall::accept(ASTVisitor &visitor) { return visitor.visit(*this); } ...@@ -396,11 +396,13 @@ Value* ASTCall::accept(ASTVisitor &visitor) { return visitor.visit(*this); }
#define _DEBUG_PRINT_N_(N) \ #define _DEBUG_PRINT_N_(N) \
{ std::cout << std::string(N, '-'); } { std::cout << std::string(N, '-'); }
ASTVisitor::~ASTVisitor() = default;
Value* ASTPrinter::visit(ASTProgram &node) { Value* ASTPrinter::visit(ASTProgram &node) {
_DEBUG_PRINT_N_(depth); _DEBUG_PRINT_N_(depth);
std::cout << "program" << std::endl; std::cout << "program" << std::endl;
add_depth(); add_depth();
for (auto decl : node.declarations) { for (auto &decl : node.declarations) {
decl->accept(*this); decl->accept(*this);
} }
remove_depth(); remove_depth();
...@@ -437,7 +439,7 @@ Value* ASTPrinter::visit(ASTFunDeclaration &node) { ...@@ -437,7 +439,7 @@ Value* ASTPrinter::visit(ASTFunDeclaration &node) {
_DEBUG_PRINT_N_(depth); _DEBUG_PRINT_N_(depth);
std::cout << "fun-declaration: " << node.id << std::endl; std::cout << "fun-declaration: " << node.id << std::endl;
add_depth(); add_depth();
for (auto param : node.params) { for (auto &param : node.params) {
param->accept(*this); param->accept(*this);
} }
...@@ -459,11 +461,11 @@ Value* ASTPrinter::visit(ASTCompoundStmt &node) { ...@@ -459,11 +461,11 @@ Value* ASTPrinter::visit(ASTCompoundStmt &node) {
_DEBUG_PRINT_N_(depth); _DEBUG_PRINT_N_(depth);
std::cout << "compound-stmt" << std::endl; std::cout << "compound-stmt" << std::endl;
add_depth(); add_depth();
for (auto decl : node.local_declarations) { for (auto &decl : node.local_declarations) {
decl->accept(*this); decl->accept(*this);
} }
for (auto stmt : node.statement_list) { for (auto &stmt : node.statement_list) {
stmt->accept(*this); stmt->accept(*this);
} }
remove_depth(); remove_depth();
...@@ -625,7 +627,7 @@ Value* ASTPrinter::visit(ASTCall &node) { ...@@ -625,7 +627,7 @@ Value* ASTPrinter::visit(ASTCall &node) {
_DEBUG_PRINT_N_(depth); _DEBUG_PRINT_N_(depth);
std::cout << "call: " << node.id << "()" << std::endl; std::cout << "call: " << node.id << "()" << std::endl;
add_depth(); add_depth();
for (auto arg : node.args) { for (auto &arg : node.args) {
arg->accept(*this); arg->accept(*this);
} }
remove_depth(); remove_depth();
......
...@@ -13,14 +13,14 @@ BasicBlock::BasicBlock(Module *m, const std::string &name = "", ...@@ -13,14 +13,14 @@ BasicBlock::BasicBlock(Module *m, const std::string &name = "",
parent_->add_basic_block(this); parent_->add_basic_block(this);
} }
Module *BasicBlock::get_module() { return get_parent()->get_parent(); } Module *BasicBlock::get_module() const { return get_parent()->get_parent(); }
void BasicBlock::erase_from_parent() { this->get_parent()->remove(this); } void BasicBlock::erase_from_parent() { this->get_parent()->remove(this); }
bool BasicBlock::is_terminated() const { bool BasicBlock::is_terminated() const {
if (instr_list_.empty()) { if (instr_list_.empty()) {
return false; return false;
} }
switch (instr_list_.back().get_instr_type()) { switch (instr_list_.back()->get_instr_type()) {
case Instruction::ret: case Instruction::ret:
case Instruction::br: case Instruction::br:
return true; return true;
...@@ -29,17 +29,49 @@ bool BasicBlock::is_terminated() const { ...@@ -29,17 +29,49 @@ bool BasicBlock::is_terminated() const {
} }
} }
Instruction *BasicBlock::get_terminator() { Instruction *BasicBlock::get_terminator() const
{
assert(is_terminated() && assert(is_terminated() &&
"Trying to get terminator from an bb which is not terminated"); "Trying to get terminator from an bb which is not terminated");
return &instr_list_.back(); return instr_list_.back();
} }
void BasicBlock::add_instruction(Instruction *instr) { void BasicBlock::add_instruction(Instruction *instr) {
if (instr->is_alloca() || instr->is_phi())
{
auto it = instr_list_.begin();
for (; it != instr_list_.end() && ((*it)->is_alloca() || (*it)->is_phi()); ++it){}
instr_list_.emplace(it, instr);
return;
}
assert(not is_terminated() && "Inserting instruction to terminated bb"); assert(not is_terminated() && "Inserting instruction to terminated bb");
instr_list_.push_back(instr); instr_list_.push_back(instr);
} }
void BasicBlock::add_instr_begin(Instruction* instr)
{
if (instr->is_alloca() || instr->is_phi())
instr_list_.push_front(instr);
else
{
auto it = instr_list_.begin();
for (; it != instr_list_.end() && ((*it)->is_alloca() || (*it)->is_phi()); ++it) {}
instr_list_.emplace(it, instr);
}
}
void BasicBlock::add_instr_before_terminator(Instruction* instr) {
if (instr->is_alloca() || instr->is_phi())
{
auto it = instr_list_.begin();
for (; it != instr_list_.end() && ((*it)->is_alloca() || (*it)->is_phi()); ++it) {}
instr_list_.emplace(it, instr);
return;
}
if (!is_terminated()) instr_list_.emplace_back(instr);
else instr_list_.insert(std::prev(instr_list_.end()), instr);
}
std::string BasicBlock::print() { std::string BasicBlock::print() {
std::string bb_ir; std::string bb_ir;
bb_ir += this->get_name(); bb_ir += this->get_name();
...@@ -61,16 +93,22 @@ std::string BasicBlock::print() { ...@@ -61,16 +93,22 @@ std::string BasicBlock::print() {
bb_ir += "; Error: Block without parent!"; bb_ir += "; Error: Block without parent!";
} }
bb_ir += "\n"; bb_ir += "\n";
for (auto &instr : this->get_instructions()) { for (auto instr : this->get_instructions()) {
bb_ir += " "; bb_ir += " ";
bb_ir += instr.print(); bb_ir += instr->print();
bb_ir += "\n"; bb_ir += "\n";
} }
return bb_ir; return bb_ir;
} }
BasicBlock* BasicBlock::get_entry_block_of_same_function(){ BasicBlock::~BasicBlock()
assert((not (parent_ == nullptr)) && "bb have no parent function"); {
for (auto inst : instr_list_) delete inst;
}
BasicBlock* BasicBlock::get_entry_block_of_same_function() const
{
assert(parent_ != nullptr && "bb have no parent function");
return parent_->get_entry_block(); return parent_->get_entry_block();
} }
...@@ -14,5 +14,4 @@ add_library( ...@@ -14,5 +14,4 @@ add_library(
target_link_libraries( target_link_libraries(
IR_lib IR_lib
LLVMSupport
) )
#include "Constant.hpp" #include "Constant.hpp"
#include <cstring>
#include "Module.hpp" #include "Module.hpp"
#include <iostream>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include <unordered_map> #include <unordered_map>
...@@ -27,6 +29,8 @@ static std::unordered_map<std::pair<float, Module *>, ...@@ -27,6 +29,8 @@ static std::unordered_map<std::pair<float, Module *>,
cached_float; cached_float;
static std::unordered_map<Type *, std::unique_ptr<ConstantZero>> cached_zero; static std::unordered_map<Type *, std::unique_ptr<ConstantZero>> cached_zero;
Constant::~Constant() = default;
ConstantInt *ConstantInt::get(int val, Module *m) { ConstantInt *ConstantInt::get(int val, Module *m) {
if (cached_int.find(std::make_pair(val, m)) != cached_int.end()) if (cached_int.find(std::make_pair(val, m)) != cached_int.end())
return cached_int[std::make_pair(val, m)].get(); return cached_int[std::make_pair(val, m)].get();
...@@ -62,7 +66,8 @@ ConstantArray::ConstantArray(ArrayType *ty, const std::vector<Constant *> &val) ...@@ -62,7 +66,8 @@ ConstantArray::ConstantArray(ArrayType *ty, const std::vector<Constant *> &val)
this->const_array.assign(val.begin(), val.end()); this->const_array.assign(val.begin(), val.end());
} }
Constant *ConstantArray::get_element_value(int index) { Constant *ConstantArray::get_element_value(int index) const
{
return this->const_array[index]; return this->const_array[index];
} }
...@@ -76,7 +81,7 @@ std::string ConstantArray::print() { ...@@ -76,7 +81,7 @@ std::string ConstantArray::print() {
const_ir += this->get_type()->print(); const_ir += this->get_type()->print();
const_ir += " "; const_ir += " ";
const_ir += "["; const_ir += "[";
for (unsigned i = 0; i < this->get_size_of_array(); i++) { for (int i = 0; i < this->get_size_of_array(); i++) {
Constant *element = get_element_value(i); Constant *element = get_element_value(i);
if (!dynamic_cast<ConstantArray *>(get_element_value(i))) { if (!dynamic_cast<ConstantArray *>(get_element_value(i))) {
const_ir += element->get_type()->print(); const_ir += element->get_type()->print();
...@@ -102,7 +107,9 @@ std::string ConstantFP::print() { ...@@ -102,7 +107,9 @@ std::string ConstantFP::print() {
std::stringstream fp_ir_ss; std::stringstream fp_ir_ss;
std::string fp_ir; std::string fp_ir;
double val = this->get_value(); double val = this->get_value();
fp_ir_ss << "0x" << std::hex << *(uint64_t *)&val << std::endl; uint64_t uval;
memcpy(&uval, &val, sizeof(double));
fp_ir_ss << "0x" << std::hex << uval << '\n';
fp_ir_ss >> fp_ir; fp_ir_ss >> fp_ir;
return fp_ir; return fp_ir;
} }
......
...@@ -15,13 +15,19 @@ Function::Function(FunctionType *ty, const std::string &name, Module *parent) ...@@ -15,13 +15,19 @@ Function::Function(FunctionType *ty, const std::string &name, Module *parent)
arguments_.emplace_back(ty->get_param_type(i), "", this, i); arguments_.emplace_back(ty->get_param_type(i), "", this, i);
} }
} }
Function::~Function()
{
for (auto bb : basic_blocks_) delete bb;
}
Function *Function::create(FunctionType *ty, const std::string &name, Function *Function::create(FunctionType *ty, const std::string &name,
Module *parent) { Module *parent) {
return new Function(ty, name, parent); return new Function(ty, name, parent);
} }
FunctionType *Function::get_function_type() const { FunctionType *Function::get_function_type() const {
return static_cast<FunctionType *>(get_type()); return dynamic_cast<FunctionType *>(get_type());
} }
Type *Function::get_return_type() const { Type *Function::get_return_type() const {
...@@ -32,7 +38,7 @@ unsigned Function::get_num_of_args() const { ...@@ -32,7 +38,7 @@ unsigned Function::get_num_of_args() const {
return get_function_type()->get_num_of_args(); return get_function_type()->get_num_of_args();
} }
unsigned Function::get_num_basic_blocks() const { return basic_blocks_.size(); } unsigned Function::get_num_basic_blocks() const { return static_cast<unsigned>(basic_blocks_.size()); }
Module *Function::get_parent() const { return parent_; } Module *Function::get_parent() const { return parent_; }
...@@ -65,11 +71,11 @@ void Function::set_instr_name() { ...@@ -65,11 +71,11 @@ void Function::set_instr_name() {
seq.insert({bb, seq_num}); seq.insert({bb, seq_num});
} }
} }
for (auto &instr : bb->get_instructions()) { for (auto instr : bb->get_instructions()) {
if (!instr.is_void() && seq.find(&instr) == seq.end()) { if (!instr->is_void() && seq.find(instr) == seq.end()) {
auto seq_num = seq.size() + seq_cnt_; auto seq_num = seq.size() + seq_cnt_;
if (instr.set_name("op" + std::to_string(seq_num))) { if (instr->set_name("op" + std::to_string(seq_num))) {
seq.insert({&instr, seq_num}); seq.insert({instr, seq_num});
} }
} }
} }
...@@ -96,7 +102,7 @@ std::string Function::print() { ...@@ -96,7 +102,7 @@ std::string Function::print() {
for (unsigned i = 0; i < this->get_num_of_args(); i++) { for (unsigned i = 0; i < this->get_num_of_args(); i++) {
if (i) if (i)
func_ir += ", "; func_ir += ", ";
func_ir += static_cast<FunctionType *>(this->get_type()) func_ir += dynamic_cast<FunctionType *>(this->get_type())
->get_param_type(i) ->get_param_type(i)
->print(); ->print();
} }
...@@ -115,7 +121,7 @@ std::string Function::print() { ...@@ -115,7 +121,7 @@ std::string Function::print() {
} else { } else {
func_ir += " {"; func_ir += " {";
func_ir += "\n"; func_ir += "\n";
for (auto &bb : this->get_basic_blocks()) { for (auto bb : this->get_basic_blocks()) {
func_ir += bb->print(); func_ir += bb->print();
} }
func_ir += "}"; func_ir += "}";
...@@ -133,7 +139,7 @@ std::string Argument::print() { ...@@ -133,7 +139,7 @@ std::string Argument::print() {
} }
void Function::check_for_block_relation_error() void Function::check_for_block_relation_error() const
{ {
// 检查函数的基本块表是否包含所有且仅包含 get_parent 是本函数的基本块 // 检查函数的基本块表是否包含所有且仅包含 get_parent 是本函数的基本块
std::unordered_set<BasicBlock*> bbs; std::unordered_set<BasicBlock*> bbs;
...@@ -172,7 +178,7 @@ void Function::check_for_block_relation_error() ...@@ -172,7 +178,7 @@ void Function::check_for_block_relation_error()
for (auto bb : basic_blocks_) for (auto bb : basic_blocks_)
{ {
assert((!bb->get_instructions().empty()) && "发现了空基本块"); assert((!bb->get_instructions().empty()) && "发现了空基本块");
auto b = &bb->get_instructions().back(); auto b = bb->get_instructions().back();
assert((b->is_br() || b->is_ret()) && "发现了无 terminator 基本块"); assert((b->is_br() || b->is_ret()) && "发现了无 terminator 基本块");
assert((b->is_br() || bb->get_succ_basic_blocks().empty()) && "某基本块末尾是 ret 指令但是后继块表不是空的, 或者末尾是 br 但后继表为空"); assert((b->is_br() || bb->get_succ_basic_blocks().empty()) && "某基本块末尾是 ret 指令但是后继块表不是空的, 或者末尾是 br 但后继表为空");
} }
...@@ -184,12 +190,14 @@ void Function::check_for_block_relation_error() ...@@ -184,12 +190,14 @@ void Function::check_for_block_relation_error()
std::unordered_set<BasicBlock*> br_get; std::unordered_set<BasicBlock*> br_get;
for (auto suc : bb->get_succ_basic_blocks()) for (auto suc : bb->get_succ_basic_blocks())
suc_table.emplace(suc); suc_table.emplace(suc);
auto& ops = bb->get_instructions().back().get_operands(); auto& ops = bb->get_instructions().back()->get_operands();
for (auto i : ops) for (auto i : ops)
{ {
auto bb2 = dynamic_cast<BasicBlock*>(i); auto bb2 = dynamic_cast<BasicBlock*>(i);
if(bb2 != nullptr) br_get.emplace(bb2); if(bb2 != nullptr) br_get.emplace(bb2);
} }
// 这三个检查保证有问题会报错,但不保证每次报错的基本块都相同
// 例如 A, B 两个基本块都存在问题,可能有时 A 报错有时 B 报错
for(auto i : suc_table) for(auto i : suc_table)
assert(br_get.count(i) && "基本块 A 的后继块有 B,但 B 并未在 A 的 br 指令中出现"); assert(br_get.count(i) && "基本块 A 的后继块有 B,但 B 并未在 A 的 br 指令中出现");
for(auto i : br_get) for(auto i : br_get)
...@@ -207,7 +215,7 @@ void Function::check_for_block_relation_error() ...@@ -207,7 +215,7 @@ void Function::check_for_block_relation_error()
} }
} }
// 检查基本块前驱后继关系是否是与 branch 指令对应 // 检查基本块前驱后继关系是否是与 branch 指令对应
for (auto& bb : basic_blocks_) for (auto bb : basic_blocks_)
{ {
for (auto pre : bb->get_pre_basic_blocks()) for (auto pre : bb->get_pre_basic_blocks())
{ {
...@@ -225,11 +233,11 @@ void Function::check_for_block_relation_error() ...@@ -225,11 +233,11 @@ void Function::check_for_block_relation_error()
} }
} }
// 检查指令 parent 设置 // 检查指令 parent 设置
for (auto& bb : basic_blocks_) for (auto bb : basic_blocks_)
{ {
for (auto& inst : bb->get_instructions()) for (auto inst : bb->get_instructions())
{ {
assert ((inst.get_parent() == bb) && "基本块 A 指令表包含指令 b, 但是 b 的 get_parent 函数不返回 A"); assert ((inst->get_parent() == bb) && "基本块 A 指令表包含指令 b, 但是 b 的 get_parent 函数不返回 A");
} }
} }
} }
#include "GlobalVariable.hpp" #include "GlobalVariable.hpp"
#include "IRprinter.hpp" #include "IRprinter.hpp"
#include "Module.hpp"
GlobalVariable::GlobalVariable(std::string name, Module *m, Type *ty, GlobalVariable::GlobalVariable(const std::string& name, Module *m, Type *ty,
bool is_const, Constant *init) bool is_const, Constant *init)
: User(ty, name), is_const_(is_const), init_val_(init) { : User(ty, name), is_const_(is_const), init_val_(init) {
m->add_global_variable(this); m->add_global_variable(this);
...@@ -10,7 +11,7 @@ GlobalVariable::GlobalVariable(std::string name, Module *m, Type *ty, ...@@ -10,7 +11,7 @@ GlobalVariable::GlobalVariable(std::string name, Module *m, Type *ty,
} }
} // global操作数为initval } // global操作数为initval
GlobalVariable *GlobalVariable::create(std::string name, Module *m, Type *ty, GlobalVariable *GlobalVariable::create(const std::string& name, Module *m, Type *ty,
bool is_const, bool is_const,
Constant *init = nullptr) { Constant *init = nullptr) {
return new GlobalVariable(name, m, PointerType::get(ty), is_const, init); return new GlobalVariable(name, m, PointerType::get(ty), is_const, init);
......
#include "IRprinter.hpp" #include "IRprinter.hpp"
#include "Instruction.hpp" #include "Instruction.hpp"
#include <cassert> #include <cassert>
#include <type_traits> #include "Constant.hpp"
#include "Function.hpp"
#include "GlobalVariable.hpp"
std::string print_as_op(Value *v, bool print_ty) { std::string print_as_op(Value *v, bool print_ty) {
std::string op_ir; std::string op_ir;
...@@ -10,9 +12,7 @@ std::string print_as_op(Value *v, bool print_ty) { ...@@ -10,9 +12,7 @@ std::string print_as_op(Value *v, bool print_ty) {
op_ir += " "; op_ir += " ";
} }
if (dynamic_cast<GlobalVariable *>(v)) { if (dynamic_cast<GlobalVariable *>(v) || dynamic_cast<Function*>(v)) {
op_ir += "@" + v->get_name();
} else if (dynamic_cast<Function *>(v)) {
op_ir += "@" + v->get_name(); op_ir += "@" + v->get_name();
} else if (dynamic_cast<Constant *>(v)) { } else if (dynamic_cast<Constant *>(v)) {
op_ir += v->print(); op_ir += v->print();
...@@ -91,7 +91,8 @@ std::string print_instr_op_name(Instruction::OpID id) { ...@@ -91,7 +91,8 @@ std::string print_instr_op_name(Instruction::OpID id) {
assert(false && "Must be bug"); assert(false && "Must be bug");
} }
template <class BinInst> std::string print_binary_inst(const BinInst &inst) { template <class BinInst>
static std::string print_binary_inst(const BinInst &inst) {
std::string instr_ir; std::string instr_ir;
instr_ir += "%"; instr_ir += "%";
instr_ir += inst.get_name(); instr_ir += inst.get_name();
...@@ -112,7 +113,8 @@ template <class BinInst> std::string print_binary_inst(const BinInst &inst) { ...@@ -112,7 +113,8 @@ template <class BinInst> std::string print_binary_inst(const BinInst &inst) {
std::string IBinaryInst::print() { return print_binary_inst(*this); } std::string IBinaryInst::print() { return print_binary_inst(*this); }
std::string FBinaryInst::print() { return print_binary_inst(*this); } std::string FBinaryInst::print() { return print_binary_inst(*this); }
template <class CMP> std::string print_cmp_inst(const CMP &inst) { template <class CMP>
static std::string print_cmp_inst(const CMP &inst) {
std::string cmp_type; std::string cmp_type;
if (inst.is_cmp()) if (inst.is_cmp())
cmp_type = "icmp"; cmp_type = "icmp";
......
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <stdexcept>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -17,8 +16,10 @@ Instruction::Instruction(Type *ty, OpID id, BasicBlock *parent) ...@@ -17,8 +16,10 @@ Instruction::Instruction(Type *ty, OpID id, BasicBlock *parent)
parent->add_instruction(this); parent->add_instruction(this);
} }
Function *Instruction::get_function() { return parent_->get_parent(); } Instruction::~Instruction() = default;
Module *Instruction::get_module() { return parent_->get_module(); }
Function *Instruction::get_function() const { return parent_->get_parent(); }
Module *Instruction::get_module() const { return parent_->get_module(); }
std::string Instruction::get_instr_op_name() const { std::string Instruction::get_instr_op_name() const {
return print_instr_op_name(op_id_); return print_instr_op_name(op_id_);
...@@ -122,12 +123,12 @@ FCmpInst *FCmpInst::create_fne(Value *v1, Value *v2, BasicBlock *bb) { ...@@ -122,12 +123,12 @@ FCmpInst *FCmpInst::create_fne(Value *v1, Value *v2, BasicBlock *bb) {
return create(fne, v1, v2, bb); return create(fne, v1, v2, bb);
} }
CallInst::CallInst(Function *func, std::vector<Value *> args, BasicBlock *bb) CallInst::CallInst(Function *func, const std::vector<Value *>& args, BasicBlock *bb)
: BaseInst<CallInst>(func->get_return_type(), call, bb) { : BaseInst<CallInst>(func->get_return_type(), call, bb) {
assert(func->get_type()->is_function_type() && "Not a function"); assert(func->get_type()->is_function_type() && "Not a function");
assert((func->get_num_of_args() == args.size()) && "Wrong number of args"); assert((func->get_num_of_args() == args.size()) && "Wrong number of args");
add_operand(func); add_operand(func);
auto func_type = static_cast<FunctionType *>(func->get_type()); auto func_type = dynamic_cast<FunctionType *>(func->get_type());
for (unsigned i = 0; i < args.size(); i++) { for (unsigned i = 0; i < args.size(); i++) {
assert(func_type->get_param_type(i) == args[i]->get_type() && assert(func_type->get_param_type(i) == args[i]->get_type() &&
"CallInst: Wrong arg type"); "CallInst: Wrong arg type");
...@@ -135,13 +136,13 @@ CallInst::CallInst(Function *func, std::vector<Value *> args, BasicBlock *bb) ...@@ -135,13 +136,13 @@ CallInst::CallInst(Function *func, std::vector<Value *> args, BasicBlock *bb)
} }
} }
CallInst *CallInst::create_call(Function *func, std::vector<Value *> args, CallInst *CallInst::create_call(Function *func, const std::vector<Value *>& args,
BasicBlock *bb) { BasicBlock *bb) {
return create(func, args, bb); return create(func, args, bb);
} }
FunctionType *CallInst::get_function_type() const { FunctionType *CallInst::get_function_type() const {
return static_cast<FunctionType *>(get_operand(0)->get_type()); return dynamic_cast<FunctionType *>(get_operand(0)->get_type());
} }
BranchInst::BranchInst(Value *cond, BasicBlock *if_true, BasicBlock *if_false, BranchInst::BranchInst(Value *cond, BasicBlock *if_true, BasicBlock *if_false,
...@@ -170,10 +171,10 @@ BranchInst::BranchInst(Value *cond, BasicBlock *if_true, BasicBlock *if_false, ...@@ -170,10 +171,10 @@ BranchInst::BranchInst(Value *cond, BasicBlock *if_true, BasicBlock *if_false,
BranchInst::~BranchInst() { BranchInst::~BranchInst() {
std::list<BasicBlock *> succs; std::list<BasicBlock *> succs;
if (is_cond_br()) { if (is_cond_br()) {
succs.push_back(static_cast<BasicBlock *>(get_operand(1))); succs.push_back(dynamic_cast<BasicBlock *>(get_operand(1)));
succs.push_back(static_cast<BasicBlock *>(get_operand(2))); succs.push_back(dynamic_cast<BasicBlock *>(get_operand(2)));
} else { } else {
succs.push_back(static_cast<BasicBlock *>(get_operand(0))); succs.push_back(dynamic_cast<BasicBlock *>(get_operand(0)));
} }
for (auto succ_bb : succs) { for (auto succ_bb : succs) {
if (succ_bb) { if (succ_bb) {
...@@ -214,20 +215,20 @@ ReturnInst *ReturnInst::create_void_ret(BasicBlock *bb) { ...@@ -214,20 +215,20 @@ ReturnInst *ReturnInst::create_void_ret(BasicBlock *bb) {
bool ReturnInst::is_void_ret() const { return get_num_operand() == 0; } bool ReturnInst::is_void_ret() const { return get_num_operand() == 0; }
GetElementPtrInst::GetElementPtrInst(Value *ptr, std::vector<Value *> idxs, GetElementPtrInst::GetElementPtrInst(Value *ptr, const std::vector<Value *>& idxs,
BasicBlock *bb) BasicBlock *bb)
: BaseInst<GetElementPtrInst>(PointerType::get(get_element_type(ptr, idxs)), : BaseInst<GetElementPtrInst>(PointerType::get(get_element_type(ptr, idxs)),
getelementptr, bb) { getelementptr, bb) {
add_operand(ptr); add_operand(ptr);
for (unsigned i = 0; i < idxs.size(); i++) { for (auto idx : idxs)
Value *idx = idxs[i]; {
assert(idx->get_type()->is_integer_type() && "Index is not integer"); assert(idx->get_type()->is_integer_type() && "Index is not integer");
add_operand(idx); add_operand(idx);
} }
} }
Type *GetElementPtrInst::get_element_type(Value *ptr, Type *GetElementPtrInst::get_element_type(const Value *ptr,
std::vector<Value *> idxs) { const std::vector<Value *>& idxs) {
assert(ptr->get_type()->is_pointer_type() && assert(ptr->get_type()->is_pointer_type() &&
"GetElementPtrInst ptr is not a pointer"); "GetElementPtrInst ptr is not a pointer");
...@@ -236,14 +237,14 @@ Type *GetElementPtrInst::get_element_type(Value *ptr, ...@@ -236,14 +237,14 @@ Type *GetElementPtrInst::get_element_type(Value *ptr,
"GetElementPtrInst ptr is wrong type" && "GetElementPtrInst ptr is wrong type" &&
(ty->is_array_type() || ty->is_integer_type() || ty->is_float_type())); (ty->is_array_type() || ty->is_integer_type() || ty->is_float_type()));
if (ty->is_array_type()) { if (ty->is_array_type()) {
ArrayType *arr_ty = static_cast<ArrayType *>(ty); ArrayType *arr_ty = dynamic_cast<ArrayType *>(ty);
for (unsigned i = 1; i < idxs.size(); i++) { for (unsigned i = 1; i < idxs.size(); i++) {
ty = arr_ty->get_element_type(); ty = arr_ty->get_element_type();
if (i < idxs.size() - 1) { if (i < idxs.size() - 1) {
assert(ty->is_array_type() && "Index error!"); assert(ty->is_array_type() && "Index error!");
} }
if (ty->is_array_type()) { if (ty->is_array_type()) {
arr_ty = static_cast<ArrayType *>(ty); arr_ty = dynamic_cast<ArrayType *>(ty);
} }
} }
} }
...@@ -255,7 +256,7 @@ Type *GetElementPtrInst::get_element_type() const { ...@@ -255,7 +256,7 @@ Type *GetElementPtrInst::get_element_type() const {
} }
GetElementPtrInst *GetElementPtrInst::create_gep(Value *ptr, GetElementPtrInst *GetElementPtrInst::create_gep(Value *ptr,
std::vector<Value *> idxs, const std::vector<Value *>& idxs,
BasicBlock *bb) { BasicBlock *bb) {
return create(ptr, idxs, bb); return create(ptr, idxs, bb);
} }
...@@ -287,7 +288,7 @@ LoadInst *LoadInst::create_load(Value *ptr, BasicBlock *bb) { ...@@ -287,7 +288,7 @@ LoadInst *LoadInst::create_load(Value *ptr, BasicBlock *bb) {
AllocaInst::AllocaInst(Type *ty, BasicBlock *bb) AllocaInst::AllocaInst(Type *ty, BasicBlock *bb)
: BaseInst<AllocaInst>(PointerType::get(ty), alloca, bb) { : BaseInst<AllocaInst>(PointerType::get(ty), alloca, bb) {
static const std::array allowed_alloc_type = { static constexpr std::array allowed_alloc_type = {
Type::IntegerTyID, Type::FloatTyID, Type::ArrayTyID, Type::PointerTyID}; Type::IntegerTyID, Type::FloatTyID, Type::ArrayTyID, Type::PointerTyID};
assert(std::find(allowed_alloc_type.begin(), allowed_alloc_type.end(), assert(std::find(allowed_alloc_type.begin(), allowed_alloc_type.end(),
ty->get_type_id()) != allowed_alloc_type.end() && ty->get_type_id()) != allowed_alloc_type.end() &&
...@@ -298,26 +299,13 @@ AllocaInst *AllocaInst::create_alloca(Type *ty, BasicBlock *bb) { ...@@ -298,26 +299,13 @@ AllocaInst *AllocaInst::create_alloca(Type *ty, BasicBlock *bb) {
return create(ty, bb); return create(ty, bb);
} }
AllocaInst *AllocaInst::create_alloca_begin(Type *ty, BasicBlock *bb) {
auto ret = create(ty, nullptr);
if(bb != nullptr)
{
ret->set_parent(bb);
if (bb->is_terminated())
bb->add_instr_before_end(ret);
else
bb->add_instruction(ret);
}
return ret;
}
ZextInst::ZextInst(Value *val, Type *ty, BasicBlock *bb) ZextInst::ZextInst(Value *val, Type *ty, BasicBlock *bb)
: BaseInst<ZextInst>(ty, zext, bb) { : BaseInst<ZextInst>(ty, zext, bb) {
assert(val->get_type()->is_integer_type() && assert(val->get_type()->is_integer_type() &&
"ZextInst operand is not integer"); "ZextInst operand is not integer");
assert(ty->is_integer_type() && "ZextInst destination type is not integer"); assert(ty->is_integer_type() && "ZextInst destination type is not integer");
assert((static_cast<IntegerType *>(val->get_type())->get_num_bits() < assert((dynamic_cast<IntegerType *>(val->get_type())->get_num_bits() <
static_cast<IntegerType *>(ty)->get_num_bits()) && dynamic_cast<IntegerType *>(ty)->get_num_bits()) &&
"ZextInst operand bit size is not smaller than destination type bit " "ZextInst operand bit size is not smaller than destination type bit "
"size"); "size");
add_operand(val); add_operand(val);
...@@ -358,8 +346,8 @@ SiToFpInst *SiToFpInst::create_sitofp(Value *val, BasicBlock *bb) { ...@@ -358,8 +346,8 @@ SiToFpInst *SiToFpInst::create_sitofp(Value *val, BasicBlock *bb) {
return create(val, bb->get_module()->get_float_type(), bb); return create(val, bb->get_module()->get_float_type(), bb);
} }
PhiInst::PhiInst(Type *ty, std::vector<Value *> vals, PhiInst::PhiInst(Type *ty, const std::vector<Value *>& vals,
std::vector<BasicBlock *> val_bbs, BasicBlock *bb) const std::vector<BasicBlock *>& val_bbs, BasicBlock *bb)
: BaseInst<PhiInst>(ty, phi) { : BaseInst<PhiInst>(ty, phi) {
assert(vals.size() == val_bbs.size() && "Unmatched vals and bbs"); assert(vals.size() == val_bbs.size() && "Unmatched vals and bbs");
for (unsigned i = 0; i < vals.size(); i++) { for (unsigned i = 0; i < vals.size(); i++) {
...@@ -371,7 +359,7 @@ PhiInst::PhiInst(Type *ty, std::vector<Value *> vals, ...@@ -371,7 +359,7 @@ PhiInst::PhiInst(Type *ty, std::vector<Value *> vals,
} }
PhiInst *PhiInst::create_phi(Type *ty, BasicBlock *bb, PhiInst *PhiInst::create_phi(Type *ty, BasicBlock *bb,
std::vector<Value *> vals, const std::vector<Value *>& vals,
std::vector<BasicBlock *> val_bbs) { const std::vector<BasicBlock *>& val_bbs) {
return create(ty, vals, val_bbs, bb); return create(ty, vals, val_bbs, bb);
} }
...@@ -6,61 +6,75 @@ ...@@ -6,61 +6,75 @@
#include <string> #include <string>
Module::Module() { Module::Module() {
void_ty_ = std::make_unique<Type>(Type::VoidTyID, this); void_ty_ = new Type(Type::VoidTyID, this);
label_ty_ = std::make_unique<Type>(Type::LabelTyID, this); label_ty_ = new Type(Type::LabelTyID, this);
int1_ty_ = std::make_unique<IntegerType>(1, this); int1_ty_ = new IntegerType(1, this);
int32_ty_ = std::make_unique<IntegerType>(32, this); int32_ty_ = new IntegerType(32, this);
float32_ty_ = std::make_unique<FloatType>(this); float32_ty_ = new FloatType(this);
} }
Type *Module::get_void_type() { return void_ty_.get(); } Module::~Module()
Type *Module::get_label_type() { return label_ty_.get(); } {
IntegerType *Module::get_int1_type() { return int1_ty_.get(); } delete void_ty_;
IntegerType *Module::get_int32_type() { return int32_ty_.get(); } delete label_ty_;
FloatType *Module::get_float_type() { return float32_ty_.get(); } delete int1_ty_;
delete int32_ty_;
delete float32_ty_;
for (auto& i : pointer_map_) delete i.second;
for (auto& i : array_map_) delete i.second;
for (auto& i : function_map_) delete i.second;
for (auto i : function_list_) delete i;
for (auto i : global_list_) delete i;
}
Type *Module::get_void_type() const { return void_ty_; }
Type *Module::get_label_type() const { return label_ty_; }
IntegerType *Module::get_int1_type() const { return int1_ty_; }
IntegerType *Module::get_int32_type() const { return int32_ty_; }
FloatType *Module::get_float_type() const { return float32_ty_; }
PointerType *Module::get_int32_ptr_type() { PointerType *Module::get_int32_ptr_type() {
return get_pointer_type(int32_ty_.get()); return get_pointer_type(int32_ty_);
} }
PointerType *Module::get_float_ptr_type() { PointerType *Module::get_float_ptr_type() {
return get_pointer_type(float32_ty_.get()); return get_pointer_type(float32_ty_);
} }
PointerType *Module::get_pointer_type(Type *contained) { PointerType *Module::get_pointer_type(Type *contained) {
if (pointer_map_.find(contained) == pointer_map_.end()) { if (pointer_map_.find(contained) == pointer_map_.end()) {
pointer_map_[contained] = std::make_unique<PointerType>(contained); pointer_map_[contained] = new PointerType(contained);
} }
return pointer_map_[contained].get(); return pointer_map_[contained];
} }
ArrayType *Module::get_array_type(Type *contained, unsigned num_elements) { ArrayType *Module::get_array_type(Type *contained, unsigned num_elements) {
if (array_map_.find({contained, num_elements}) == array_map_.end()) { if (array_map_.find({contained, num_elements}) == array_map_.end()) {
array_map_[{contained, num_elements}] = array_map_[{contained, num_elements}] =
std::make_unique<ArrayType>(contained, num_elements); new ArrayType(contained, num_elements);
} }
return array_map_[{contained, num_elements}].get(); return array_map_[{contained, num_elements}];
} }
FunctionType *Module::get_function_type(Type *retty, FunctionType *Module::get_function_type(Type *retty,
std::vector<Type *> &args) { std::vector<Type *> &args) {
if (not function_map_.count({retty, args})) { if (not function_map_.count({retty, args})) {
function_map_[{retty, args}] = function_map_[{retty, args}] =
std::make_unique<FunctionType>(retty, args); new FunctionType(retty, args);
} }
return function_map_[{retty, args}].get(); return function_map_[{retty, args}];
} }
void Module::add_function(Function *f) { function_list_.push_back(f); } void Module::add_function(Function *f) { function_list_.push_back(f); }
llvm::ilist<Function> &Module::get_functions() { return function_list_; } std::list<Function*> &Module::get_functions() { return function_list_; }
void Module::add_global_variable(GlobalVariable *g) { void Module::add_global_variable(GlobalVariable *g) {
global_list_.push_back(g); global_list_.push_back(g);
} }
llvm::ilist<GlobalVariable> &Module::get_global_variable() { std::list<GlobalVariable*> &Module::get_global_variable() {
return global_list_; return global_list_;
} }
void Module::set_print_name() { void Module::set_print_name() {
for (auto &func : this->get_functions()) { for (auto func : this->get_functions()) {
func.set_instr_name(); func->set_instr_name();
} }
return; return;
} }
...@@ -68,12 +82,12 @@ void Module::set_print_name() { ...@@ -68,12 +82,12 @@ void Module::set_print_name() {
std::string Module::print() { std::string Module::print() {
set_print_name(); set_print_name();
std::string module_ir; std::string module_ir;
for (auto &global_val : this->global_list_) { for (auto global_val : this->global_list_) {
module_ir += global_val.print(); module_ir += global_val->print();
module_ir += "\n"; module_ir += "\n";
} }
for (auto &func : this->function_list_) { for (auto func : this->function_list_) {
module_ir += func.print(); module_ir += func->print();
module_ir += "\n"; module_ir += "\n";
} }
return module_ir; return module_ir;
......
...@@ -3,31 +3,32 @@ ...@@ -3,31 +3,32 @@
#include <array> #include <array>
#include <cassert> #include <cassert>
#include <stdexcept>
Type::Type(TypeID tid, Module *m) { Type::Type(TypeID tid, Module *m) {
tid_ = tid; tid_ = tid;
m_ = m; m_ = m;
} }
Type::~Type() = default;
bool Type::is_int1_type() const { bool Type::is_int1_type() const {
return is_integer_type() and return is_integer_type() and
static_cast<const IntegerType *>(this)->get_num_bits() == 1; dynamic_cast<const IntegerType *>(this)->get_num_bits() == 1;
} }
bool Type::is_int32_type() const { bool Type::is_int32_type() const {
return is_integer_type() and return is_integer_type() and
static_cast<const IntegerType *>(this)->get_num_bits() == 32; dynamic_cast<const IntegerType *>(this)->get_num_bits() == 32;
} }
Type *Type::get_pointer_element_type() const { Type *Type::get_pointer_element_type() const {
if (this->is_pointer_type()) if (this->is_pointer_type())
return static_cast<const PointerType *>(this)->get_element_type(); return dynamic_cast<const PointerType *>(this)->get_element_type();
assert(false and "get_pointer_element_type() called on non-pointer type"); assert(false and "get_pointer_element_type() called on non-pointer type");
} }
Type *Type::get_array_element_type() const { Type *Type::get_array_element_type() const {
if (this->is_array_type()) if (this->is_array_type())
return static_cast<const ArrayType *>(this)->get_element_type(); return dynamic_cast<const ArrayType *>(this)->get_element_type();
assert(false and "get_array_element_type() called on non-array type"); assert(false and "get_array_element_type() called on non-array type");
} }
...@@ -42,7 +43,7 @@ unsigned Type::get_size() const { ...@@ -42,7 +43,7 @@ unsigned Type::get_size() const {
assert(false && "Type::get_size(): unexpected int type bits"); assert(false && "Type::get_size(): unexpected int type bits");
} }
case ArrayTyID: { case ArrayTyID: {
auto array_type = static_cast<const ArrayType *>(this); auto array_type = dynamic_cast<const ArrayType *>(this);
auto element_size = array_type->get_element_type()->get_size(); auto element_size = array_type->get_element_type()->get_size();
auto num_elements = array_type->get_num_of_elements(); auto num_elements = array_type->get_num_of_elements();
return element_size * num_elements; return element_size * num_elements;
...@@ -71,18 +72,18 @@ std::string Type::print() const { ...@@ -71,18 +72,18 @@ std::string Type::print() const {
case IntegerTyID: case IntegerTyID:
type_ir += "i"; type_ir += "i";
type_ir += std::to_string( type_ir += std::to_string(
static_cast<const IntegerType *>(this)->get_num_bits()); dynamic_cast<const IntegerType *>(this)->get_num_bits());
break; break;
case FunctionTyID: case FunctionTyID:
type_ir += type_ir +=
static_cast<const FunctionType *>(this)->get_return_type()->print(); dynamic_cast<const FunctionType *>(this)->get_return_type()->print();
type_ir += " ("; type_ir += " (";
for (unsigned i = 0; for (unsigned i = 0;
i < static_cast<const FunctionType *>(this)->get_num_of_args(); i < dynamic_cast<const FunctionType *>(this)->get_num_of_args();
i++) { i++) {
if (i) if (i)
type_ir += ", "; type_ir += ", ";
type_ir += static_cast<const FunctionType *>(this) type_ir += dynamic_cast<const FunctionType *>(this)
->get_param_type(i) ->get_param_type(i)
->print(); ->print();
} }
...@@ -95,10 +96,10 @@ std::string Type::print() const { ...@@ -95,10 +96,10 @@ std::string Type::print() const {
case ArrayTyID: case ArrayTyID:
type_ir += "["; type_ir += "[";
type_ir += std::to_string( type_ir += std::to_string(
static_cast<const ArrayType *>(this)->get_num_of_elements()); dynamic_cast<const ArrayType *>(this)->get_num_of_elements());
type_ir += " x "; type_ir += " x ";
type_ir += type_ir +=
static_cast<const ArrayType *>(this)->get_element_type()->print(); dynamic_cast<const ArrayType *>(this)->get_element_type()->print();
type_ir += "]"; type_ir += "]";
break; break;
case FloatTyID: case FloatTyID:
...@@ -113,9 +114,11 @@ std::string Type::print() const { ...@@ -113,9 +114,11 @@ std::string Type::print() const {
IntegerType::IntegerType(unsigned num_bits, Module *m) IntegerType::IntegerType(unsigned num_bits, Module *m)
: Type(Type::IntegerTyID, m), num_bits_(num_bits) {} : Type(Type::IntegerTyID, m), num_bits_(num_bits) {}
IntegerType::~IntegerType() = default;
unsigned IntegerType::get_num_bits() const { return num_bits_; } unsigned IntegerType::get_num_bits() const { return num_bits_; }
FunctionType::FunctionType(Type *result, std::vector<Type *> params) FunctionType::FunctionType(Type *result, const std::vector<Type *>& params)
: Type(Type::FunctionTyID, nullptr) { : Type(Type::FunctionTyID, nullptr) {
assert(is_valid_return_type(result) && "Invalid return type for function!"); assert(is_valid_return_type(result) && "Invalid return type for function!");
result_ = result; result_ = result;
...@@ -127,11 +130,13 @@ FunctionType::FunctionType(Type *result, std::vector<Type *> params) ...@@ -127,11 +130,13 @@ FunctionType::FunctionType(Type *result, std::vector<Type *> params)
} }
} }
bool FunctionType::is_valid_return_type(Type *ty) { FunctionType::~FunctionType() = default;
bool FunctionType::is_valid_return_type(const Type *ty) {
return ty->is_integer_type() || ty->is_void_type() || ty->is_float_type(); return ty->is_integer_type() || ty->is_void_type() || ty->is_float_type();
} }
bool FunctionType::is_valid_argument_type(Type *ty) { bool FunctionType::is_valid_argument_type(const Type *ty) {
return ty->is_integer_type() || ty->is_pointer_type() || return ty->is_integer_type() || ty->is_pointer_type() ||
ty->is_float_type(); ty->is_float_type();
} }
...@@ -140,7 +145,7 @@ FunctionType *FunctionType::get(Type *result, std::vector<Type *> params) { ...@@ -140,7 +145,7 @@ FunctionType *FunctionType::get(Type *result, std::vector<Type *> params) {
return result->get_module()->get_function_type(result, params); return result->get_module()->get_function_type(result, params);
} }
unsigned FunctionType::get_num_of_args() const { return args_.size(); } unsigned FunctionType::get_num_of_args() const { return static_cast<unsigned>(args_.size()); }
Type *FunctionType::get_param_type(unsigned i) const { return args_[i]; } Type *FunctionType::get_param_type(unsigned i) const { return args_[i]; }
...@@ -154,7 +159,9 @@ ArrayType::ArrayType(Type *contained, unsigned num_elements) ...@@ -154,7 +159,9 @@ ArrayType::ArrayType(Type *contained, unsigned num_elements)
contained_ = contained; contained_ = contained;
} }
bool ArrayType::is_valid_element_type(Type *ty) { ArrayType::~ArrayType() = default;
bool ArrayType::is_valid_element_type(const Type *ty) {
return ty->is_integer_type() || ty->is_array_type() || ty->is_float_type(); return ty->is_integer_type() || ty->is_array_type() || ty->is_float_type();
} }
...@@ -164,7 +171,7 @@ ArrayType *ArrayType::get(Type *contained, unsigned num_elements) { ...@@ -164,7 +171,7 @@ ArrayType *ArrayType::get(Type *contained, unsigned num_elements) {
PointerType::PointerType(Type *contained) PointerType::PointerType(Type *contained)
: Type(Type::PointerTyID, contained->get_module()), contained_(contained) { : Type(Type::PointerTyID, contained->get_module()), contained_(contained) {
static const std::array allowed_elem_type = { static constexpr std::array allowed_elem_type = {
Type::IntegerTyID, Type::FloatTyID, Type::ArrayTyID, Type::PointerTyID}; Type::IntegerTyID, Type::FloatTyID, Type::ArrayTyID, Type::PointerTyID};
auto elem_type_id = contained->get_type_id(); auto elem_type_id = contained->get_type_id();
assert(std::find(allowed_elem_type.begin(), allowed_elem_type.end(), assert(std::find(allowed_elem_type.begin(), allowed_elem_type.end(),
...@@ -172,10 +179,14 @@ PointerType::PointerType(Type *contained) ...@@ -172,10 +179,14 @@ PointerType::PointerType(Type *contained)
"Not allowed type for pointer"); "Not allowed type for pointer");
} }
PointerType::~PointerType() = default;
PointerType *PointerType::get(Type *contained) { PointerType *PointerType::get(Type *contained) {
return contained->get_module()->get_pointer_type(contained); return contained->get_module()->get_pointer_type(contained);
} }
FloatType::FloatType(Module *m) : Type(Type::FloatTyID, m) {} FloatType::FloatType(Module *m) : Type(Type::FloatTyID, m) {}
FloatType::~FloatType() = default;
FloatType *FloatType::get(Module *m) { return m->get_float_type(); } FloatType *FloatType::get(Module *m) { return m->get_float_type(); }
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
#include <cassert> #include <cassert>
User::~User() { remove_all_operands(); }
void User::set_operand(unsigned i, Value *v) { void User::set_operand(unsigned i, Value *v) {
assert(i < operands_.size() && "set_operand out of index"); assert(i < operands_.size() && "set_operand out of index");
if (operands_[i]) { // old operand if (operands_[i]) { // old operand
...@@ -15,7 +17,7 @@ void User::set_operand(unsigned i, Value *v) { ...@@ -15,7 +17,7 @@ void User::set_operand(unsigned i, Value *v) {
void User::add_operand(Value *v) { void User::add_operand(Value *v) {
assert(v != nullptr && "bad use: add_operand(nullptr)"); assert(v != nullptr && "bad use: add_operand(nullptr)");
v->add_use(this, operands_.size()); v->add_use(this, static_cast<unsigned>(operands_.size()));
operands_.push_back(v); operands_.push_back(v);
} }
......
#include "Value.hpp" #include "Value.hpp"
#include "Type.hpp"
#include "User.hpp" #include "User.hpp"
#include <cassert>
bool Value::set_name(std::string name) { Value::~Value() { replace_all_use_with(nullptr); }
if (name_ == "") {
bool Value::set_name(const std::string& name) {
if (name_.empty()) {
name_ = name; name_ = name;
return true; return true;
} }
...@@ -21,21 +21,22 @@ void Value::remove_use(User *user, unsigned arg_no) { ...@@ -21,21 +21,22 @@ void Value::remove_use(User *user, unsigned arg_no) {
use_list_.remove_if([&](const Use &use) { return use == target_use; }); use_list_.remove_if([&](const Use &use) { return use == target_use; });
} }
void Value::replace_all_use_with(Value *new_val) { void Value::replace_all_use_with(Value *new_val) const
{
if (this == new_val) if (this == new_val)
return; return;
while (use_list_.size()) { while (!use_list_.empty()) {
auto use = use_list_.begin(); auto use = use_list_.begin();
use->val_->set_operand(use->arg_no_, new_val); use->val_->set_operand(use->arg_no_, new_val);
} }
} }
void Value::replace_use_with_if(Value *new_val, void Value::replace_use_with_if(Value *new_val,
std::function<bool(Use *)> should_replace) { const std::function<bool(Use *)>& should_replace) {
if (this == new_val) if (this == new_val)
return; return;
for (auto iter = use_list_.begin(); iter != use_list_.end();) { for (auto iter = use_list_.begin(); iter != use_list_.end();) {
auto use = *iter++; auto &use = *iter++;
if (not should_replace(&use)) if (not should_replace(&use))
continue; continue;
use.val_->set_operand(use.arg_no_, new_val); use.val_->set_operand(use.arg_no_, new_val);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment