Commit f26d91aa authored by 李晓奇's avatar 李晓奇

finish a lot... bugs

parent efbf4233
...@@ -8,12 +8,265 @@ ...@@ -8,12 +8,265 @@
## 实验难点 ## 实验难点
实验中遇到哪些挑战 ### 对于函数`detectEquivalences(G)`
```cpp
detectEquivalences(G)
PIN1 = {} // “1” is the first statement in the program
POUT1 = transferFunction(PIN1)
for each statement s other than the first statement in the program
POUTs = Top
while changes to any POUT occur // i.e. changes in equivalences
for each statement s other than the first statement in the program
if s appears in block b that has two predecessors
then
PINs = Join(POUTs1, POUTs2) // s1 and s2 are last statements in respective predecessors
else
PINs = POUTs3 // s3 the statement just before s
POUTs = transferFunction(PINs) // apply transferFunction on each statement in the block
```
- POUT、PIN是对于语句的的map,但是实现中是对于基本块的map,怎么处理区别?
注意到基本块的pin、pout对应着首条语句的pin和末尾语句的pout,所以没有本质上的不同,主要是怎么补全中间语句的pin、pout。
中间语句的pin、pout应该不会被显式保存,只在对基本块执行转移函数时,逐条计算。所以应该没啥问题。
- 怎么设计top?
尽管算法伪代码中,top是对每一条语句的赋值的,但是任意一条语句的都有所属基本块,只要对基本块的pout标注top、就能保证基本块的pin合乎逻辑,逐条执行转移函数后,每条语句的pin都合法。``
- `while changes to any POUT occur`
这句话如此轻松,但是在实现中出现了问题:这涉及到怎么判断分区相等,做不好会出现死循环。
如果单纯比较index恐怕不行:因为我目前的`transferFunction`会生成一个新的等价类,赋予一个新的编号。
所以我使用对等价类中的`members`做相等的判断,如果两个分区中的等价类都能找到相等的,那么就判断分区相等。
> 但是这样依赖于set的排序,
>
> - 对于分区,这个排序依赖于等价类中index的值,侥幸来说每条指令按顺序执行,即使index不一致(每次转移函数分配最新的index),大小顺序也是一致的?
>
> - 对于等价类,里边是`Value*`指针,这个指针的等价就是语义上的等价,没问题。
改过来后发现还是死循环,DEBUG很久意识到我的分区相等写的有问题:`return members_ == other.members_`,这样会挨个儿比较`members_`指针的值,而不是我预想的比较指针指向的等价类。所以要手动遍历一下。
这个遍历也不是很直观的,因为要做两层解引用:
```cpp
for (; it1 != p1.end(); ++it1, ++it2)
if (not(**it1 == **it2))
return false;
```
`*it`是等价类的指针,`**it`才是等价类。
除此之外,我还忘了`++it2`,淦。
改完这些,终于能跑通第一版本了……
### 等价类设计
> ```cpp
> shared_ptr<Expression>
> GVN::valueExpr(Instruction *instr);
> ```
等价类涉及到针对指令设计、常量折叠、递归计算等,很复杂,需要拎出来单独讨论。
首先注意到传给传进去的参数是指令类型(指针)。需要`transferFunction`处理的指令,一次赋值的右式自然可以由指令类型得到,所以没有毛病。
指令有多种类型,也即lightIR中有多种赋值方式,而非单纯的伪代码中的二元运算。但是即使是已经在算法中讨论过的二元运算,写起来也不容易。
#### 生成等价类的逻辑
首先讨论生成等价类的逻辑:
- 二元运算`y op z`
如果y、z都是常量,那么可以进行常量折叠,得到**一个**`ConstantExpression`类型。
否则返回的是一个`BinaryExpression`类型,这个类型需要左右操作数,都是`Expression`类型,如果操作数是常量,使用`ConstantExpression`创建新的值表达式,否则递归调用`valueExpr`。
> 这里为什么递归调用不会走到未定义?
- phi函数
我们逻辑上把其改变成了copy,`transferFunction`应该维护这一点,所以认为没有phi函数传入。
具体是`transferFunction`函数中的这几句话:
```cpp
GVN::partitions
GVN::transferFunction(Instruction *x, Value *e, partitions pin)
{
auto e_instr = dynamic_cast<Instruction *>(e);
auto e_const = dynamic_cast<Constant *>(e);
assert((not e or e_instr or e_const) &&
"A value must be from an instruction or constant");
// ……
if (e) {
if (e_const)
ve = ConstantExpression::create(e_const);
else
ve = valueExpr(e_instr, &pin);
} else
ve = valueExpr(x, &pin);
}
```
- 比较函数:和二元运算一致。
- 类型转换:添加新的表达式类型`CastExpression`,处理和二元操作类似。
- 其他产生赋值的指令,也会由`transferFunction`传递给`ValueExpr`,这些指令是:`load`、`alloca`和返回值非空的`call`。
这里设计一种值表达式类型:`UniqueExpression`,这个表达式和任何其他都不等价,表现为`equiv`直接返回false。
#### 若干需要梳理的问题
除此之外,关于等价类有几个命题一定要说明,这些是理解GVN的关键。
- 等价类中的ve一定存在吗?
根据等价类的设计,ve一定是存在的,且等价类中`members_`中一定存在元素。
- ve一定唯一吗?
从逻辑上,等价类集合中的元素享有共同的ve产生的结果,ve是唯一的
从实现上,添加
- ve或者vpf能够作为代表吗?
这个问题源于`transferFunction`的一句话:
> if ve or vpf is in a class Ci in POUTs
- 代表元如何选取?什么时候选取?
- 怎么判断等价类相等
根据论坛上这个[例子](http://cscourse.ustc.edu.cn/forum/thread.jsp?forum=91&thread=68),我发现值编号会有非必要的增长,即最终收敛时用到的编号是1、5、6,其中2、3、4都是迭代中用到的但是最后都没有体现。
从这个例子中可以发现,迭代收敛时,是pout中对应分区的`members_`不变,值编号还是会变化的,所以判断等价类相等,应该比较`members_`集合。
- 什么时候给新的值编号?
1. 执行转移函数发现没有可以归属的等价类时
2. 求交巴拉巴拉还没搞清楚那里
- 等价类为空的判定,可以用`members_`做判断吗?
我认为是可以的,存在情况:两个基本块汇合时,交出一个集合,它的members为空,但是ve不空:
```
BB0: y =
     z =
; pout: [{v1,y}, {v2,z}]
BB1: x1 = y + z ; preds = BB0
; pout: [{v1,y}, {v2,z}, {v3, x1, v1+v2}]
BB2: x2 = y + z ; preds = BB0
; pout: [{v1,y}, {v2,z}, {v4, x2, v1+v2}]
BB3: ... ; preds = BB1, BB2
```
进入BB3时,对BB1和BB2的pout中等价类分别取交,执行`Intersect(Ci={v3, x1, v1+v2}, Cj={v4, x2, v1+v2}`,得到的是`{v1+v2}`,`members_`中没有成员,但是ve存在,我认为这种情况也应该判定结果为空集。
因为最终反映到IR上,我们直接关心的都是`members_`中的成员,所以用`members_`的空代表等价类的空我觉得是合适的。
| 类型 | 处理 |
| ---------------- | ------------------------- |
| 二元运算 | 依伪代码 |
| 没有返回值的 | 不处理 |
| phi函数 | 本块中的不管,后继块的做考虑 |
| 函数调用 | 有返回值的考虑等价类,**注意后边纯函数的处理** |
| 产生指针(GEP、alloca) | 先为返回值单独新建等价类 |
| 类型转换及0扩展 | 单独新建等价类 |
| 比较 | 根据op和操作数递归判断,应该和二元运算类似 |
| load | 单独新建等价类 |
### 对于函数`Intersect(Ci, Cj )`
伪代码中,一个等价类是一个集合,求交后,v~k~是自然而然的在或不在新集合中,但是实现中并不这么简单,如何对这样的两个结构体进行逻辑上的取交:
```cpp
struct CongruenceClass {
size_t index_;
Value *leader_;
std::shared_ptr<GVNExpression::Expression> value_expr_;
std::shared_ptr<GVNExpression::PhiExpression> value_phi_;
std::set<Value *> members_;
}
```
- 首先是Intersect的伪代码描述中,下边这句怎么在实现中体现?
> C~k~ does not have value number
解决:index不同的自然就交不出v~i~,相同的时候才有v~i~即value number.
隐患:transferFunction要对v~i~有继承性,即不能每次涉及一个等价类都新开一个value number
- 根据上一条讨论,得出一个观点:在伪代码中,一个等价类集合会出现的内容,在实现中的对应如下:
- 普通变量:`members_`
- 值表达式:`value_expr_`
- phi函数:`value_phi_`
- 值编号v~i~:`index`
因此,伪代码中对集合的求交,对应到实现上,就是对以上四个域求交。
### 对于函数`valuePhiFunc(ve, P)`
- 输入就一个分区,实际上要用到两个前驱的分区,怎么处理?
传入基本块做参数。
- 二元运算的顺序颠倒怎么办?
> 如phi(x1,y1)+phi(x2,y2)和phi(x1,y1)+phi(y2,x2)
>
> 对于第二种,单纯寻找x1+y2和y1+x2肯定不合逻辑
应该不会有这种情况:phi表达式,追根溯源是在join时产生的:
```cpp
GVN::join(const partitions &P1, const partitions &P2);
```
这里的分区是严格按照`Basicblock`的方法`get_pre_basic_blocks()`的返回顺序传入的,所以不会有担心的情况发生。
### 所见非所得
两方面:
- 论文中伪代码对于基本块的关注不是很多,但是我们的实现一定要围绕基本块进行,这个的解决对策说到了,其实和单条语句没有本质差异。
- 我们针对lightir优化,经常深入API中陷入到实现细节,还容易被C++的语言特性绊住,但是应该尝试从直观上理解:我们究竟要实现什么样的效果。这个就得去看ll文件找灵感。然后这还不够,回过头来看代码又呆住了,因为ll语句和那些类型对应不上,这个就是容易忽视的一点:看lightir类型的print函数,这个能帮我们直观地串联起实现细节和表象的ll文件。
## 实验设计 ## 实验设计
### join_helper
### _TOP
### transferFunction(Basicblock)
###
实现思路,相应代码,优化前后的IR对比(举一个例子)并辅以简单说明 实现思路,相应代码,优化前后的IR对比(举一个例子)并辅以简单说明
### 思考题 ### 思考题
1. 请简要分析你的算法复杂度 1. 请简要分析你的算法复杂度
2. `std::shared_ptr`如果存在环形引用,则无法正确释放内存,你的 Expression 类是否存在 circular reference? 2. `std::shared_ptr`如果存在环形引用,则无法正确释放内存,你的 Expression 类是否存在 circular reference?
3. 尽管本次实验已经写了很多代码,但是在算法上和工程上仍然可以对 GVN 进行改进,请简述你的 GVN 实现可以改进的地方 3. 尽管本次实验已经写了很多代码,但是在算法上和工程上仍然可以对 GVN 进行改进,请简述你的 GVN 实现可以改进的地方
...@@ -25,4 +278,3 @@ ...@@ -25,4 +278,3 @@
## 实验反馈(可选 不会评分) ## 实验反馈(可选 不会评分)
对本次实验的建议 对本次实验的建议
...@@ -9,8 +9,7 @@ ...@@ -9,8 +9,7 @@
class BasicBlock; class BasicBlock;
class Function; class Function;
class Instruction : public User, public llvm::ilist_node<Instruction> class Instruction : public User, public llvm::ilist_node<Instruction> {
{
public: public:
enum OpID { enum OpID {
// Terminator Instructions // Terminator Instructions
...@@ -49,11 +48,11 @@ class Instruction : public User, public llvm::ilist_node<Instruction> ...@@ -49,11 +48,11 @@ class Instruction : public User, public llvm::ilist_node<Instruction>
Instruction(const Instruction &) = delete; Instruction(const Instruction &) = delete;
virtual ~Instruction() = default; virtual ~Instruction() = default;
inline const BasicBlock *get_parent() const { return parent_; } inline const BasicBlock *get_parent() const { return parent_; }
inline BasicBlock *get_parent() { return parent_; } inline BasicBlock *get_parent() { return parent_; }
void set_parent(BasicBlock *parent) { this->parent_ = parent; } void set_parent(BasicBlock *parent) { this->parent_ = parent; }
// Return the function this instruction belongs to. // Return the function this instruction belongs to.
Function *get_function(); Function *get_function();
Module *get_module(); Module *get_module();
OpID get_instr_type() const { return op_id_; } OpID get_instr_type() const { return op_id_; }
// clang-format off // clang-format off
...@@ -86,8 +85,7 @@ class Instruction : public User, public llvm::ilist_node<Instruction> ...@@ -86,8 +85,7 @@ class Instruction : public User, public llvm::ilist_node<Instruction>
// clang-format on // clang-format on
std::string get_instr_op_name() { return get_instr_op_name(op_id_); } std::string get_instr_op_name() { return get_instr_op_name(op_id_); }
bool is_void() bool is_void() {
{
return ((op_id_ == ret) || (op_id_ == br) || (op_id_ == store) || return ((op_id_ == ret) || (op_id_ == br) || (op_id_ == store) ||
(op_id_ == call && this->get_type()->is_void_type())); (op_id_ == call && this->get_type()->is_void_type()));
} }
...@@ -108,6 +106,7 @@ class Instruction : public User, public llvm::ilist_node<Instruction> ...@@ -108,6 +106,7 @@ class Instruction : public User, public llvm::ilist_node<Instruction>
bool is_fsub() { return op_id_ == fsub; } bool is_fsub() { return op_id_ == fsub; }
bool is_fmul() { return op_id_ == fmul; } bool is_fmul() { return op_id_ == fmul; }
bool is_fdiv() { return op_id_ == fdiv; } bool is_fdiv() { return op_id_ == fdiv; }
bool is_fp2si() { return op_id_ == fptosi; } bool is_fp2si() { return op_id_ == fptosi; }
bool is_si2fp() { return op_id_ == sitofp; } bool is_si2fp() { return op_id_ == sitofp; }
...@@ -118,8 +117,7 @@ class Instruction : public User, public llvm::ilist_node<Instruction> ...@@ -118,8 +117,7 @@ class Instruction : public User, public llvm::ilist_node<Instruction>
bool is_gep() { return op_id_ == getelementptr; } bool is_gep() { return op_id_ == getelementptr; }
bool is_zext() { return op_id_ == zext; } bool is_zext() { return op_id_ == zext; }
bool isBinary() bool isBinary() {
{
return (is_add() || is_sub() || is_mul() || is_div() || is_fadd() || return (is_add() || is_sub() || is_mul() || is_div() || is_fadd() ||
is_fsub() || is_fmul() || is_fdiv()) && is_fsub() || is_fmul() || is_fdiv()) &&
(get_num_operand() == 2); (get_num_operand() == 2);
...@@ -128,39 +126,34 @@ class Instruction : public User, public llvm::ilist_node<Instruction> ...@@ -128,39 +126,34 @@ class Instruction : public User, public llvm::ilist_node<Instruction>
bool isTerminator() { return is_br() || is_ret(); } bool isTerminator() { return is_br() || is_ret(); }
private: private:
OpID op_id_; OpID op_id_;
unsigned num_ops_; unsigned num_ops_;
BasicBlock *parent_; BasicBlock *parent_;
}; };
namespace detail namespace detail {
{ template <typename T>
template <typename T> struct tag {
struct tag using type = T;
{ };
using type = T; template <typename... Ts>
}; struct select_last {
template <typename... Ts> // Use a fold-expression to fold the comma operator over the parameter
struct select_last // pack.
{ using type = typename decltype((tag<Ts>{}, ...))::type;
// Use a fold-expression to fold the comma operator over the parameter };
// pack. template <typename... Ts>
using type = typename decltype((tag<Ts>{}, ...))::type; using select_last_t = typename select_last<Ts...>::type;
};
template <typename... Ts>
using select_last_t = typename select_last<Ts...>::type;
}; // namespace detail }; // namespace detail
template <class> template <class>
inline constexpr bool always_false_v = false; inline constexpr bool always_false_v = false;
template <typename Inst> template <typename Inst>
class BaseInst : public Instruction class BaseInst : public Instruction {
{
protected: protected:
template <typename... Args> template <typename... Args>
static Inst *create(Args &&...args) static Inst *create(Args &&...args) {
{
if constexpr (std::is_same_v< if constexpr (std::is_same_v<
std::decay_t<detail::select_last_t<Args...>>, std::decay_t<detail::select_last_t<Args...>>,
BasicBlock *>) { BasicBlock *>) {
...@@ -171,13 +164,10 @@ class BaseInst : public Instruction ...@@ -171,13 +164,10 @@ class BaseInst : public Instruction
} }
template <typename... Args> template <typename... Args>
BaseInst(Args &&...args) : Instruction(std::forward<Args>(args)...) BaseInst(Args &&...args) : Instruction(std::forward<Args>(args)...) {}
{
}
}; };
class BinaryInst : public BaseInst<BinaryInst> class BinaryInst : public BaseInst<BinaryInst> {
{
friend BaseInst<BinaryInst>; friend BaseInst<BinaryInst>;
private: private:
...@@ -185,35 +175,51 @@ class BinaryInst : public BaseInst<BinaryInst> ...@@ -185,35 +175,51 @@ class BinaryInst : public BaseInst<BinaryInst>
public: public:
// create add instruction, auto insert to bb // create add instruction, auto insert to bb
static BinaryInst *create_add(Value *v1, Value *v2, BasicBlock *bb, static BinaryInst *create_add(Value *v1,
Value *v2,
BasicBlock *bb,
Module *m); Module *m);
// create sub instruction, auto insert to bb // create sub instruction, auto insert to bb
static BinaryInst *create_sub(Value *v1, Value *v2, BasicBlock *bb, static BinaryInst *create_sub(Value *v1,
Value *v2,
BasicBlock *bb,
Module *m); Module *m);
// create mul instruction, auto insert to bb // create mul instruction, auto insert to bb
static BinaryInst *create_mul(Value *v1, Value *v2, BasicBlock *bb, static BinaryInst *create_mul(Value *v1,
Value *v2,
BasicBlock *bb,
Module *m); Module *m);
// create Div instruction, auto insert to bb // create Div instruction, auto insert to bb
static BinaryInst *create_sdiv(Value *v1, Value *v2, BasicBlock *bb, static BinaryInst *create_sdiv(Value *v1,
Value *v2,
BasicBlock *bb,
Module *m); Module *m);
// create fadd instruction, auto insert to bb // create fadd instruction, auto insert to bb
static BinaryInst *create_fadd(Value *v1, Value *v2, BasicBlock *bb, static BinaryInst *create_fadd(Value *v1,
Value *v2,
BasicBlock *bb,
Module *m); Module *m);
// create fsub instruction, auto insert to bb // create fsub instruction, auto insert to bb
static BinaryInst *create_fsub(Value *v1, Value *v2, BasicBlock *bb, static BinaryInst *create_fsub(Value *v1,
Value *v2,
BasicBlock *bb,
Module *m); Module *m);
// create fmul instruction, auto insert to bb // create fmul instruction, auto insert to bb
static BinaryInst *create_fmul(Value *v1, Value *v2, BasicBlock *bb, static BinaryInst *create_fmul(Value *v1,
Value *v2,
BasicBlock *bb,
Module *m); Module *m);
// create fDiv instruction, auto insert to bb // create fDiv instruction, auto insert to bb
static BinaryInst *create_fdiv(Value *v1, Value *v2, BasicBlock *bb, static BinaryInst *create_fdiv(Value *v1,
Value *v2,
BasicBlock *bb,
Module *m); Module *m);
virtual std::string print() override; virtual std::string print() override;
...@@ -222,8 +228,7 @@ class BinaryInst : public BaseInst<BinaryInst> ...@@ -222,8 +228,7 @@ class BinaryInst : public BaseInst<BinaryInst>
void assertValid(); void assertValid();
}; };
class CmpInst : public BaseInst<CmpInst> class CmpInst : public BaseInst<CmpInst> {
{
friend BaseInst<CmpInst>; friend BaseInst<CmpInst>;
public: public:
...@@ -240,7 +245,10 @@ class CmpInst : public BaseInst<CmpInst> ...@@ -240,7 +245,10 @@ class CmpInst : public BaseInst<CmpInst>
CmpInst(Type *ty, CmpOp op, Value *lhs, Value *rhs, BasicBlock *bb); CmpInst(Type *ty, CmpOp op, Value *lhs, Value *rhs, BasicBlock *bb);
public: public:
static CmpInst *create_cmp(CmpOp op, Value *lhs, Value *rhs, BasicBlock *bb, static CmpInst *create_cmp(CmpOp op,
Value *lhs,
Value *rhs,
BasicBlock *bb,
Module *m); Module *m);
CmpOp get_cmp_op() { return cmp_op_; } CmpOp get_cmp_op() { return cmp_op_; }
...@@ -253,8 +261,7 @@ class CmpInst : public BaseInst<CmpInst> ...@@ -253,8 +261,7 @@ class CmpInst : public BaseInst<CmpInst>
void assertValid(); void assertValid();
}; };
class FCmpInst : public BaseInst<FCmpInst> class FCmpInst : public BaseInst<FCmpInst> {
{
friend BaseInst<FCmpInst>; friend BaseInst<FCmpInst>;
public: public:
...@@ -271,8 +278,11 @@ class FCmpInst : public BaseInst<FCmpInst> ...@@ -271,8 +278,11 @@ class FCmpInst : public BaseInst<FCmpInst>
FCmpInst(Type *ty, CmpOp op, Value *lhs, Value *rhs, BasicBlock *bb); FCmpInst(Type *ty, CmpOp op, Value *lhs, Value *rhs, BasicBlock *bb);
public: public:
static FCmpInst *create_fcmp(CmpOp op, Value *lhs, Value *rhs, static FCmpInst *create_fcmp(CmpOp op,
BasicBlock *bb, Module *m); Value *lhs,
Value *rhs,
BasicBlock *bb,
Module *m);
CmpOp get_cmp_op() { return cmp_op_; } CmpOp get_cmp_op() { return cmp_op_; }
...@@ -284,33 +294,36 @@ class FCmpInst : public BaseInst<FCmpInst> ...@@ -284,33 +294,36 @@ class FCmpInst : public BaseInst<FCmpInst>
void assert_valid(); void assert_valid();
}; };
class CallInst : public BaseInst<CallInst> class CallInst : public BaseInst<CallInst> {
{
friend BaseInst<CallInst>; friend BaseInst<CallInst>;
protected: protected:
CallInst(Function *func, std::vector<Value *> args, BasicBlock *bb); CallInst(Function *func, std::vector<Value *> args, BasicBlock *bb);
public: public:
static CallInst *create(Function *func, std::vector<Value *> args, static CallInst *create(Function *func,
std::vector<Value *> args,
BasicBlock *bb); BasicBlock *bb);
FunctionType *get_function_type() const; FunctionType *get_function_type() const;
virtual std::string print() override; virtual std::string print() override;
}; };
class BranchInst : public BaseInst<BranchInst> class BranchInst : public BaseInst<BranchInst> {
{
friend BaseInst<BranchInst>; friend BaseInst<BranchInst>;
private: private:
BranchInst(Value *cond, BasicBlock *if_true, BasicBlock *if_false, BranchInst(Value *cond,
BasicBlock *if_true,
BasicBlock *if_false,
BasicBlock *bb); BasicBlock *bb);
BranchInst(BasicBlock *if_true, BasicBlock *bb); BranchInst(BasicBlock *if_true, BasicBlock *bb);
public: public:
static BranchInst *create_cond_br(Value *cond, BasicBlock *if_true, static BranchInst *create_cond_br(Value *cond,
BasicBlock *if_false, BasicBlock *bb); BasicBlock *if_true,
BasicBlock *if_false,
BasicBlock *bb);
static BranchInst *create_br(BasicBlock *if_true, BasicBlock *bb); static BranchInst *create_br(BasicBlock *if_true, BasicBlock *bb);
bool is_cond_br() const; bool is_cond_br() const;
...@@ -318,8 +331,7 @@ class BranchInst : public BaseInst<BranchInst> ...@@ -318,8 +331,7 @@ class BranchInst : public BaseInst<BranchInst>
virtual std::string print() override; virtual std::string print() override;
}; };
class ReturnInst : public BaseInst<ReturnInst> class ReturnInst : public BaseInst<ReturnInst> {
{
friend BaseInst<ReturnInst>; friend BaseInst<ReturnInst>;
private: private:
...@@ -329,13 +341,12 @@ class ReturnInst : public BaseInst<ReturnInst> ...@@ -329,13 +341,12 @@ class ReturnInst : public BaseInst<ReturnInst>
public: public:
static ReturnInst *create_ret(Value *val, BasicBlock *bb); static ReturnInst *create_ret(Value *val, BasicBlock *bb);
static ReturnInst *create_void_ret(BasicBlock *bb); static ReturnInst *create_void_ret(BasicBlock *bb);
bool is_void_ret() const; bool is_void_ret() const;
virtual std::string print() override; virtual std::string print() override;
}; };
class GetElementPtrInst : public BaseInst<GetElementPtrInst> class GetElementPtrInst : public BaseInst<GetElementPtrInst> {
{
friend BaseInst<GetElementPtrInst>; friend BaseInst<GetElementPtrInst>;
private: private:
...@@ -343,9 +354,10 @@ class GetElementPtrInst : public BaseInst<GetElementPtrInst> ...@@ -343,9 +354,10 @@ class GetElementPtrInst : public BaseInst<GetElementPtrInst>
public: public:
static Type *get_element_type(Value *ptr, std::vector<Value *> idxs); static Type *get_element_type(Value *ptr, std::vector<Value *> idxs);
static GetElementPtrInst *create_gep(Value *ptr, std::vector<Value *> idxs, static GetElementPtrInst *create_gep(Value *ptr,
std::vector<Value *> idxs,
BasicBlock *bb); BasicBlock *bb);
Type *get_element_type() const; Type *get_element_type() const;
virtual std::string print() override; virtual std::string print() override;
...@@ -353,8 +365,7 @@ class GetElementPtrInst : public BaseInst<GetElementPtrInst> ...@@ -353,8 +365,7 @@ class GetElementPtrInst : public BaseInst<GetElementPtrInst>
Type *element_ty_; Type *element_ty_;
}; };
class StoreInst : public BaseInst<StoreInst> class StoreInst : public BaseInst<StoreInst> {
{
friend BaseInst<StoreInst>; friend BaseInst<StoreInst>;
private: private:
...@@ -369,8 +380,7 @@ class StoreInst : public BaseInst<StoreInst> ...@@ -369,8 +380,7 @@ class StoreInst : public BaseInst<StoreInst>
virtual std::string print() override; virtual std::string print() override;
}; };
class LoadInst : public BaseInst<LoadInst> class LoadInst : public BaseInst<LoadInst> {
{
friend BaseInst<LoadInst>; friend BaseInst<LoadInst>;
private: private:
...@@ -378,15 +388,14 @@ class LoadInst : public BaseInst<LoadInst> ...@@ -378,15 +388,14 @@ class LoadInst : public BaseInst<LoadInst>
public: public:
static LoadInst *create_load(Type *ty, Value *ptr, BasicBlock *bb); static LoadInst *create_load(Type *ty, Value *ptr, BasicBlock *bb);
Value *get_lval() { return this->get_operand(0); } Value *get_lval() { return this->get_operand(0); }
Type *get_load_type() const; Type *get_load_type() const;
virtual std::string print() override; virtual std::string print() override;
}; };
class AllocaInst : public BaseInst<AllocaInst> class AllocaInst : public BaseInst<AllocaInst> {
{
friend BaseInst<AllocaInst>; friend BaseInst<AllocaInst>;
private: private:
...@@ -403,8 +412,7 @@ class AllocaInst : public BaseInst<AllocaInst> ...@@ -403,8 +412,7 @@ class AllocaInst : public BaseInst<AllocaInst>
Type *alloca_ty_; Type *alloca_ty_;
}; };
class ZextInst : public BaseInst<ZextInst> class ZextInst : public BaseInst<ZextInst> {
{
friend BaseInst<ZextInst>; friend BaseInst<ZextInst>;
private: private:
...@@ -421,8 +429,7 @@ class ZextInst : public BaseInst<ZextInst> ...@@ -421,8 +429,7 @@ class ZextInst : public BaseInst<ZextInst>
Type *dest_ty_; Type *dest_ty_;
}; };
class FpToSiInst : public BaseInst<FpToSiInst> class FpToSiInst : public BaseInst<FpToSiInst> {
{
friend BaseInst<FpToSiInst>; friend BaseInst<FpToSiInst>;
private: private:
...@@ -439,8 +446,7 @@ class FpToSiInst : public BaseInst<FpToSiInst> ...@@ -439,8 +446,7 @@ class FpToSiInst : public BaseInst<FpToSiInst>
Type *dest_ty_; Type *dest_ty_;
}; };
class SiToFpInst : public BaseInst<SiToFpInst> class SiToFpInst : public BaseInst<SiToFpInst> {
{
friend BaseInst<SiToFpInst>; friend BaseInst<SiToFpInst>;
private: private:
...@@ -457,25 +463,24 @@ class SiToFpInst : public BaseInst<SiToFpInst> ...@@ -457,25 +463,24 @@ class SiToFpInst : public BaseInst<SiToFpInst>
Type *dest_ty_; Type *dest_ty_;
}; };
class PhiInst : public BaseInst<PhiInst> class PhiInst : public BaseInst<PhiInst> {
{
friend BaseInst<PhiInst>; friend BaseInst<PhiInst>;
private: private:
PhiInst(OpID op, std::vector<Value *> vals, PhiInst(OpID op,
std::vector<BasicBlock *> val_bbs, Type *ty, BasicBlock *bb); std::vector<Value *> vals,
std::vector<BasicBlock *> val_bbs,
Type *ty,
BasicBlock *bb);
PhiInst(Type *ty, OpID op, unsigned num_ops, BasicBlock *bb) PhiInst(Type *ty, OpID op, unsigned num_ops, BasicBlock *bb)
: BaseInst<PhiInst>(ty, op, num_ops, bb) : BaseInst<PhiInst>(ty, op, num_ops, bb) {}
{
}
Value *l_val_; Value *l_val_;
public: public:
static PhiInst *create_phi(Type *ty, BasicBlock *bb); static PhiInst *create_phi(Type *ty, BasicBlock *bb);
Value *get_lval() { return l_val_; } Value *get_lval() { return l_val_; }
void set_lval(Value *l_val) { l_val_ = l_val; } void set_lval(Value *l_val) { l_val_ = l_val; }
void add_phi_pair_operand(Value *val, Value *pre_bb) void add_phi_pair_operand(Value *val, Value *pre_bb) {
{
this->add_operand(val); this->add_operand(val);
this->add_operand(pre_bb); this->add_operand(pre_bb);
} }
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
class GVN;
namespace GVNExpression { namespace GVNExpression {
// fold the constant value // fold the constant value
...@@ -35,12 +36,13 @@ class ConstFolder { ...@@ -35,12 +36,13 @@ class ConstFolder {
/** /**
* for constructor of class derived from `Expression`, we make it public * for constructor of class derived from `Expression`, we make it public
* because `std::make_shared` needs the constructor to be publicly available, * because `std::make_shared` needs the constructor to be publicly available,
* but you should call the static factory method `create` instead the constructor itself to get the desired data * but you should call the static factory method `create` instead the
* constructor itself to get the desired data
*/ */
class Expression { class Expression {
public: public:
// TODO: you need to extend expression types according to testcases // TODO: you need to extend expression types according to testcases
enum gvn_expr_t { e_constant, e_bin, e_phi }; enum gvn_expr_t { e_constant, e_bin, e_phi, e_cast, e_gep, e_unique };
Expression(gvn_expr_t t) : expr_type(t) {} Expression(gvn_expr_t t) : expr_type(t) {}
virtual ~Expression() = default; virtual ~Expression() = default;
virtual std::string print() = 0; virtual std::string print() = 0;
...@@ -50,15 +52,21 @@ class Expression { ...@@ -50,15 +52,21 @@ class Expression {
gvn_expr_t expr_type; gvn_expr_t expr_type;
}; };
bool operator==(const std::shared_ptr<Expression> &lhs, const std::shared_ptr<Expression> &rhs); bool operator==(const std::shared_ptr<Expression> &lhs,
bool operator==(const GVNExpression::Expression &lhs, const GVNExpression::Expression &rhs); const std::shared_ptr<Expression> &rhs);
bool operator==(const GVNExpression::Expression &lhs,
const GVNExpression::Expression &rhs);
class ConstantExpression : public Expression { class ConstantExpression : public Expression {
public: public:
static std::shared_ptr<ConstantExpression> create(Constant *c) { return std::make_shared<ConstantExpression>(c); } static std::shared_ptr<ConstantExpression> create(Constant *c) {
return std::make_shared<ConstantExpression>(c);
}
virtual std::string print() { return c_->print(); } virtual std::string print() { return c_->print(); }
// we leverage the fact that constants in lightIR have unique addresses // we leverage the fact that constants in lightIR have unique addresses
bool equiv(const ConstantExpression *other) const { return c_ == other->c_; } bool equiv(const ConstantExpression *other) const {
return c_ == other->c_;
}
ConstantExpression(Constant *c) : Expression(e_constant), c_(c) {} ConstantExpression(Constant *c) : Expression(e_constant), c_(c) {}
private: private:
...@@ -67,24 +75,31 @@ class ConstantExpression : public Expression { ...@@ -67,24 +75,31 @@ class ConstantExpression : public Expression {
// arithmetic expression // arithmetic expression
class BinaryExpression : public Expression { class BinaryExpression : public Expression {
friend class ::GVN;
public: public:
static std::shared_ptr<BinaryExpression> create(Instruction::OpID op, static std::shared_ptr<BinaryExpression> create(
std::shared_ptr<Expression> lhs, Instruction::OpID op,
std::shared_ptr<Expression> rhs) { std::shared_ptr<Expression> lhs,
std::shared_ptr<Expression> rhs) {
return std::make_shared<BinaryExpression>(op, lhs, rhs); return std::make_shared<BinaryExpression>(op, lhs, rhs);
} }
virtual std::string print() { virtual std::string print() {
return "(" + Instruction::get_instr_op_name(op_) + " " + lhs_->print() + " " + rhs_->print() + ")"; return "(" + Instruction::get_instr_op_name(op_) + " " + lhs_->print() +
" " + rhs_->print() + ")";
} }
bool equiv(const BinaryExpression *other) const { bool equiv(const BinaryExpression *other) const {
if (op_ == other->op_ and *lhs_ == *other->lhs_ and *rhs_ == *other->rhs_) if (op_ == other->op_ and *lhs_ == *other->lhs_ and
*rhs_ == *other->rhs_)
return true; return true;
else else
return false; return false;
} }
BinaryExpression(Instruction::OpID op, std::shared_ptr<Expression> lhs, std::shared_ptr<Expression> rhs) BinaryExpression(Instruction::OpID op,
std::shared_ptr<Expression> lhs,
std::shared_ptr<Expression> rhs)
: Expression(e_bin), op_(op), lhs_(lhs), rhs_(rhs) {} : Expression(e_bin), op_(op), lhs_(lhs), rhs_(rhs) {}
private: private:
...@@ -93,33 +108,122 @@ class BinaryExpression : public Expression { ...@@ -93,33 +108,122 @@ class BinaryExpression : public Expression {
}; };
class PhiExpression : public Expression { class PhiExpression : public Expression {
friend class ::GVN;
public: public:
static std::shared_ptr<PhiExpression> create(std::shared_ptr<Expression> lhs, std::shared_ptr<Expression> rhs) { static std::shared_ptr<PhiExpression> create(
std::shared_ptr<Expression> lhs,
std::shared_ptr<Expression> rhs) {
return std::make_shared<PhiExpression>(lhs, rhs); return std::make_shared<PhiExpression>(lhs, rhs);
} }
virtual std::string print() { return "(phi " + lhs_->print() + " " + rhs_->print() + ")"; } virtual std::string print() {
return "(phi " + lhs_->print() + " " + rhs_->print() + ")";
}
bool equiv(const PhiExpression *other) const { bool equiv(const PhiExpression *other) const {
if (*lhs_ == *other->lhs_ and *rhs_ == *other->rhs_) if (*lhs_ == *other->lhs_ and *rhs_ == *other->rhs_)
return true; return true;
else else
return false; return false;
} }
PhiExpression(std::shared_ptr<Expression> lhs, std::shared_ptr<Expression> rhs) PhiExpression(std::shared_ptr<Expression> lhs,
std::shared_ptr<Expression> rhs)
: Expression(e_phi), lhs_(lhs), rhs_(rhs) {} : Expression(e_phi), lhs_(lhs), rhs_(rhs) {}
private: private:
std::shared_ptr<Expression> lhs_, rhs_; std::shared_ptr<Expression> lhs_, rhs_;
}; };
// type cast expression
class CastExpression : public Expression {
public:
static std::shared_ptr<CastExpression> create(
Instruction::OpID op,
std::shared_ptr<Expression> src,
Type *dest_type) {
return std::make_shared<CastExpression>(op, src, dest_type);
}
virtual std::string print() {
return "(" + dest_ty_->print() + " " +
Instruction::get_instr_op_name(op_) + " " + src_->print() + ")";
}
bool equiv(const CastExpression *other) const {
return op_ == other->op_ and src_ == other->src_ and
dest_ty_ == other->dest_ty_;
}
CastExpression(Instruction::OpID op,
std::shared_ptr<Expression> src,
Type *dest_type)
: Expression(e_cast), op_(op), src_(src), dest_ty_(dest_type) {}
private:
Instruction::OpID op_;
std::shared_ptr<Expression> src_;
Type *dest_ty_;
};
// type cast expression
class GEPExpression : public Expression {
public:
static std::shared_ptr<GEPExpression> create(
std::shared_ptr<Expression> ptr,
std::vector<std::shared_ptr<Expression>> &idxs) {
return std::make_shared<GEPExpression>(ptr, idxs);
}
virtual std::string print() {
std::string ret = "(GEP " + ptr_->print();
for (auto idx : idxs_)
ret += " " + idx->print();
return ret + ")";
}
bool equiv(const GEPExpression *other) const {
if (idxs_.size() != other->idxs_.size())
return false;
for (int i = 0; i != idxs_.size(); ++i)
if (not(idxs_[i] == other->idxs_[i]))
return false;
return ptr_ == other->ptr_;
}
GEPExpression(std::shared_ptr<Expression> ptr,
std::vector<std::shared_ptr<Expression>> &idxs)
: Expression(e_gep), ptr_(ptr), idxs_(idxs) {}
private:
std::shared_ptr<Expression> ptr_;
std::vector<std::shared_ptr<Expression>> idxs_;
};
// unique expression: not equal to any one else
class UniqueExpression : public Expression {
public:
static std::shared_ptr<UniqueExpression> create(Instruction *instr) {
return std::make_shared<UniqueExpression>(instr);
}
virtual std::string print() { return "(UNIQUE " + instr_->print() + ")"; }
bool equiv(const UniqueExpression *other) const { return false; }
UniqueExpression(Instruction *instr)
: Expression(e_unique), instr_(instr) {}
private:
std::shared_ptr<Instruction> instr_;
};
} // namespace GVNExpression } // namespace GVNExpression
/** /**
* Congruence class in each partitions * Congruence class in each partitions
* note: for constant propagation, you might need to add other fields * note: for constant propagation, you might need to add other fields
* and for load/store redundancy detection, you most certainly need to modify the class * and for load/store redundancy detection, you most certainly need to modify
* the class
*/ */
struct CongruenceClass { struct CongruenceClass {
size_t index_; size_t index_;
// representative of the congruence class, used to replace all the members (except itself) when analysis is done // representative of the congruence class, used to replace all the members
// (except itself) when analysis is done
Value *leader_; Value *leader_;
// value expression in congruence class // value expression in congruence class
std::shared_ptr<GVNExpression::Expression> value_expr_; std::shared_ptr<GVNExpression::Expression> value_expr_;
...@@ -128,17 +232,22 @@ struct CongruenceClass { ...@@ -128,17 +232,22 @@ struct CongruenceClass {
// equivalent variables in one congruence class // equivalent variables in one congruence class
std::set<Value *> members_; std::set<Value *> members_;
CongruenceClass(size_t index) : index_(index), leader_{}, value_expr_{}, value_phi_{}, members_{} {} CongruenceClass(size_t index)
: index_(index), leader_{}, value_expr_{}, value_phi_{}, members_{} {}
bool operator<(const CongruenceClass &other) const { return this->index_ < other.index_; } bool operator<(const CongruenceClass &other) const {
return this->index_ < other.index_;
}
bool operator==(const CongruenceClass &other) const; bool operator==(const CongruenceClass &other) const;
}; };
namespace std { namespace std {
template <> template <>
// overload std::less for std::shared_ptr<CongruenceClass>, i.e. how to sort the congruence classes // overload std::less for std::shared_ptr<CongruenceClass>, i.e. how to sort the
// congruence classes
struct less<std::shared_ptr<CongruenceClass>> { struct less<std::shared_ptr<CongruenceClass>> {
bool operator()(const std::shared_ptr<CongruenceClass> &a, const std::shared_ptr<CongruenceClass> &b) const { bool operator()(const std::shared_ptr<CongruenceClass> &a,
const std::shared_ptr<CongruenceClass> &b) const {
// nullptrs should never appear in partitions, so we just dereference it // nullptrs should never appear in partitions, so we just dereference it
return *a < *b; return *a < *b;
} }
...@@ -154,17 +263,24 @@ class GVN : public Pass { ...@@ -154,17 +263,24 @@ class GVN : public Pass {
// init for pass metadata; // init for pass metadata;
void initPerFunction(); void initPerFunction();
// fill the following functions according to Pseudocode, **you might need to add more arguments** // fill the following functions according to Pseudocode, **you might need to
// add more arguments**
void detectEquivalences(); void detectEquivalences();
partitions join(const partitions &P1, const partitions &P2); partitions join(const partitions &P1, const partitions &P2);
std::shared_ptr<CongruenceClass> intersect(std::shared_ptr<CongruenceClass>, std::shared_ptr<CongruenceClass>); std::shared_ptr<CongruenceClass> intersect(
std::shared_ptr<CongruenceClass>,
std::shared_ptr<CongruenceClass>);
partitions transferFunction(Instruction *x, Value *e, partitions pin); partitions transferFunction(Instruction *x, Value *e, partitions pin);
partitions transferFunction(BasicBlock *bb); partitions transferFunction(BasicBlock *bb);
std::shared_ptr<GVNExpression::PhiExpression> valuePhiFunc(std::shared_ptr<GVNExpression::Expression>, std::shared_ptr<GVNExpression::PhiExpression> valuePhiFunc(
const partitions &); std::shared_ptr<GVNExpression::Expression>,
std::shared_ptr<GVNExpression::Expression> valueExpr(Instruction *instr); BasicBlock *bb);
std::shared_ptr<GVNExpression::Expression> getVN(const partitions &pout, std::shared_ptr<GVNExpression::Expression> valueExpr(
std::shared_ptr<GVNExpression::Expression> ve); Instruction *instr,
partitions *part = nullptr);
std::shared_ptr<GVNExpression::Expression> getVN(
const partitions &pout,
std::shared_ptr<GVNExpression::Expression> ve);
// replace cc members with leader // replace cc members with leader
void replace_cc_members(); void replace_cc_members();
...@@ -180,6 +296,7 @@ class GVN : public Pass { ...@@ -180,6 +296,7 @@ class GVN : public Pass {
// self add // self add
// //
std::uint64_t new_number() { return next_value_number_++; } std::uint64_t new_number() { return next_value_number_++; }
static int pretend_copy_stmt(Instruction *inst, BasicBlock *bb);
private: private:
bool dump_json_; bool dump_json_;
...@@ -191,7 +308,9 @@ class GVN : public Pass { ...@@ -191,7 +308,9 @@ class GVN : public Pass {
std::unique_ptr<DeadCode> dce_; std::unique_ptr<DeadCode> dce_;
// self add // self add
std::map<Instruction*, bool> _TOP; std::map<BasicBlock *, bool> _TOP;
partitions join_helper(BasicBlock *pre1, BasicBlock *pre2);
BasicBlock* curr_bb;
}; };
bool operator==(const GVN::partitions &p1, const GVN::partitions &p2); bool operator==(const GVN::partitions &p1, const GVN::partitions &p2);
cminusf_builder_stu.cpp
...@@ -5,22 +5,20 @@ ...@@ -5,22 +5,20 @@
#include "cminusf_builder.hpp" #include "cminusf_builder.hpp"
#include "logging.hpp"
#define CONST_FP(num) ConstantFP::get((float)num, module.get()) #define CONST_FP(num) ConstantFP::get((float)num, module.get())
#define CONST_INT(num) ConstantInt::get(num, module.get()) #define CONST_INT(num) ConstantInt::get(num, module.get())
// TODO: Global Variable Declarations
// You can define global variables here // You can define global variables here
// to store state. You can expand these // to store state
// definitions if you need to.
// the latest return value // store temporary value
Value *cur_value = nullptr; Value *tmp_val = nullptr;
// if var is assignment's left part, LV is true // whether require lvalue
bool LV = false; bool require_lvalue = false;
// function that is being built // function that is being built
Function *cur_fun = nullptr; Function *cur_fun = nullptr;
// detect scope pre-enter (for elegance only)
bool pre_enter_scope = false;
// types // types
Type *VOID_T; Type *VOID_T;
...@@ -30,9 +28,22 @@ Type *INT32PTR_T; ...@@ -30,9 +28,22 @@ Type *INT32PTR_T;
Type *FLOAT_T; Type *FLOAT_T;
Type *FLOATPTR_T; Type *FLOATPTR_T;
// initializer bool
ConstantZero *I32Initializer; promote(IRBuilder *builder, Value **l_val_p, Value **r_val_p) {
ConstantZero *FloatInitializer; bool is_int;
auto &l_val = *l_val_p;
auto &r_val = *r_val_p;
if (l_val->get_type() == r_val->get_type()) {
is_int = l_val->get_type()->is_integer_type();
} else {
is_int = false;
if (l_val->get_type()->is_integer_type())
l_val = builder->create_sitofp(l_val, FLOAT_T);
else
r_val = builder->create_sitofp(r_val, FLOAT_T);
}
return is_int;
}
/* /*
* use CMinusfBuilder::Scope to construct scopes * use CMinusfBuilder::Scope to construct scopes
...@@ -42,182 +53,61 @@ ConstantZero *FloatInitializer; ...@@ -42,182 +53,61 @@ ConstantZero *FloatInitializer;
* scope.find: find and return the value bound to the name * scope.find: find and return the value bound to the name
*/ */
void error_exit(std::string s) { void
LOG_ERROR << s; CminusfBuilder::visit(ASTProgram &node) {
std::abort();
}
// This function makes sure that
// 1. 2 values have same type
// 2. type is either i32 or float
void CminusfBuilder::biop_type_check(Value *&lvalue, Value *&rvalue, std::string util) {
if (Type::is_eq_type(lvalue->get_type(), rvalue->get_type())) {
if (lvalue->get_type()->is_integer_type() or lvalue->get_type()->is_float_type()) {
// check for i1
if (Type::is_eq_type(lvalue->get_type(), INT1_T)) {
lvalue = builder->create_zext(lvalue, INT32_T);
rvalue = builder->create_zext(rvalue, INT32_T);
}
} else
error_exit("not supported type cast for " + util);
return;
}
// only support cast between int and float: i32, i1, float
//
// case that integer and float is mixed, directly cast integer to float
if (lvalue->get_type()->is_integer_type() and rvalue->get_type()->is_float_type())
lvalue = builder->create_sitofp(lvalue, FLOAT_T);
else if (lvalue->get_type()->is_float_type() and rvalue->get_type()->is_integer_type())
rvalue = builder->create_sitofp(rvalue, FLOAT_T);
else if (lvalue->get_type()->is_integer_type() and rvalue->get_type()->is_integer_type()) {
// case that I32 and I1 mixed
if (Type::is_eq_type(lvalue->get_type(), INT1_T))
lvalue = builder->create_zext(lvalue, INT32_T);
else
rvalue = builder->create_zext(rvalue, INT32_T);
} else { // we only support computing among i1, i32 and float
error_exit("not supported type cast for " + util);
}
}
// this function makes sure value is a bool type
void CminusfBuilder::cast_to_i1(Value *&value) {
assert(value->get_type()->is_integer_type() or value->get_type()->is_float_type());
if (value->get_type()->is_float_type())
// value = builder->create_fptosi(value, INT1_T);
value = builder->create_fcmp_ne(value, CONST_FP(0));
else if (Type::is_eq_type(value->get_type(), INT32_T))
value = builder->create_icmp_ne(value, CONST_INT(0));
}
void CminusfBuilder::visit(ASTProgram &node) {
VOID_T = Type::get_void_type(module.get()); VOID_T = Type::get_void_type(module.get());
INT1_T = Type::get_int1_type(module.get()); INT1_T = Type::get_int1_type(module.get());
INT32_T = Type::get_int32_type(module.get()); INT32_T = Type::get_int32_type(module.get());
INT32PTR_T = Type::get_int32_ptr_type(module.get()); INT32PTR_T = Type::get_int32_ptr_type(module.get());
FLOAT_T = Type::get_float_type(module.get()); FLOAT_T = Type::get_float_type(module.get());
FLOATPTR_T = Type::get_float_ptr_type(module.get()); FLOATPTR_T = Type::get_float_ptr_type(module.get());
I32Initializer = ConstantZero::get(INT32_T, builder->get_module());
FloatInitializer = ConstantZero::get(FLOAT_T, builder->get_module());
for (auto decl : node.declarations) { for (auto decl : node.declarations) {
decl->accept(*this); decl->accept(*this);
} }
} }
// Done void
void CminusfBuilder::visit(ASTNum &node) { CminusfBuilder::visit(ASTNum &node) {
//!TODO: This function is empty now. if (node.type == TYPE_INT)
// Add some code here. tmp_val = CONST_INT(node.i_val);
// else
switch (node.type) { tmp_val = CONST_FP(node.f_val);
case TYPE_INT:
cur_value = CONST_INT(node.i_val);
return;
case TYPE_FLOAT:
cur_value = CONST_FP(node.f_val);
return;
default:
error_exit("ASTNum is not int or float");
}
} }
// Done void
void CminusfBuilder::visit(ASTVarDeclaration &node) { CminusfBuilder::visit(ASTVarDeclaration &node) {
//!TODO: This function is empty now. Type *var_type;
// Add some code here. if (node.type == TYPE_INT)
bool global = (builder->get_insert_block() == nullptr); var_type = Type::get_int32_type(module.get());
if (node.num) { else
// declares an array var_type = Type::get_float_type(module.get());
// if (node.num == nullptr) {
// get array size if (scope.in_global()) {
node.num->accept(*this); auto initializer = ConstantZero::get(var_type, module.get());
// auto var = GlobalVariable::create(
// !no type cast here! node.id, module.get(), var_type, false, initializer);
if (not(node.num->type == TYPE_INT)) scope.push(node.id, var);
error_exit("size of array has non-integer type"); } else {
auto var = builder->create_alloca(var_type);
int size = node.num->i_val; scope.push(node.id, var);
if (size <= 0)
error_exit("array size[" + std::to_string(size) + "] <= 0");
switch (node.type) {
case TYPE_INT: {
auto I32Array_T = Type::get_array_type(INT32_T, size);
if (global)
cur_value =
GlobalVariable::create(node.id, builder->get_module(), I32Array_T, false, I32Initializer);
else
cur_value = builder->create_alloca(I32Array_T);
break;
}
case TYPE_FLOAT: {
auto FloatArray_T = Type::get_array_type(FLOAT_T, size);
if (global)
cur_value =
GlobalVariable::create(node.id, builder->get_module(), FloatArray_T, false, FloatInitializer);
else
cur_value = builder->create_alloca(FloatArray_T);
break;
}
default:
error_exit("Variable type(not array) is not int or float");
} }
assert(cur_value->get_type()->is_pointer_type() && "IF SEE THIS: API ERROR");
} else { } else {
// flat int or float type auto *array_type = ArrayType::get(var_type, node.num->i_val);
switch (node.type) { if (scope.in_global()) {
case TYPE_INT: auto initializer = ConstantZero::get(array_type, module.get());
if (global) auto var = GlobalVariable::create(
cur_value = GlobalVariable::create(node.id, builder->get_module(), INT32_T, false, I32Initializer); node.id, module.get(), array_type, false, initializer);
else scope.push(node.id, var);
cur_value = builder->create_alloca(INT32_T); } else {
break; auto var = builder->create_alloca(array_type);
scope.push(node.id, var);
case TYPE_FLOAT:
if (global)
cur_value =
GlobalVariable::create(node.id, builder->get_module(), FLOAT_T, false, FloatInitializer);
else {
/* Beautiful is better than ugly.
* Explicit is better than implicit.
* Simple is better than complex.
* Complex is better than complicated.
* Flat is better than nested.
* Sparse is better than dense.
* Readability counts.
* Special cases aren't special enough to break the rules.
* Although practicality beats purity.
* Errors should never pass silently.
* Unless explicitly silenced.
* In the face of ambiguity, refuse the temptation to guess.
* There should be one-- and preferably only one --obvious way to do it.
* Although that way may not be obvious at first unless you're Dutch.
* Now is better than never.
* Although never is often better than *right* now.
* If the implementation is hard to explain, it's a bad idea.
* If the implementation is easy to explain, it may be a good idea.
* Namespaces are one honking great idea -- let's do more of those! */
// cur_value = builder->create_alloca(INT32_T);
cur_value = builder->create_alloca(FLOAT_T);
}
break;
default:
error_exit("Variable type(not array) is not int or float");
} }
} }
if (not scope.push(node.id, cur_value))
error_exit("variable redefined: " + node.id);
LOG_DEBUG << "add entry: " << node.id << " " << cur_value;
} }
// Done void
void CminusfBuilder::visit(ASTFunDeclaration &node) { CminusfBuilder::visit(ASTFunDeclaration &node) {
FunctionType *fun_type; FunctionType *fun_type;
Type *ret_type; Type *ret_type;
std::vector<Type *> param_types; std::vector<Type *> param_types;
...@@ -229,46 +119,54 @@ void CminusfBuilder::visit(ASTFunDeclaration &node) { ...@@ -229,46 +119,54 @@ void CminusfBuilder::visit(ASTFunDeclaration &node) {
ret_type = VOID_T; ret_type = VOID_T;
for (auto &param : node.params) { for (auto &param : node.params) {
//!TODO: Please accomplish param_types. if (param->type == TYPE_INT) {
// if (param->isarray) {
// First make function BB, which needs this param type, param_types.push_back(INT32PTR_T);
// then set_insert_point, we can call accept to gen code, } else {
switch (param->type) { param_types.push_back(INT32_T);
case TYPE_INT: }
param_types.push_back(param->isarray ? INT32PTR_T : INT32_T); } else {
break; if (param->isarray) {
case TYPE_FLOAT: param_types.push_back(FLOATPTR_T);
param_types.push_back(param->isarray ? FLOATPTR_T : FLOAT_T); } else {
break; param_types.push_back(FLOAT_T);
case TYPE_VOID: }
if (not param_types.empty())
error_exit("function parameters weird");
break;
} }
} }
fun_type = FunctionType::get(ret_type, param_types); fun_type = FunctionType::get(ret_type, param_types);
auto fun = Function::create(fun_type, node.id, module.get()); auto fun = Function::create(fun_type, node.id, module.get());
cur_fun = fun;
scope.push(node.id, fun); scope.push(node.id, fun);
cur_fun = fun;
auto funBB = BasicBlock::create(module.get(), "entry", fun); auto funBB = BasicBlock::create(module.get(), "entry", fun);
builder->set_insert_point(funBB); builder->set_insert_point(funBB);
scope.enter(); scope.enter();
pre_enter_scope = true;
std::vector<Value *> args; std::vector<Value *> args;
for (auto arg = fun->arg_begin(); arg != fun->arg_end(); arg++) { for (auto arg = fun->arg_begin(); arg != fun->arg_end(); arg++) {
args.push_back(*arg); args.push_back(*arg);
} }
for (int i = 0; i < node.params.size(); ++i) { for (int i = 0; i < node.params.size(); ++i) {
//!TODO: You need to deal with params if (node.params[i]->isarray) {
// and store them in the scope. Value *array_alloc;
cur_value = args[i]; if (node.params[i]->type == TYPE_INT)
node.params[i]->accept(*this); array_alloc = builder->create_alloca(INT32PTR_T);
else
array_alloc = builder->create_alloca(FLOATPTR_T);
builder->create_store(args[i], array_alloc);
scope.push(node.params[i]->id, array_alloc);
} else {
Value *alloc;
if (node.params[i]->type == TYPE_INT)
alloc = builder->create_alloca(INT32_T);
else
alloc = builder->create_alloca(FLOAT_T);
builder->create_store(args[i], alloc);
scope.push(node.params[i]->id, alloc);
}
} }
node.compound_stmt->accept(*this); node.compound_stmt->accept(*this);
// default return value // can't deal with return in both blocks
if (builder->get_insert_block()->get_terminator() == nullptr) { if (builder->get_insert_block()->get_terminator() == nullptr) {
if (cur_fun->get_return_type()->is_void_type()) if (cur_fun->get_return_type()->is_void_type())
builder->create_void_ret(); builder->create_void_ret();
...@@ -280,40 +178,17 @@ void CminusfBuilder::visit(ASTFunDeclaration &node) { ...@@ -280,40 +178,17 @@ void CminusfBuilder::visit(ASTFunDeclaration &node) {
scope.exit(); scope.exit();
} }
// Done void
void CminusfBuilder::visit(ASTParam &node) { CminusfBuilder::visit(ASTParam &node) {}
//!TODO: This function is empty now.
// If the parameter is int|float, copy and store them
auto param_value = cur_value;
switch (node.type) {
case TYPE_INT: {
if (node.isarray)
cur_value = builder->create_alloca(INT32PTR_T);
else
cur_value = builder->create_alloca(INT32_T);
break;
}
case TYPE_FLOAT: {
if (node.isarray)
cur_value = builder->create_alloca(FLOATPTR_T);
else
cur_value = builder->create_alloca(FLOAT_T);
break;
}
case TYPE_VOID:
return;
}
scope.push(node.id, cur_value);
builder->create_store(param_value, cur_value);
}
// Done? void
void CminusfBuilder::visit(ASTCompoundStmt &node) { CminusfBuilder::visit(ASTCompoundStmt &node) {
//!TODO: This function is not complete. bool need_exit_scope = !pre_enter_scope;
// You may need to add some code here if (pre_enter_scope) {
// to deal with complex statements. pre_enter_scope = false;
} else {
scope.enter(); scope.enter();
}
for (auto &decl : node.local_declarations) { for (auto &decl : node.local_declarations) {
decl->accept(*this); decl->accept(*this);
...@@ -325,413 +200,303 @@ void CminusfBuilder::visit(ASTCompoundStmt &node) { ...@@ -325,413 +200,303 @@ void CminusfBuilder::visit(ASTCompoundStmt &node) {
break; break;
} }
scope.exit(); if (need_exit_scope) {
scope.exit();
}
} }
// Done void
void CminusfBuilder::visit(ASTExpressionStmt &node) { CminusfBuilder::visit(ASTExpressionStmt &node) {
//!TODO: This function is empty now. if (node.expression != nullptr)
// Add some code here.
if (node.expression)
node.expression->accept(*this); node.expression->accept(*this);
} }
// Done void
void CminusfBuilder::visit(ASTSelectionStmt &node) { CminusfBuilder::visit(ASTSelectionStmt &node) {
//!TODO: This function is empty now.
// Add some code here.
scope.enter();
node.expression->accept(*this); node.expression->accept(*this);
auto cond = cur_value; auto ret_val = tmp_val;
cast_to_i1(cond); auto trueBB = BasicBlock::create(module.get(), "", cur_fun);
BasicBlock *falseBB{};
auto ifBB = BasicBlock::create(builder->get_module(), "", cur_fun); auto contBB = BasicBlock::create(module.get(), "", cur_fun);
auto endBB = BasicBlock::create(builder->get_module(), "", cur_fun); Value *cond_val;
if (node.else_statement) { if (ret_val->get_type()->is_integer_type())
auto elseBB = BasicBlock::create(builder->get_module(), "", cur_fun); cond_val = builder->create_icmp_ne(ret_val, CONST_INT(0));
builder->create_cond_br(cond, ifBB, elseBB); else
cond_val = builder->create_fcmp_ne(ret_val, CONST_FP(0.));
builder->set_insert_point(ifBB);
node.if_statement->accept(*this);
builder->create_br(endBB);
builder->set_insert_point(elseBB);
node.else_statement->accept(*this);
builder->create_br(endBB);
builder->set_insert_point(endBB); if (node.else_statement == nullptr) {
builder->create_cond_br(cond_val, trueBB, contBB);
} else { } else {
builder->create_cond_br(cond, ifBB, endBB); falseBB = BasicBlock::create(module.get(), "", cur_fun);
builder->create_cond_br(cond_val, trueBB, falseBB);
}
builder->set_insert_point(trueBB);
node.if_statement->accept(*this);
builder->set_insert_point(ifBB); if (builder->get_insert_block()->get_terminator() == nullptr)
node.if_statement->accept(*this); builder->create_br(contBB);
builder->create_br(endBB);
builder->set_insert_point(endBB); if (node.else_statement == nullptr) {
// falseBB->erase_from_parent(); // did not clean up memory
} else {
builder->set_insert_point(falseBB);
node.else_statement->accept(*this);
if (builder->get_insert_block()->get_terminator() == nullptr)
builder->create_br(contBB);
} }
scope.exit();
}
// Done builder->set_insert_point(contBB);
void CminusfBuilder::visit(ASTIterationStmt &node) { }
//!TODO: This function is empty now.
// Add some code here.
scope.enter();
auto HEAD = BasicBlock::create(builder->get_module(), "", cur_fun);
auto BODY = BasicBlock::create(builder->get_module(), "", cur_fun);
auto END = BasicBlock::create(builder->get_module(), "", cur_fun);
builder->create_br(HEAD);
builder->set_insert_point(HEAD); void
CminusfBuilder::visit(ASTIterationStmt &node) {
auto exprBB = BasicBlock::create(module.get(), "", cur_fun);
if (builder->get_insert_block()->get_terminator() == nullptr)
builder->create_br(exprBB);
builder->set_insert_point(exprBB);
node.expression->accept(*this); node.expression->accept(*this);
auto cond = cur_value; auto ret_val = tmp_val;
cast_to_i1(cond); auto trueBB = BasicBlock::create(module.get(), "", cur_fun);
builder->create_cond_br(cond, BODY, END); auto contBB = BasicBlock::create(module.get(), "", cur_fun);
Value *cond_val;
if (ret_val->get_type()->is_integer_type())
cond_val = builder->create_icmp_ne(ret_val, CONST_INT(0));
else
cond_val = builder->create_fcmp_ne(ret_val, CONST_FP(0.));
builder->set_insert_point(BODY); builder->create_cond_br(cond_val, trueBB, contBB);
builder->set_insert_point(trueBB);
node.statement->accept(*this); node.statement->accept(*this);
builder->create_br(HEAD); if (builder->get_insert_block()->get_terminator() == nullptr)
builder->create_br(exprBB);
builder->set_insert_point(END); builder->set_insert_point(contBB);
scope.exit();
} }
// Done void
void CminusfBuilder::visit(ASTReturnStmt &node) { CminusfBuilder::visit(ASTReturnStmt &node) {
if (node.expression == nullptr) { if (node.expression == nullptr) {
builder->create_void_ret(); builder->create_void_ret();
} else { } else {
//!TODO: The given code is incomplete. auto fun_ret_type = cur_fun->get_function_type()->get_return_type();
// You need to solve other return cases (e.g. return an integer).
//
node.expression->accept(*this); node.expression->accept(*this);
// type cast if (fun_ret_type != tmp_val->get_type()) {
// return type can only be int, float or void if (fun_ret_type->is_integer_type())
if (not Type::is_eq_type(cur_fun->get_return_type(), cur_value->get_type())) { tmp_val = builder->create_fptosi(tmp_val, INT32_T);
if (not cur_value->get_type()->is_integer_type() and not cur_value->get_type()->is_float_type())
error_exit("unsupported return type");
if (cur_value->get_type()->is_float_type())
cur_value = builder->create_fptosi(cur_value, INT32_T);
else if (cur_fun->get_return_type()->is_float_type())
cur_value = builder->create_sitofp(cur_value, FLOAT_T);
else else
cur_value = builder->create_zext(cur_value, INT32_T); tmp_val = builder->create_sitofp(tmp_val, FLOAT_T);
} }
builder->create_ret(cur_value); builder->create_ret(tmp_val);
LOG_DEBUG << "create ret:\n" << builder->get_module()->print();
} }
} }
// Done void
// if LV is marked, return memory addr CminusfBuilder::visit(ASTVar &node) {
// else return value stored inside auto var = scope.find(node.id);
void CminusfBuilder::visit(ASTVar &node) { assert(var != nullptr);
//!TODO: This function is empty now. auto is_int =
// Add some code here. var->get_type()->get_pointer_element_type()->is_integer_type();
// auto is_float =
// First it's pointer type, the pointed elements have 3 cases: var->get_type()->get_pointer_element_type()->is_float_type();
// 1. int or float auto is_ptr =
// 2. [i32 x n] or [float x n] var->get_type()->get_pointer_element_type()->is_pointer_type();
// 3. int* bool should_return_lvalue = require_lvalue;
auto memory = scope.find(node.id); require_lvalue = false;
Value *addr; if (node.expression == nullptr) {
if (memory == nullptr) if (should_return_lvalue) {
error_exit("variable " + node.id + " not declared"); tmp_val = var;
LOG_DEBUG << "find entry: " << node.id << " " << memory; require_lvalue = false;
assert(memory->get_type()->is_pointer_type());
auto element_type = memory->get_type()->get_pointer_element_type();
if (node.expression) { // e.g. int a[10]; // mem is [i32 x 10]*
bool old_LV = LV;
LV = false;
node.expression->accept(*this);
LV = old_LV;
// subscription type cast
if (not Type::is_eq_type(cur_value->get_type(), INT32_T)) {
if (Type::is_eq_type(cur_value->get_type(), FLOAT_T))
cur_value = builder->create_fptosi(cur_value, INT32_T);
else
error_exit("bad type for subscription");
}
auto cond = builder->create_icmp_lt(cur_value, CONST_INT(0));
auto except_func = scope.find("neg_idx_except");
auto TBB = BasicBlock::create(builder->get_module(), "", cur_fun);
auto passBB = BasicBlock::create(builder->get_module(), "", cur_fun);
builder->create_cond_br(cond, TBB, passBB);
builder->set_insert_point(TBB);
builder->create_call(except_func, {});
builder->create_br(passBB);
builder->set_insert_point(passBB);
// Now the subscription is in cur_value, which is good value
// We should focus on the var type:
//
// assert it's pointer type
if (element_type->is_float_type() or element_type->is_integer_type()) {
// 1. int or float
error_exit("invalid types for array subscript");
} else if (element_type->is_array_type()) {
// 2. [i32 x n] or [float x n]
//
// addr is actually &memory[0][cur_value]
addr = builder->create_gep(memory, {CONST_INT(0), cur_value});
} else if (element_type->is_pointer_type()) {
// 3. int*
// what to do:
// - int**->int*
// - seek addr
//
// addr is actually &(*memory)[cur_value]
addr = builder->create_load(memory);
addr = builder->create_gep(addr, {cur_value});
}
// The logic for this part is the same as `int|float` without subscription.
// Cause we have subscription to find the particular element(int or float),
// we make `addr` its memory address.
} else { // e.g. int a; // a is i32*
if (element_type->is_float_type() or element_type->is_integer_type()) {
// 1. int or float
// addr is the element's addr
addr = memory;
} else { } else {
if (LV) if (is_int || is_float || is_ptr) {
error_exit("error: pointer or array type is not assignable"); tmp_val = builder->create_load(var);
// For array* or pointer* type, the right-value behaviour is quite special, } else {
// so treat them apart. tmp_val =
if (element_type->is_array_type()) { builder->create_gep(var, {CONST_INT(0), CONST_INT(0)});
// 2. [i32 x n] or [float x n]
// addr is the first element's address in the array
cur_value = builder->create_gep(memory, {CONST_INT(0), CONST_INT(0)});
} else if (element_type->is_pointer_type()) {
// 3. int*
// addr is the content in the memory, which is actually pointer type
cur_value = builder->create_load(memory);
} }
return;
} }
}
if (LV) {
LOG_INFO << "directly return addr" << node.id;
cur_value = addr;
} else { } else {
LOG_INFO << "create load for var: " << node.id; node.expression->accept(*this);
cur_value = builder->create_load(addr); auto val = tmp_val;
Value *is_neg;
auto exceptBB = BasicBlock::create(module.get(), "", cur_fun);
auto contBB = BasicBlock::create(module.get(), "", cur_fun);
if (val->get_type()->is_float_type())
val = builder->create_fptosi(val, INT32_T);
is_neg = builder->create_icmp_lt(val, CONST_INT(0));
builder->create_cond_br(is_neg, exceptBB, contBB);
builder->set_insert_point(exceptBB);
auto neg_idx_except_fun = scope.find("neg_idx_except");
builder->create_call(static_cast<Function *>(neg_idx_except_fun), {});
if (cur_fun->get_return_type()->is_void_type())
builder->create_void_ret();
else if (cur_fun->get_return_type()->is_float_type())
builder->create_ret(CONST_FP(0.));
else
builder->create_ret(CONST_INT(0));
builder->set_insert_point(contBB);
Value *tmp_ptr;
if (is_int || is_float)
tmp_ptr = builder->create_gep(var, {val});
else if (is_ptr) {
auto array_load = builder->create_load(var);
tmp_ptr = builder->create_gep(array_load, {val});
} else
tmp_ptr = builder->create_gep(var, {CONST_INT(0), val});
if (should_return_lvalue) {
tmp_val = tmp_ptr;
require_lvalue = false;
} else {
tmp_val = builder->create_load(tmp_ptr);
}
} }
} }
// Done void
void CminusfBuilder::visit(ASTAssignExpression &node) { CminusfBuilder::visit(ASTAssignExpression &node) {
//!TODO: This function is empty now.
// Add some code here.
//
LV = true;
node.var->accept(*this);
LV = false;
auto addr = cur_value;
node.expression->accept(*this); node.expression->accept(*this);
auto expr_result = tmp_val;
assert(addr->get_type()->get_pointer_element_type() != nullptr); require_lvalue = true;
// type cast: left is a pointer type, pointed to i32 or float node.var->accept(*this);
if (not Type::is_eq_type(addr->get_type()->get_pointer_element_type(), cur_value->get_type())) { auto var_addr = tmp_val;
if (cur_value->get_type()->is_float_type()) if (var_addr->get_type()->get_pointer_element_type() !=
cur_value = builder->create_fptosi(cur_value, INT32_T); expr_result->get_type()) {
else if (addr->get_type()->get_pointer_element_type()->is_float_type()) if (expr_result->get_type() == INT32_T)
cur_value = builder->create_sitofp(cur_value, FLOAT_T); expr_result = builder->create_sitofp(expr_result, FLOAT_T);
else if (Type::is_eq_type(cur_value->get_type(), INT1_T))
cur_value = builder->create_zext(cur_value, INT32_T);
else else
error_exit("bad type for assignment"); expr_result = builder->create_fptosi(expr_result, INT32_T);
} }
// gen code builder->create_store(expr_result, var_addr);
builder->create_store(cur_value, addr); tmp_val = expr_result;
} }
// Done void
void CminusfBuilder::visit(ASTSimpleExpression &node) { CminusfBuilder::visit(ASTSimpleExpression &node) {
//!TODO: This function is empty now. if (node.additive_expression_r == nullptr) {
// Add some code here.
//
if (node.additive_expression_r) {
node.additive_expression_l->accept(*this); node.additive_expression_l->accept(*this);
auto lvalue = cur_value; } else {
node.additive_expression_l->accept(*this);
auto l_val = tmp_val;
node.additive_expression_r->accept(*this); node.additive_expression_r->accept(*this);
auto rvalue = cur_value; auto r_val = tmp_val;
// check type bool is_int = promote(&*builder, &l_val, &r_val);
biop_type_check(lvalue, rvalue, "cmp"); Value *cmp;
bool float_cmp = lvalue->get_type()->is_float_type();
switch (node.op) { switch (node.op) {
case OP_LE: { case OP_LT:
if (float_cmp) if (is_int)
cur_value = builder->create_fcmp_le(lvalue, rvalue); cmp = builder->create_icmp_lt(l_val, r_val);
else else
cur_value = builder->create_icmp_le(lvalue, rvalue); cmp = builder->create_fcmp_lt(l_val, r_val);
break; break;
} case OP_LE:
case OP_LT: { if (is_int)
if (float_cmp) cmp = builder->create_icmp_le(l_val, r_val);
cur_value = builder->create_fcmp_lt(lvalue, rvalue);
else else
cur_value = builder->create_icmp_lt(lvalue, rvalue); cmp = builder->create_fcmp_le(l_val, r_val);
break; break;
} case OP_GE:
case OP_GT: { if (is_int)
if (float_cmp) cmp = builder->create_icmp_ge(l_val, r_val);
cur_value = builder->create_fcmp_gt(lvalue, rvalue);
else else
cur_value = builder->create_icmp_gt(lvalue, rvalue); cmp = builder->create_fcmp_ge(l_val, r_val);
break; break;
} case OP_GT:
case OP_GE: { if (is_int)
if (float_cmp) cmp = builder->create_icmp_gt(l_val, r_val);
cur_value = builder->create_fcmp_ge(lvalue, rvalue);
else else
cur_value = builder->create_icmp_ge(lvalue, rvalue); cmp = builder->create_fcmp_gt(l_val, r_val);
break; break;
} case OP_EQ:
case OP_EQ: { if (is_int)
if (float_cmp) cmp = builder->create_icmp_eq(l_val, r_val);
cur_value = builder->create_fcmp_eq(lvalue, rvalue);
else else
cur_value = builder->create_icmp_eq(lvalue, rvalue); cmp = builder->create_fcmp_eq(l_val, r_val);
break; break;
} case OP_NEQ:
case OP_NEQ: { if (is_int)
if (float_cmp) cmp = builder->create_icmp_ne(l_val, r_val);
cur_value = builder->create_fcmp_ne(lvalue, rvalue);
else else
cur_value = builder->create_icmp_ne(lvalue, rvalue); cmp = builder->create_fcmp_ne(l_val, r_val);
break; break;
}
} }
} else
node.additive_expression_l->accept(*this); tmp_val = builder->create_zext(cmp, INT32_T);
}
} }
// Done void
void CminusfBuilder::visit(ASTAdditiveExpression &node) { CminusfBuilder::visit(ASTAdditiveExpression &node) {
//!TODO: This function is empty now. if (node.additive_expression == nullptr) {
// Add some code here. node.term->accept(*this);
// } else {
if (node.additive_expression) {
node.additive_expression->accept(*this); node.additive_expression->accept(*this);
auto lvalue = cur_value; auto l_val = tmp_val;
node.term->accept(*this); node.term->accept(*this);
auto rvalue = cur_value; auto r_val = tmp_val;
// check type bool is_int = promote(&*builder, &l_val, &r_val);
biop_type_check(lvalue, rvalue, "addop");
bool float_type = lvalue->get_type()->is_float_type();
// now left and right is the same type
switch (node.op) { switch (node.op) {
case OP_PLUS: { case OP_PLUS:
if (float_type) if (is_int)
cur_value = builder->create_fadd(lvalue, rvalue); tmp_val = builder->create_iadd(l_val, r_val);
else else
cur_value = builder->create_iadd(lvalue, rvalue); tmp_val = builder->create_fadd(l_val, r_val);
break; break;
} case OP_MINUS:
case OP_MINUS: { if (is_int)
if (float_type) tmp_val = builder->create_isub(l_val, r_val);
cur_value = builder->create_fsub(lvalue, rvalue);
else else
cur_value = builder->create_isub(lvalue, rvalue); tmp_val = builder->create_fsub(l_val, r_val);
break; break;
}
} }
} else }
node.term->accept(*this);
} }
// Done void
void CminusfBuilder::visit(ASTTerm &node) { CminusfBuilder::visit(ASTTerm &node) {
//!TODO: This function is empty now. if (node.term == nullptr) {
// Add some code here. node.factor->accept(*this);
if (node.term) { } else {
node.term->accept(*this); node.term->accept(*this);
auto lvalue = cur_value; auto l_val = tmp_val;
node.factor->accept(*this); node.factor->accept(*this);
auto rvalue = cur_value; auto r_val = tmp_val;
// check type bool is_int = promote(&*builder, &l_val, &r_val);
biop_type_check(lvalue, rvalue, "mul");
bool float_type = lvalue->get_type()->is_float_type();
// now left and right is the same type
switch (node.op) { switch (node.op) {
case OP_MUL: { case OP_MUL:
if (float_type) if (is_int)
cur_value = builder->create_fmul(lvalue, rvalue); tmp_val = builder->create_imul(l_val, r_val);
else else
cur_value = builder->create_imul(lvalue, rvalue); tmp_val = builder->create_fmul(l_val, r_val);
break; break;
} case OP_DIV:
case OP_DIV: { if (is_int)
if (float_type) tmp_val = builder->create_isdiv(l_val, r_val);
cur_value = builder->create_fdiv(lvalue, rvalue);
else else
cur_value = builder->create_isdiv(lvalue, rvalue); tmp_val = builder->create_fdiv(l_val, r_val);
break; break;
}
} }
} else }
node.factor->accept(*this);
} }
// Done void
void CminusfBuilder::visit(ASTCall &node) { CminusfBuilder::visit(ASTCall &node) {
//!TODO: This function is empty now. auto fun = static_cast<Function *>(scope.find(node.id));
// Add some code here.
Function *func = static_cast<Function *>(scope.find(node.id));
std::vector<Value *> args; std::vector<Value *> args;
if (func == nullptr) auto param_type = fun->get_function_type()->param_begin();
error_exit("function " + node.id + " not declared"); for (auto &arg : node.args) {
if (node.args.size() != func->get_num_of_args()) arg->accept(*this);
error_exit("expect " + std::to_string(func->get_num_of_args()) + " params, but " + if (!tmp_val->get_type()->is_pointer_type() &&
std::to_string(node.args.size()) + " is given"); *param_type != tmp_val->get_type()) {
// check every argument if (tmp_val->get_type()->is_integer_type())
for (int i = 0; i != node.args.size(); ++i) { tmp_val = builder->create_sitofp(tmp_val, FLOAT_T);
// ith parameter's type else
Type *param_type = func->get_function_type()->get_param_type(i); tmp_val = builder->create_fptosi(tmp_val, INT32_T);
node.args[i]->accept(*this);
// type cast
if (not Type::is_eq_type(param_type, cur_value->get_type())) {
if (param_type->is_pointer_type()) {
// shouldn't need type cast for pointer, logically
if (param_type->get_pointer_element_type()->is_integer_type() or
param_type->get_pointer_element_type()->is_float_type())
error_exit("BUG HERE: ASTVar return value is not int* or float*");
else
error_exit("BUG HERE: function param needs weird pointer type");
} else if (param_type->is_integer_type() or param_type->is_float_type()) {
// need type cast between int and float
if (not cur_value->get_type()->is_integer_type() and not cur_value->get_type()->is_float_type())
error_exit("unexpected type cast!");
if (param_type->is_float_type())
cur_value = builder->create_sitofp(cur_value, FLOAT_T);
else if (param_type->is_integer_type())
if (cur_value->get_type()->is_integer_type())
cur_value = builder->create_zext(cur_value, INT32_T);
else
cur_value = builder->create_fptosi(cur_value, INT32_T);
else
error_exit("unexpected type cast!");
} else
error_exit("unexpected case when casting arguments for function call " + node.id);
} }
args.push_back(tmp_val);
// now cur_value fits the param type param_type++;
args.push_back(cur_value);
} }
cur_value = builder->create_call(func, args);
tmp_val = builder->create_call(static_cast<Function *>(fun), args);
} }
...@@ -10,7 +10,10 @@ ...@@ -10,7 +10,10 @@
#include <cassert> #include <cassert>
#include <vector> #include <vector>
Instruction::Instruction(Type *ty, OpID id, unsigned num_ops, BasicBlock *parent) Instruction::Instruction(Type *ty,
OpID id,
unsigned num_ops,
BasicBlock *parent)
: User(ty, "", num_ops), op_id_(id), num_ops_(num_ops), parent_(parent) { : User(ty, "", num_ops), op_id_(id), num_ops_(num_ops), parent_(parent) {
parent_->add_instruction(this); parent_->add_instruction(this);
} }
...@@ -18,56 +21,75 @@ Instruction::Instruction(Type *ty, OpID id, unsigned num_ops, BasicBlock *parent ...@@ -18,56 +21,75 @@ Instruction::Instruction(Type *ty, OpID id, unsigned num_ops, BasicBlock *parent
Instruction::Instruction(Type *ty, OpID id, unsigned num_ops) Instruction::Instruction(Type *ty, OpID id, unsigned num_ops)
: User(ty, "", num_ops), op_id_(id), num_ops_(num_ops), parent_(nullptr) {} : User(ty, "", num_ops), op_id_(id), num_ops_(num_ops), parent_(nullptr) {}
Function *Instruction::get_function() { return parent_->get_parent(); } Function *
Instruction::get_function() {
return parent_->get_parent();
}
Module *Instruction::get_module() { return parent_->get_module(); } Module *
Instruction::get_module() {
return parent_->get_module();
}
BinaryInst::BinaryInst(Type *ty, OpID id, Value *v1, Value *v2, BasicBlock *bb) : BaseInst<BinaryInst>(ty, id, 2, bb) { BinaryInst::BinaryInst(Type *ty, OpID id, Value *v1, Value *v2, BasicBlock *bb)
: BaseInst<BinaryInst>(ty, id, 2, bb) {
set_operand(0, v1); set_operand(0, v1);
set_operand(1, v2); set_operand(1, v2);
// assertValid(); // assertValid();
} }
void BinaryInst::assertValid() { void
BinaryInst::assertValid() {
assert(get_operand(0)->get_type()->is_integer_type()); assert(get_operand(0)->get_type()->is_integer_type());
assert(get_operand(1)->get_type()->is_integer_type()); assert(get_operand(1)->get_type()->is_integer_type());
assert(static_cast<IntegerType *>(get_operand(0)->get_type())->get_num_bits() == assert(
static_cast<IntegerType *>(get_operand(1)->get_type())->get_num_bits()); static_cast<IntegerType *>(get_operand(0)->get_type())
->get_num_bits() ==
static_cast<IntegerType *>(get_operand(1)->get_type())->get_num_bits());
} }
BinaryInst *BinaryInst::create_add(Value *v1, Value *v2, BasicBlock *bb, Module *m) { BinaryInst *
BinaryInst::create_add(Value *v1, Value *v2, BasicBlock *bb, Module *m) {
return create(Type::get_int32_type(m), Instruction::add, v1, v2, bb); return create(Type::get_int32_type(m), Instruction::add, v1, v2, bb);
} }
BinaryInst *BinaryInst::create_sub(Value *v1, Value *v2, BasicBlock *bb, Module *m) { BinaryInst *
BinaryInst::create_sub(Value *v1, Value *v2, BasicBlock *bb, Module *m) {
return create(Type::get_int32_type(m), Instruction::sub, v1, v2, bb); return create(Type::get_int32_type(m), Instruction::sub, v1, v2, bb);
} }
BinaryInst *BinaryInst::create_mul(Value *v1, Value *v2, BasicBlock *bb, Module *m) { BinaryInst *
BinaryInst::create_mul(Value *v1, Value *v2, BasicBlock *bb, Module *m) {
return create(Type::get_int32_type(m), Instruction::mul, v1, v2, bb); return create(Type::get_int32_type(m), Instruction::mul, v1, v2, bb);
} }
BinaryInst *BinaryInst::create_sdiv(Value *v1, Value *v2, BasicBlock *bb, Module *m) { BinaryInst *
BinaryInst::create_sdiv(Value *v1, Value *v2, BasicBlock *bb, Module *m) {
return create(Type::get_int32_type(m), Instruction::sdiv, v1, v2, bb); return create(Type::get_int32_type(m), Instruction::sdiv, v1, v2, bb);
} }
BinaryInst *BinaryInst::create_fadd(Value *v1, Value *v2, BasicBlock *bb, Module *m) { BinaryInst *
BinaryInst::create_fadd(Value *v1, Value *v2, BasicBlock *bb, Module *m) {
return create(Type::get_float_type(m), Instruction::fadd, v1, v2, bb); return create(Type::get_float_type(m), Instruction::fadd, v1, v2, bb);
} }
BinaryInst *BinaryInst::create_fsub(Value *v1, Value *v2, BasicBlock *bb, Module *m) { BinaryInst *
BinaryInst::create_fsub(Value *v1, Value *v2, BasicBlock *bb, Module *m) {
return create(Type::get_float_type(m), Instruction::fsub, v1, v2, bb); return create(Type::get_float_type(m), Instruction::fsub, v1, v2, bb);
} }
BinaryInst *BinaryInst::create_fmul(Value *v1, Value *v2, BasicBlock *bb, Module *m) { BinaryInst *
BinaryInst::create_fmul(Value *v1, Value *v2, BasicBlock *bb, Module *m) {
return create(Type::get_float_type(m), Instruction::fmul, v1, v2, bb); return create(Type::get_float_type(m), Instruction::fmul, v1, v2, bb);
} }
BinaryInst *BinaryInst::create_fdiv(Value *v1, Value *v2, BasicBlock *bb, Module *m) { BinaryInst *
BinaryInst::create_fdiv(Value *v1, Value *v2, BasicBlock *bb, Module *m) {
return create(Type::get_float_type(m), Instruction::fdiv, v1, v2, bb); return create(Type::get_float_type(m), Instruction::fdiv, v1, v2, bb);
} }
std::string BinaryInst::print() { std::string
BinaryInst::print() {
std::string instr_ir; std::string instr_ir;
instr_ir += "%"; instr_ir += "%";
instr_ir += this->get_name(); instr_ir += this->get_name();
...@@ -78,7 +100,8 @@ std::string BinaryInst::print() { ...@@ -78,7 +100,8 @@ std::string BinaryInst::print() {
instr_ir += " "; instr_ir += " ";
instr_ir += print_as_op(this->get_operand(0), false); instr_ir += print_as_op(this->get_operand(0), false);
instr_ir += ", "; instr_ir += ", ";
if (Type::is_eq_type(this->get_operand(0)->get_type(), this->get_operand(1)->get_type())) { if (Type::is_eq_type(this->get_operand(0)->get_type(),
this->get_operand(1)->get_type())) {
instr_ir += print_as_op(this->get_operand(1), false); instr_ir += print_as_op(this->get_operand(1), false);
} else { } else {
instr_ir += print_as_op(this->get_operand(1), true); instr_ir += print_as_op(this->get_operand(1), true);
...@@ -93,18 +116,27 @@ CmpInst::CmpInst(Type *ty, CmpOp op, Value *lhs, Value *rhs, BasicBlock *bb) ...@@ -93,18 +116,27 @@ CmpInst::CmpInst(Type *ty, CmpOp op, Value *lhs, Value *rhs, BasicBlock *bb)
// assertValid(); // assertValid();
} }
void CmpInst::assertValid() { void
CmpInst::assertValid() {
assert(get_operand(0)->get_type()->is_integer_type()); assert(get_operand(0)->get_type()->is_integer_type());
assert(get_operand(1)->get_type()->is_integer_type()); assert(get_operand(1)->get_type()->is_integer_type());
assert(static_cast<IntegerType *>(get_operand(0)->get_type())->get_num_bits() == assert(
static_cast<IntegerType *>(get_operand(1)->get_type())->get_num_bits()); static_cast<IntegerType *>(get_operand(0)->get_type())
} ->get_num_bits() ==
static_cast<IntegerType *>(get_operand(1)->get_type())->get_num_bits());
CmpInst *CmpInst::create_cmp(CmpOp op, Value *lhs, Value *rhs, BasicBlock *bb, Module *m) { }
CmpInst *
CmpInst::create_cmp(CmpOp op,
Value *lhs,
Value *rhs,
BasicBlock *bb,
Module *m) {
return create(m->get_int1_type(), op, lhs, rhs, bb); return create(m->get_int1_type(), op, lhs, rhs, bb);
} }
std::string CmpInst::print() { std::string
CmpInst::print() {
std::string instr_ir; std::string instr_ir;
instr_ir += "%"; instr_ir += "%";
instr_ir += this->get_name(); instr_ir += this->get_name();
...@@ -117,7 +149,8 @@ std::string CmpInst::print() { ...@@ -117,7 +149,8 @@ std::string CmpInst::print() {
instr_ir += " "; instr_ir += " ";
instr_ir += print_as_op(this->get_operand(0), false); instr_ir += print_as_op(this->get_operand(0), false);
instr_ir += ", "; instr_ir += ", ";
if (Type::is_eq_type(this->get_operand(0)->get_type(), this->get_operand(1)->get_type())) { if (Type::is_eq_type(this->get_operand(0)->get_type(),
this->get_operand(1)->get_type())) {
instr_ir += print_as_op(this->get_operand(1), false); instr_ir += print_as_op(this->get_operand(1), false);
} else { } else {
instr_ir += print_as_op(this->get_operand(1), true); instr_ir += print_as_op(this->get_operand(1), true);
...@@ -132,16 +165,23 @@ FCmpInst::FCmpInst(Type *ty, CmpOp op, Value *lhs, Value *rhs, BasicBlock *bb) ...@@ -132,16 +165,23 @@ FCmpInst::FCmpInst(Type *ty, CmpOp op, Value *lhs, Value *rhs, BasicBlock *bb)
// assertValid(); // assertValid();
} }
void FCmpInst::assert_valid() { void
FCmpInst::assert_valid() {
assert(get_operand(0)->get_type()->is_float_type()); assert(get_operand(0)->get_type()->is_float_type());
assert(get_operand(1)->get_type()->is_float_type()); assert(get_operand(1)->get_type()->is_float_type());
} }
FCmpInst *FCmpInst::create_fcmp(CmpOp op, Value *lhs, Value *rhs, BasicBlock *bb, Module *m) { FCmpInst *
FCmpInst::create_fcmp(CmpOp op,
Value *lhs,
Value *rhs,
BasicBlock *bb,
Module *m) {
return create(m->get_int1_type(), op, lhs, rhs, bb); return create(m->get_int1_type(), op, lhs, rhs, bb);
} }
std::string FCmpInst::print() { std::string
FCmpInst::print() {
std::string instr_ir; std::string instr_ir;
instr_ir += "%"; instr_ir += "%";
instr_ir += this->get_name(); instr_ir += this->get_name();
...@@ -154,7 +194,8 @@ std::string FCmpInst::print() { ...@@ -154,7 +194,8 @@ std::string FCmpInst::print() {
instr_ir += " "; instr_ir += " ";
instr_ir += print_as_op(this->get_operand(0), false); instr_ir += print_as_op(this->get_operand(0), false);
instr_ir += ","; instr_ir += ",";
if (Type::is_eq_type(this->get_operand(0)->get_type(), this->get_operand(1)->get_type())) { if (Type::is_eq_type(this->get_operand(0)->get_type(),
this->get_operand(1)->get_type())) {
instr_ir += print_as_op(this->get_operand(1), false); instr_ir += print_as_op(this->get_operand(1), false);
} else { } else {
instr_ir += print_as_op(this->get_operand(1), true); instr_ir += print_as_op(this->get_operand(1), true);
...@@ -163,7 +204,10 @@ std::string FCmpInst::print() { ...@@ -163,7 +204,10 @@ std::string FCmpInst::print() {
} }
CallInst::CallInst(Function *func, std::vector<Value *> args, BasicBlock *bb) CallInst::CallInst(Function *func, std::vector<Value *> args, BasicBlock *bb)
: BaseInst<CallInst>(func->get_return_type(), Instruction::call, args.size() + 1, bb) { : BaseInst<CallInst>(func->get_return_type(),
Instruction::call,
args.size() + 1,
bb) {
assert(func->get_num_of_args() == args.size()); assert(func->get_num_of_args() == args.size());
int num_ops = args.size() + 1; int num_ops = args.size() + 1;
set_operand(0, func); set_operand(0, func);
...@@ -172,13 +216,18 @@ CallInst::CallInst(Function *func, std::vector<Value *> args, BasicBlock *bb) ...@@ -172,13 +216,18 @@ CallInst::CallInst(Function *func, std::vector<Value *> args, BasicBlock *bb)
} }
} }
CallInst *CallInst::create(Function *func, std::vector<Value *> args, BasicBlock *bb) { CallInst *
CallInst::create(Function *func, std::vector<Value *> args, BasicBlock *bb) {
return BaseInst<CallInst>::create(func, args, bb); return BaseInst<CallInst>::create(func, args, bb);
} }
FunctionType *CallInst::get_function_type() const { return static_cast<FunctionType *>(get_operand(0)->get_type()); } FunctionType *
CallInst::get_function_type() const {
return static_cast<FunctionType *>(get_operand(0)->get_type());
}
std::string CallInst::print() { std::string
CallInst::print() {
std::string instr_ir; std::string instr_ir;
if (!this->is_void()) { if (!this->is_void()) {
instr_ir += "%"; instr_ir += "%";
...@@ -190,7 +239,8 @@ std::string CallInst::print() { ...@@ -190,7 +239,8 @@ std::string CallInst::print() {
instr_ir += this->get_function_type()->get_return_type()->print(); instr_ir += this->get_function_type()->get_return_type()->print();
instr_ir += " "; instr_ir += " ";
assert(dynamic_cast<Function *>(this->get_operand(0)) && "Wrong call operand function"); assert(dynamic_cast<Function *>(this->get_operand(0)) &&
"Wrong call operand function");
instr_ir += print_as_op(this->get_operand(0), false); instr_ir += print_as_op(this->get_operand(0), false);
instr_ir += "("; instr_ir += "(";
for (int i = 1; i < this->get_num_operand(); i++) { for (int i = 1; i < this->get_num_operand(); i++) {
...@@ -204,19 +254,32 @@ std::string CallInst::print() { ...@@ -204,19 +254,32 @@ std::string CallInst::print() {
return instr_ir; return instr_ir;
} }
BranchInst::BranchInst(Value *cond, BasicBlock *if_true, BasicBlock *if_false, BasicBlock *bb) BranchInst::BranchInst(Value *cond,
: BaseInst<BranchInst>(Type::get_void_type(if_true->get_module()), Instruction::br, 3, bb) { BasicBlock *if_true,
BasicBlock *if_false,
BasicBlock *bb)
: BaseInst<BranchInst>(Type::get_void_type(if_true->get_module()),
Instruction::br,
3,
bb) {
set_operand(0, cond); set_operand(0, cond);
set_operand(1, if_true); set_operand(1, if_true);
set_operand(2, if_false); set_operand(2, if_false);
} }
BranchInst::BranchInst(BasicBlock *if_true, BasicBlock *bb) BranchInst::BranchInst(BasicBlock *if_true, BasicBlock *bb)
: BaseInst<BranchInst>(Type::get_void_type(if_true->get_module()), Instruction::br, 1, bb) { : BaseInst<BranchInst>(Type::get_void_type(if_true->get_module()),
Instruction::br,
1,
bb) {
set_operand(0, if_true); set_operand(0, if_true);
} }
BranchInst *BranchInst::create_cond_br(Value *cond, BasicBlock *if_true, BasicBlock *if_false, BasicBlock *bb) { BranchInst *
BranchInst::create_cond_br(Value *cond,
BasicBlock *if_true,
BasicBlock *if_false,
BasicBlock *bb) {
if_true->add_pre_basic_block(bb); if_true->add_pre_basic_block(bb);
if_false->add_pre_basic_block(bb); if_false->add_pre_basic_block(bb);
bb->add_succ_basic_block(if_false); bb->add_succ_basic_block(if_false);
...@@ -225,16 +288,21 @@ BranchInst *BranchInst::create_cond_br(Value *cond, BasicBlock *if_true, BasicBl ...@@ -225,16 +288,21 @@ BranchInst *BranchInst::create_cond_br(Value *cond, BasicBlock *if_true, BasicBl
return create(cond, if_true, if_false, bb); return create(cond, if_true, if_false, bb);
} }
BranchInst *BranchInst::create_br(BasicBlock *if_true, BasicBlock *bb) { BranchInst *
BranchInst::create_br(BasicBlock *if_true, BasicBlock *bb) {
if_true->add_pre_basic_block(bb); if_true->add_pre_basic_block(bb);
bb->add_succ_basic_block(if_true); bb->add_succ_basic_block(if_true);
return create(if_true, bb); return create(if_true, bb);
} }
bool BranchInst::is_cond_br() const { return get_num_operand() == 3; } bool
BranchInst::is_cond_br() const {
return get_num_operand() == 3;
}
std::string BranchInst::print() { std::string
BranchInst::print() {
std::string instr_ir; std::string instr_ir;
instr_ir += this->get_module()->get_instr_op_name(this->get_instr_type()); instr_ir += this->get_module()->get_instr_op_name(this->get_instr_type());
instr_ir += " "; instr_ir += " ";
...@@ -250,20 +318,36 @@ std::string BranchInst::print() { ...@@ -250,20 +318,36 @@ std::string BranchInst::print() {
} }
ReturnInst::ReturnInst(Value *val, BasicBlock *bb) ReturnInst::ReturnInst(Value *val, BasicBlock *bb)
: BaseInst<ReturnInst>(Type::get_void_type(bb->get_module()), Instruction::ret, 1, bb) { : BaseInst<ReturnInst>(Type::get_void_type(bb->get_module()),
Instruction::ret,
1,
bb) {
set_operand(0, val); set_operand(0, val);
} }
ReturnInst::ReturnInst(BasicBlock *bb) ReturnInst::ReturnInst(BasicBlock *bb)
: BaseInst<ReturnInst>(Type::get_void_type(bb->get_module()), Instruction::ret, 0, bb) {} : BaseInst<ReturnInst>(Type::get_void_type(bb->get_module()),
Instruction::ret,
0,
bb) {}
ReturnInst *ReturnInst::create_ret(Value *val, BasicBlock *bb) { return create(val, bb); } ReturnInst *
ReturnInst::create_ret(Value *val, BasicBlock *bb) {
return create(val, bb);
}
ReturnInst *ReturnInst::create_void_ret(BasicBlock *bb) { return create(bb); } ReturnInst *
ReturnInst::create_void_ret(BasicBlock *bb) {
return create(bb);
}
bool ReturnInst::is_void_ret() const { return get_num_operand() == 0; } bool
ReturnInst::is_void_ret() const {
return get_num_operand() == 0;
}
std::string ReturnInst::print() { std::string
ReturnInst::print() {
std::string instr_ir; std::string instr_ir;
instr_ir += this->get_module()->get_instr_op_name(this->get_instr_type()); instr_ir += this->get_module()->get_instr_op_name(this->get_instr_type());
instr_ir += " "; instr_ir += " ";
...@@ -278,7 +362,9 @@ std::string ReturnInst::print() { ...@@ -278,7 +362,9 @@ std::string ReturnInst::print() {
return instr_ir; return instr_ir;
} }
GetElementPtrInst::GetElementPtrInst(Value *ptr, std::vector<Value *> idxs, BasicBlock *bb) GetElementPtrInst::GetElementPtrInst(Value *ptr,
std::vector<Value *> idxs,
BasicBlock *bb)
: BaseInst<GetElementPtrInst>(PointerType::get(get_element_type(ptr, idxs)), : BaseInst<GetElementPtrInst>(PointerType::get(get_element_type(ptr, idxs)),
Instruction::getelementptr, Instruction::getelementptr,
1 + idxs.size(), 1 + idxs.size(),
...@@ -290,10 +376,12 @@ GetElementPtrInst::GetElementPtrInst(Value *ptr, std::vector<Value *> idxs, Basi ...@@ -290,10 +376,12 @@ GetElementPtrInst::GetElementPtrInst(Value *ptr, std::vector<Value *> idxs, Basi
element_ty_ = get_element_type(ptr, idxs); element_ty_ = get_element_type(ptr, idxs);
} }
Type *GetElementPtrInst::get_element_type(Value *ptr, std::vector<Value *> idxs) { Type *
GetElementPtrInst::get_element_type(Value *ptr, std::vector<Value *> idxs) {
Type *ty = ptr->get_type()->get_pointer_element_type(); Type *ty = ptr->get_type()->get_pointer_element_type();
assert("GetElementPtrInst ptr is wrong type" && assert(
(ty->is_array_type() || ty->is_integer_type() || ty->is_float_type())); "GetElementPtrInst ptr is wrong type" &&
(ty->is_array_type() || ty->is_integer_type() || ty->is_float_type()));
if (ty->is_array_type()) { if (ty->is_array_type()) {
ArrayType *arr_ty = static_cast<ArrayType *>(ty); ArrayType *arr_ty = static_cast<ArrayType *>(ty);
for (int i = 1; i < idxs.size(); i++) { for (int i = 1; i < idxs.size(); i++) {
...@@ -309,13 +397,20 @@ Type *GetElementPtrInst::get_element_type(Value *ptr, std::vector<Value *> idxs) ...@@ -309,13 +397,20 @@ Type *GetElementPtrInst::get_element_type(Value *ptr, std::vector<Value *> idxs)
return ty; return ty;
} }
Type *GetElementPtrInst::get_element_type() const { return element_ty_; } Type *
GetElementPtrInst::get_element_type() const {
return element_ty_;
}
GetElementPtrInst *GetElementPtrInst::create_gep(Value *ptr, std::vector<Value *> idxs, BasicBlock *bb) { GetElementPtrInst *
GetElementPtrInst::create_gep(Value *ptr,
std::vector<Value *> idxs,
BasicBlock *bb) {
return create(ptr, idxs, bb); return create(ptr, idxs, bb);
} }
std::string GetElementPtrInst::print() { std::string
GetElementPtrInst::print() {
std::string instr_ir; std::string instr_ir;
instr_ir += "%"; instr_ir += "%";
instr_ir += this->get_name(); instr_ir += this->get_name();
...@@ -323,7 +418,8 @@ std::string GetElementPtrInst::print() { ...@@ -323,7 +418,8 @@ std::string GetElementPtrInst::print() {
instr_ir += this->get_module()->get_instr_op_name(this->get_instr_type()); instr_ir += this->get_module()->get_instr_op_name(this->get_instr_type());
instr_ir += " "; instr_ir += " ";
assert(this->get_operand(0)->get_type()->is_pointer_type()); assert(this->get_operand(0)->get_type()->is_pointer_type());
instr_ir += this->get_operand(0)->get_type()->get_pointer_element_type()->print(); instr_ir +=
this->get_operand(0)->get_type()->get_pointer_element_type()->print();
instr_ir += ", "; instr_ir += ", ";
for (int i = 0; i < this->get_num_operand(); i++) { for (int i = 0; i < this->get_num_operand(); i++) {
if (i > 0) if (i > 0)
...@@ -336,14 +432,21 @@ std::string GetElementPtrInst::print() { ...@@ -336,14 +432,21 @@ std::string GetElementPtrInst::print() {
} }
StoreInst::StoreInst(Value *val, Value *ptr, BasicBlock *bb) StoreInst::StoreInst(Value *val, Value *ptr, BasicBlock *bb)
: BaseInst<StoreInst>(Type::get_void_type(bb->get_module()), Instruction::store, 2, bb) { : BaseInst<StoreInst>(Type::get_void_type(bb->get_module()),
Instruction::store,
2,
bb) {
set_operand(0, val); set_operand(0, val);
set_operand(1, ptr); set_operand(1, ptr);
} }
StoreInst *StoreInst::create_store(Value *val, Value *ptr, BasicBlock *bb) { return create(val, ptr, bb); } StoreInst *
StoreInst::create_store(Value *val, Value *ptr, BasicBlock *bb) {
return create(val, ptr, bb);
}
std::string StoreInst::print() { std::string
StoreInst::print() {
std::string instr_ir; std::string instr_ir;
instr_ir += this->get_module()->get_instr_op_name(this->get_instr_type()); instr_ir += this->get_module()->get_instr_op_name(this->get_instr_type());
instr_ir += " "; instr_ir += " ";
...@@ -355,19 +458,27 @@ std::string StoreInst::print() { ...@@ -355,19 +458,27 @@ std::string StoreInst::print() {
return instr_ir; return instr_ir;
} }
LoadInst::LoadInst(Type *ty, Value *ptr, BasicBlock *bb) : BaseInst<LoadInst>(ty, Instruction::load, 1, bb) { LoadInst::LoadInst(Type *ty, Value *ptr, BasicBlock *bb)
: BaseInst<LoadInst>(ty, Instruction::load, 1, bb) {
assert(ptr->get_type()->is_pointer_type()); assert(ptr->get_type()->is_pointer_type());
assert(ty == static_cast<PointerType *>(ptr->get_type())->get_element_type()); assert(ty ==
static_cast<PointerType *>(ptr->get_type())->get_element_type());
set_operand(0, ptr); set_operand(0, ptr);
} }
LoadInst *LoadInst::create_load(Type *ty, Value *ptr, BasicBlock *bb) { return create(ty, ptr, bb); } LoadInst *
LoadInst::create_load(Type *ty, Value *ptr, BasicBlock *bb) {
return create(ty, ptr, bb);
}
Type *LoadInst::get_load_type() const { Type *
return static_cast<PointerType *>(get_operand(0)->get_type())->get_element_type(); LoadInst::get_load_type() const {
return static_cast<PointerType *>(get_operand(0)->get_type())
->get_element_type();
} }
std::string LoadInst::print() { std::string
LoadInst::print() {
std::string instr_ir; std::string instr_ir;
instr_ir += "%"; instr_ir += "%";
instr_ir += this->get_name(); instr_ir += this->get_name();
...@@ -375,7 +486,8 @@ std::string LoadInst::print() { ...@@ -375,7 +486,8 @@ std::string LoadInst::print() {
instr_ir += this->get_module()->get_instr_op_name(this->get_instr_type()); instr_ir += this->get_module()->get_instr_op_name(this->get_instr_type());
instr_ir += " "; instr_ir += " ";
assert(this->get_operand(0)->get_type()->is_pointer_type()); assert(this->get_operand(0)->get_type()->is_pointer_type());
instr_ir += this->get_operand(0)->get_type()->get_pointer_element_type()->print(); instr_ir +=
this->get_operand(0)->get_type()->get_pointer_element_type()->print();
instr_ir += ","; instr_ir += ",";
instr_ir += " "; instr_ir += " ";
instr_ir += print_as_op(this->get_operand(0), true); instr_ir += print_as_op(this->get_operand(0), true);
...@@ -383,13 +495,21 @@ std::string LoadInst::print() { ...@@ -383,13 +495,21 @@ std::string LoadInst::print() {
} }
AllocaInst::AllocaInst(Type *ty, BasicBlock *bb) AllocaInst::AllocaInst(Type *ty, BasicBlock *bb)
: BaseInst<AllocaInst>(PointerType::get(ty), Instruction::alloca, 0, bb), alloca_ty_(ty) {} : BaseInst<AllocaInst>(PointerType::get(ty), Instruction::alloca, 0, bb)
, alloca_ty_(ty) {}
AllocaInst *AllocaInst::create_alloca(Type *ty, BasicBlock *bb) { return create(ty, bb); } AllocaInst *
AllocaInst::create_alloca(Type *ty, BasicBlock *bb) {
return create(ty, bb);
}
Type *AllocaInst::get_alloca_type() const { return alloca_ty_; } Type *
AllocaInst::get_alloca_type() const {
return alloca_ty_;
}
std::string AllocaInst::print() { std::string
AllocaInst::print() {
std::string instr_ir; std::string instr_ir;
instr_ir += "%"; instr_ir += "%";
instr_ir += this->get_name(); instr_ir += this->get_name();
...@@ -400,15 +520,23 @@ std::string AllocaInst::print() { ...@@ -400,15 +520,23 @@ std::string AllocaInst::print() {
return instr_ir; return instr_ir;
} }
ZextInst::ZextInst(OpID op, Value *val, Type *ty, BasicBlock *bb) : BaseInst<ZextInst>(ty, op, 1, bb), dest_ty_(ty) { ZextInst::ZextInst(OpID op, Value *val, Type *ty, BasicBlock *bb)
: BaseInst<ZextInst>(ty, op, 1, bb), dest_ty_(ty) {
set_operand(0, val); set_operand(0, val);
} }
ZextInst *ZextInst::create_zext(Value *val, Type *ty, BasicBlock *bb) { return create(Instruction::zext, val, ty, bb); } ZextInst *
ZextInst::create_zext(Value *val, Type *ty, BasicBlock *bb) {
return create(Instruction::zext, val, ty, bb);
}
Type *ZextInst::get_dest_type() const { return dest_ty_; } Type *
ZextInst::get_dest_type() const {
return dest_ty_;
}
std::string ZextInst::print() { std::string
ZextInst::print() {
std::string instr_ir; std::string instr_ir;
instr_ir += "%"; instr_ir += "%";
instr_ir += this->get_name(); instr_ir += this->get_name();
...@@ -428,13 +556,18 @@ FpToSiInst::FpToSiInst(OpID op, Value *val, Type *ty, BasicBlock *bb) ...@@ -428,13 +556,18 @@ FpToSiInst::FpToSiInst(OpID op, Value *val, Type *ty, BasicBlock *bb)
set_operand(0, val); set_operand(0, val);
} }
FpToSiInst *FpToSiInst::create_fptosi(Value *val, Type *ty, BasicBlock *bb) { FpToSiInst *
FpToSiInst::create_fptosi(Value *val, Type *ty, BasicBlock *bb) {
return create(Instruction::fptosi, val, ty, bb); return create(Instruction::fptosi, val, ty, bb);
} }
Type *FpToSiInst::get_dest_type() const { return dest_ty_; } Type *
FpToSiInst::get_dest_type() const {
return dest_ty_;
}
std::string FpToSiInst::print() { std::string
FpToSiInst::print() {
std::string instr_ir; std::string instr_ir;
instr_ir += "%"; instr_ir += "%";
instr_ir += this->get_name(); instr_ir += this->get_name();
...@@ -454,13 +587,18 @@ SiToFpInst::SiToFpInst(OpID op, Value *val, Type *ty, BasicBlock *bb) ...@@ -454,13 +587,18 @@ SiToFpInst::SiToFpInst(OpID op, Value *val, Type *ty, BasicBlock *bb)
set_operand(0, val); set_operand(0, val);
} }
SiToFpInst *SiToFpInst::create_sitofp(Value *val, Type *ty, BasicBlock *bb) { SiToFpInst *
SiToFpInst::create_sitofp(Value *val, Type *ty, BasicBlock *bb) {
return create(Instruction::sitofp, val, ty, bb); return create(Instruction::sitofp, val, ty, bb);
} }
Type *SiToFpInst::get_dest_type() const { return dest_ty_; } Type *
SiToFpInst::get_dest_type() const {
return dest_ty_;
}
std::string SiToFpInst::print() { std::string
SiToFpInst::print() {
std::string instr_ir; std::string instr_ir;
instr_ir += "%"; instr_ir += "%";
instr_ir += this->get_name(); instr_ir += this->get_name();
...@@ -475,7 +613,11 @@ std::string SiToFpInst::print() { ...@@ -475,7 +613,11 @@ std::string SiToFpInst::print() {
return instr_ir; return instr_ir;
} }
PhiInst::PhiInst(OpID op, std::vector<Value *> vals, std::vector<BasicBlock *> val_bbs, Type *ty, BasicBlock *bb) PhiInst::PhiInst(OpID op,
std::vector<Value *> vals,
std::vector<BasicBlock *> val_bbs,
Type *ty,
BasicBlock *bb)
: BaseInst<PhiInst>(ty, op, 2 * vals.size()) { : BaseInst<PhiInst>(ty, op, 2 * vals.size()) {
for (int i = 0; i < vals.size(); i++) { for (int i = 0; i < vals.size(); i++) {
set_operand(2 * i, vals[i]); set_operand(2 * i, vals[i]);
...@@ -484,13 +626,15 @@ PhiInst::PhiInst(OpID op, std::vector<Value *> vals, std::vector<BasicBlock *> v ...@@ -484,13 +626,15 @@ PhiInst::PhiInst(OpID op, std::vector<Value *> vals, std::vector<BasicBlock *> v
this->set_parent(bb); this->set_parent(bb);
} }
PhiInst *PhiInst::create_phi(Type *ty, BasicBlock *bb) { PhiInst *
PhiInst::create_phi(Type *ty, BasicBlock *bb) {
std::vector<Value *> vals; std::vector<Value *> vals;
std::vector<BasicBlock *> val_bbs; std::vector<BasicBlock *> val_bbs;
return create(Instruction::phi, vals, val_bbs, ty, bb); return create(Instruction::phi, vals, val_bbs, ty, bb);
} }
std::string PhiInst::print() { std::string
PhiInst::print() {
std::string instr_ir; std::string instr_ir;
instr_ir += "%"; instr_ir += "%";
instr_ir += this->get_name(); instr_ir += this->get_name();
...@@ -508,9 +652,12 @@ std::string PhiInst::print() { ...@@ -508,9 +652,12 @@ std::string PhiInst::print() {
instr_ir += print_as_op(this->get_operand(2 * i + 1), false); instr_ir += print_as_op(this->get_operand(2 * i + 1), false);
instr_ir += " ]"; instr_ir += " ]";
} }
if (this->get_num_operand() / 2 < this->get_parent()->get_pre_basic_blocks().size()) { if (this->get_num_operand() / 2 <
this->get_parent()->get_pre_basic_blocks().size()) {
for (auto pre_bb : this->get_parent()->get_pre_basic_blocks()) { for (auto pre_bb : this->get_parent()->get_pre_basic_blocks()) {
if (std::find(this->get_operands().begin(), this->get_operands().end(), static_cast<Value *>(pre_bb)) == if (std::find(this->get_operands().begin(),
this->get_operands().end(),
static_cast<Value *>(pre_bb)) ==
this->get_operands().end()) { this->get_operands().end()) {
// find a pre_bb is not in phi // find a pre_bb is not in phi
instr_ir += ", [ undef, " + print_as_op(pre_bb, false) + " ]"; instr_ir += ", [ undef, " + print_as_op(pre_bb, false) + " ]";
......
...@@ -209,6 +209,17 @@ print_partitions(const GVN::partitions &p) { ...@@ -209,6 +209,17 @@ print_partitions(const GVN::partitions &p) {
} }
} // namespace utils } // namespace utils
GVN::partitions
GVN::join_helper(BasicBlock *pre1, BasicBlock *pre2) {
assert(not _TOP[pre1] or not _TOP[pre2] && "should flow here, not jump");
if (_TOP[pre1])
return pout_[pre2];
else if (_TOP[pre2])
return pout_[pre1];
return join(pout_[pre1], pout_[pre2]);
}
GVN::partitions GVN::partitions
GVN::join(const partitions &P1, const partitions &P2) { GVN::join(const partitions &P1, const partitions &P2) {
// TODO: do intersection pair-wise // TODO: do intersection pair-wise
...@@ -228,7 +239,6 @@ std::shared_ptr<CongruenceClass> ...@@ -228,7 +239,6 @@ std::shared_ptr<CongruenceClass>
GVN::intersect(std::shared_ptr<CongruenceClass> ci, GVN::intersect(std::shared_ptr<CongruenceClass> ci,
std::shared_ptr<CongruenceClass> cj) { std::shared_ptr<CongruenceClass> cj) {
// TODO // TODO
// If no common members, return null
auto c = createCongruenceClass(); auto c = createCongruenceClass();
std::set<Value *> intersection; std::set<Value *> intersection;
...@@ -240,10 +250,21 @@ GVN::intersect(std::shared_ptr<CongruenceClass> ci, ...@@ -240,10 +250,21 @@ GVN::intersect(std::shared_ptr<CongruenceClass> ci,
c->members_ = intersection; c->members_ = intersection;
if (ci->index_ == cj->index_) if (ci->index_ == cj->index_)
c->index_ = ci->index_; c->index_ = ci->index_;
if (ci->leader_ == cj->leader_) if (ci->value_expr_ == cj->value_expr_)
c->leader_ = cj->leader_; c->value_expr_ = ci->value_expr_;
/* if (*ci == *cj) if (ci->value_phi_ and cj->value_phi_ and
* return ci; */ *ci->value_phi_ == *cj->value_phi_)
c->value_phi_ = ci->value_phi_;
// if (c->members_.size() or c->value_expr_ or c->value_phi_) // not empty
// ??
// What if the ve is nullptr?
if (c->members_.size()) // not empty
if (c->index_ == 0) {
c->index_ = new_number();
c->value_phi_ =
PhiExpression::create(ci->value_expr_, cj->value_expr_);
}
return c; return c;
} }
...@@ -251,30 +272,31 @@ GVN::intersect(std::shared_ptr<CongruenceClass> ci, ...@@ -251,30 +272,31 @@ GVN::intersect(std::shared_ptr<CongruenceClass> ci,
void void
GVN::detectEquivalences() { GVN::detectEquivalences() {
bool changed; bool changed;
std::cout << "all the instruction address:" << std::endl;
for (auto &bb : func_->get_basic_blocks()) {
for (auto &instr : bb.get_instructions())
std::cout << &instr << "\t" << instr.print() << std::endl;
}
// initialize pout with top // initialize pout with top
for (auto &bb : func_->get_basic_blocks()) { for (auto &bb : func_->get_basic_blocks()) {
// pin_[&bb].clear(); _TOP[&bb] = true;
// pout_[&bb].clear();
for (auto &instr : bb.get_instructions())
_TOP[&instr] = true;
} }
// modify entry block
auto Entry = func_->get_entry_block(); auto Entry = func_->get_entry_block();
_TOP[&*Entry->get_instructions().begin()] = false; _TOP[Entry] = false;
pin_[Entry].clear(); pin_[Entry].clear();
pout_[Entry].clear(); // pout_[Entry] = transferFunction(Entry); pout_[Entry] = transferFunction(Entry);
// iterate until converge // iterate until converge
do { do {
changed = false; changed = false;
// see the pseudo code in documentation for (auto &_bb : func_->get_basic_blocks()) {
for (auto &_bb :
func_->get_basic_blocks()) { // you might need to visit the
// blocks in depth-first order
auto bb = &_bb; auto bb = &_bb;
// get PIN of bb by predecessor(s) if (bb == Entry)
continue;
// get PIN of bb from predecessor(s)
auto pre_bbs_ = bb->get_pre_basic_blocks(); auto pre_bbs_ = bb->get_pre_basic_blocks();
if (bb != Entry) { if (bb != Entry) {
// only update PIN for blocks that are not Entry // only update PIN for blocks that are not Entry
...@@ -283,12 +305,12 @@ GVN::detectEquivalences() { ...@@ -283,12 +305,12 @@ GVN::detectEquivalences() {
case 2: { case 2: {
auto pre_1 = *pre_bbs_.begin(); auto pre_1 = *pre_bbs_.begin();
auto pre_2 = *(++pre_bbs_.begin()); auto pre_2 = *(++pre_bbs_.begin());
pin_[bb] = join(pin_[pre_1], pin_[pre_2]); pin_[bb] = join_helper(pre_1, pre_2);
break; break;
} }
case 1: { case 1: {
auto pre = *(pre_bbs_.begin()); auto pre = *(pre_bbs_.begin());
pin_[bb] = clone(pin_[pre]); pin_[bb] = pout_[pre];
break; break;
} }
default: default:
...@@ -297,82 +319,246 @@ GVN::detectEquivalences() { ...@@ -297,82 +319,246 @@ GVN::detectEquivalences() {
abort(); abort();
} }
} }
auto part = pin_[bb]; auto part = transferFunction(bb);
// iterate through all instructions in the block
for (auto &instr : bb->get_instructions()) {
// ??
if (not instr.is_phi())
part = transferFunction(&instr, instr.get_operand(1), part);
}
// and the phi instruction in all the successors
for (auto succ : bb->get_succ_basic_blocks()) {
for (auto &instr : succ->get_instructions())
if (instr.is_phi()) {
Instruction *pretend;
// ??
part = transferFunction(
pretend, instr.get_operand(1), part);
}
}
// check changes in pout // check changes in pout
changed |= not(part == pout_[bb]); changed |= not(part == pout_[bb]);
pout_[bb] = part; pout_[bb] = part;
_TOP[bb] = false;
} }
} while (changed); } while (changed);
} }
shared_ptr<Expression> shared_ptr<Expression>
GVN::valueExpr(Instruction *instr) { GVN::valueExpr(Instruction *instr, partitions *part) {
// TODO // TODO
return {}; // ?? should use part?
std::string err{"Undefined"};
std::cout << instr->print() << std::endl;
if (instr->isBinary() or instr->is_cmp() or instr->is_fcmp()) {
auto op1 = instr->get_operand(0);
auto op2 = instr->get_operand(1);
auto op1_const = dynamic_cast<Constant *>(op1);
auto op2_const = dynamic_cast<Constant *>(op2);
if (op1_const and op2_const) {
// both are constant number, so:
// constant fold!
return ConstantExpression::create(
folder_->compute(instr, op1_const, op2_const));
} else {
// both none constant
auto op1_instr = dynamic_cast<Instruction *>(op1);
auto op2_instr = dynamic_cast<Instruction *>(op2);
assert((op1_instr or op1_const) and (op2_instr or op2_const) &&
"must be this case");
return BinaryExpression::create(
instr->get_instr_type(),
(op1_const ? ConstantExpression::create(op1_const)
: valueExpr(op1_instr)),
(op2_const ? ConstantExpression::create(op2_const)
: valueExpr(op2_instr)));
}
} else if (instr->is_phi()) {
err = "phi";
} else if (instr->is_fp2si() or instr->is_si2fp() or instr->is_zext()) {
auto op = instr->get_operand(0);
auto op_const = dynamic_cast<Constant *>(op);
auto op_instr = dynamic_cast<Instruction *>(op);
assert(op_instr or op_const);
// get dest type
auto instr_fp2si = dynamic_cast<FpToSiInst *>(instr);
auto instr_si2fp = dynamic_cast<SiToFpInst *>(instr);
auto instr_zext = dynamic_cast<ZextInst *>(instr);
Type *dest_type = nullptr;
if (instr_fp2si)
dest_type = instr_fp2si->get_dest_type();
else if (instr_si2fp)
dest_type = instr_si2fp->get_dest_type();
else if (instr_zext)
dest_type = instr_zext->get_dest_type();
else
err = "cast";
if (dest_type) {
if (op_const)
return ConstantExpression::create(
folder_->compute(instr, op_const));
else
return CastExpression::create(
instr->get_instr_type(), valueExpr(op_instr), dest_type);
}
} else if (instr->is_gep()) {
auto operands = instr->get_operands();
auto ptr = operands[0];
std::vector<std::shared_ptr<Expression>> idxs;
// check for base address
assert(not dynamic_cast<Constant *>(ptr) and
dynamic_cast<Instruction *>(ptr) &&
"base address should only be from instruction");
// set idxes
for (int i = 1; i < operands.size(); i++) {
if (dynamic_cast<Constant *>(operands[i]))
idxs.push_back(ConstantExpression::create(
dynamic_cast<Constant *>(operands[i])));
else {
assert(dynamic_cast<Instruction *>(operands[i]));
idxs.push_back(
valueExpr(dynamic_cast<Instruction *>(operands[i])));
}
}
return GEPExpression::create(
valueExpr(dynamic_cast<Instruction *>(ptr)), idxs);
} else if (instr->is_load() or instr->is_alloca() or instr->is_call()) {
return UniqueExpression::create(instr);
}
std::cerr << "Undefined case: " << err << std::endl;
abort();
} }
// instruction of the form `x = e`, mostly x is just e (SSA), // instruction of the form `x = e`, mostly x is just e (SSA),
// but for copy stmt x is a phi instruction in the successor. // but for copy stmt x is a phi instruction in the successor.
// Phi values (not copy stmt) should be handled in detectEquiv // Phi values (not copy stmt) should be handled in detectEquiv
//
// assert the x is an instruction that can generate a new value
//
/// \param bb basic block in which the transfer function is called /// \param bb basic block in which the transfer function is called
GVN::partitions GVN::partitions
GVN::transferFunction(Instruction *x, Value *e, partitions pin) { GVN::transferFunction(Instruction *x, Value *e, partitions pin) {
partitions pout = clone(pin); partitions pout = pin;
// TODO: deal with copy-stmt case
// ?? deal with copy statement
auto e_instr = dynamic_cast<Instruction *>(e);
auto e_const = dynamic_cast<Constant *>(e);
assert((not e or e_instr or e_const) &&
"A value must be from an instruction or constant");
// erase the old record for x
std::set<Value *>::iterator it;
for (auto c : pin)
if ((it = std::find(c->members_.begin(), c->members_.end(), x)) !=
c->members_.end()) {
c->members_.erase(it);
}
// TODO: get different ValueExpr by Instruction::OpID, modify pout // TODO: get different ValueExpr by Instruction::OpID, modify pout
std::set<Value *>::iterator iter; // ??
// get ve and vpf
shared_ptr<Expression> ve;
if (e) {
if (e_const)
ve = ConstantExpression::create(e_const);
else
ve = valueExpr(e_instr, &pin);
} else
ve = valueExpr(x, &pin);
auto vpf = valuePhiFunc(ve, curr_bb);
for (auto c : pout) { for (auto c : pout) {
if ((iter = std::find(c->members_.begin(), c->members_.end(), x)) != if (ve == c->value_expr_ or (vpf and vpf == c->value_phi_)) {
c->members_.end()) { c->value_expr_ = ve;
// static_cast<Value *>(x))) != c->members_.end()) { c->members_.insert(x);
c->members_.erase(iter); } else {
auto c = createCongruenceClass(new_number());
c->members_.insert(x);
c->value_expr_ = ve;
c->value_phi_ = vpf;
pout.insert(c);
} }
} }
auto ve = valueExpr(x);
auto vpf = valuePhiFunc(ve, pin);
/* pout.insert({}); /* // first version: ignore ve and vpf
* auto c = CongruenceClass(new_number()); * // and only update index, leader and members
* c.leader_ = e; */ * auto c = createCongruenceClass(new_number());
* c->leader_ = x;
* c->members_.insert(x);
* pout.insert(c); */
return pout; return pout;
} }
/*
* read the pin for the block and then execute transferFunction() for all
* instructions inside.
*/
GVN::partitions GVN::partitions
GVN::transferFunction(BasicBlock *bb) { GVN::transferFunction(BasicBlock *bb) {
partitions pout = clone(pin_[bb]); curr_bb = bb;
// ?? int res;
return pout; auto part = pin_[bb];
/* LOG_INFO << "transferFunction(bb=" << bb->get_name() << ")\n";
* LOG_INFO << "pin:\n";
* utils::print_partitions(pin_[bb]);
* LOG_INFO << "pout before:\n";
* utils::print_partitions(pout_[bb]); */
// iterate through all instructions in the block
for (auto &instr : bb->get_instructions()) {
// ?? what about orther instructions? Are they all ok?
if (not instr.is_phi() and not instr.is_void())
part = transferFunction(&instr, nullptr, part);
}
// and the phi instruction in all the successors
for (auto succ : bb->get_succ_basic_blocks()) {
for (auto &instr : succ->get_instructions()) {
if (instr.is_phi()) {
if ((res = pretend_copy_stmt(&instr, bb) == -1))
continue;
part = transferFunction(&instr, instr.get_operand(res), part);
}
}
}
/* LOG_INFO << "pout after:\n";
* utils::print_partitions(part);
* std::cout << std::endl; */
return part;
} }
shared_ptr<PhiExpression> shared_ptr<PhiExpression>
GVN::valuePhiFunc(shared_ptr<Expression> ve, const partitions &P) { GVN::valuePhiFunc(shared_ptr<Expression> ve, BasicBlock *bb) {
// TODO // TODO
return {}; if (ve->get_expr_type() != Expression::e_bin)
return nullptr;
auto ve_bin = static_cast<BinaryExpression *>(ve.get());
if (ve_bin->lhs_->get_expr_type() != Expression::e_phi or
ve_bin->rhs_->get_expr_type() != Expression::e_phi)
return nullptr;
// get 2 phi expressions
auto lhs = static_cast<PhiExpression *>(ve_bin->lhs_.get());
auto rhs = static_cast<PhiExpression *>(ve_bin->rhs_.get());
// get 2 predecessors
auto pre_bbs_ = bb->get_pre_basic_blocks();
auto pre_1 = *pre_bbs_.begin();
auto pre_2 = *(++pre_bbs_.begin());
// try to get the merged value expression
auto vl_merge = BinaryExpression::create(ve_bin->op_, lhs->lhs_, rhs->lhs_);
auto vr_merge = BinaryExpression::create(ve_bin->op_, lhs->rhs_, rhs->rhs_);
auto vi = getVN(pout_[pre_1], vl_merge);
auto vj = getVN(pout_[pre_2], vr_merge);
if (vi == nullptr)
vi = valuePhiFunc(vl_merge, pre_1);
if (vj == nullptr)
vj = valuePhiFunc(vr_merge, pre_2);
if (vi and vj)
return PhiExpression::create(vi, vj);
else
return nullptr;
} }
shared_ptr<Expression> shared_ptr<Expression>
GVN::getVN(const partitions &pout, shared_ptr<Expression> ve) { GVN::getVN(const partitions &pout, shared_ptr<Expression> ve) {
// TODO: return what? // TODO: return what?
/* for (auto c : pout) {
* if (c->value_expr_ == ve)
* return ve;
* } */
for (auto it = pout.begin(); it != pout.end(); it++) for (auto it = pout.begin(); it != pout.end(); it++)
if ((*it)->value_expr_ and *(*it)->value_expr_ == *ve) if ((*it)->value_expr_ and *(*it)->value_expr_ == *ve)
return {}; return ve;
return nullptr; return nullptr;
} }
...@@ -490,6 +676,12 @@ GVNExpression::operator==(const Expression &lhs, const Expression &rhs) { ...@@ -490,6 +676,12 @@ GVNExpression::operator==(const Expression &lhs, const Expression &rhs) {
return equiv_as<BinaryExpression>(lhs, rhs); return equiv_as<BinaryExpression>(lhs, rhs);
case Expression::e_phi: case Expression::e_phi:
return equiv_as<PhiExpression>(lhs, rhs); return equiv_as<PhiExpression>(lhs, rhs);
case Expression::e_cast:
return equiv_as<CastExpression>(lhs, rhs);
case Expression::e_gep:
return equiv_as<GEPExpression>(lhs, rhs);
case Expression::e_unique:
return equiv_as<UniqueExpression>(lhs, rhs);
} }
} }
...@@ -517,12 +709,35 @@ operator==(const GVN::partitions &p1, const GVN::partitions &p2) { ...@@ -517,12 +709,35 @@ operator==(const GVN::partitions &p1, const GVN::partitions &p2) {
// cannot direct compare??? // cannot direct compare???
if (p1.size() != p2.size()) if (p1.size() != p2.size())
return false; return false;
return std::equal(p1.begin(), p1.end(), p2.begin(), p2.end()); auto it1 = p1.begin();
auto it2 = p2.begin();
for (; it1 != p1.end(); ++it1, ++it2)
if (not(**it1 == **it2))
return false;
return true;
} }
// only compare index // only compare members
bool bool
CongruenceClass::operator==(const CongruenceClass &other) const { CongruenceClass::operator==(const CongruenceClass &other) const {
// TODO: which fields need to be compared? // TODO: which fields need to be compared?
return index_ == other.index_; if (members_.size() != other.members_.size())
return false;
return members_ == other.members_;
}
int
GVN::pretend_copy_stmt(Instruction *instr, BasicBlock *bb) {
auto phi = static_cast<PhiInst *>(instr);
// res = phi [op1, name1], [op2, name2]
// ^0 ^1 ^2 ^3
if (phi->get_operand(1)->get_name() == bb->get_name()) {
// pretend copy statement:
// `res = op1`
return 0;
} else if (phi->get_operand(3)->get_name() == bb->get_name()) {
// `res = op2`
return 2;
}
return -1;
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment