Commit f26d91aa authored by 李晓奇's avatar 李晓奇

finish a lot... bugs

parent efbf4233
......@@ -8,12 +8,265 @@
## 实验难点
实验中遇到哪些挑战
### 对于函数`detectEquivalences(G)`
```cpp
detectEquivalences(G)
PIN1 = {} // “1” is the first statement in the program
POUT1 = transferFunction(PIN1)
for each statement s other than the first statement in the program
POUTs = Top
while changes to any POUT occur // i.e. changes in equivalences
for each statement s other than the first statement in the program
if s appears in block b that has two predecessors
then
PINs = Join(POUTs1, POUTs2) // s1 and s2 are last statements in respective predecessors
else
PINs = POUTs3 // s3 the statement just before s
POUTs = transferFunction(PINs) // apply transferFunction on each statement in the block
```
- POUT、PIN是对于语句的的map,但是实现中是对于基本块的map,怎么处理区别?
注意到基本块的pin、pout对应着首条语句的pin和末尾语句的pout,所以没有本质上的不同,主要是怎么补全中间语句的pin、pout。
中间语句的pin、pout应该不会被显式保存,只在对基本块执行转移函数时,逐条计算。所以应该没啥问题。
- 怎么设计top?
尽管算法伪代码中,top是对每一条语句的赋值的,但是任意一条语句的都有所属基本块,只要对基本块的pout标注top、就能保证基本块的pin合乎逻辑,逐条执行转移函数后,每条语句的pin都合法。``
- `while changes to any POUT occur`
这句话如此轻松,但是在实现中出现了问题:这涉及到怎么判断分区相等,做不好会出现死循环。
如果单纯比较index恐怕不行:因为我目前的`transferFunction`会生成一个新的等价类,赋予一个新的编号。
所以我使用对等价类中的`members`做相等的判断,如果两个分区中的等价类都能找到相等的,那么就判断分区相等。
> 但是这样依赖于set的排序,
>
> - 对于分区,这个排序依赖于等价类中index的值,侥幸来说每条指令按顺序执行,即使index不一致(每次转移函数分配最新的index),大小顺序也是一致的?
>
> - 对于等价类,里边是`Value*`指针,这个指针的等价就是语义上的等价,没问题。
改过来后发现还是死循环,DEBUG很久意识到我的分区相等写的有问题:`return members_ == other.members_`,这样会挨个儿比较`members_`指针的值,而不是我预想的比较指针指向的等价类。所以要手动遍历一下。
这个遍历也不是很直观的,因为要做两层解引用:
```cpp
for (; it1 != p1.end(); ++it1, ++it2)
if (not(**it1 == **it2))
return false;
```
`*it`是等价类的指针,`**it`才是等价类。
除此之外,我还忘了`++it2`,淦。
改完这些,终于能跑通第一版本了……
### 等价类设计
> ```cpp
> shared_ptr<Expression>
> GVN::valueExpr(Instruction *instr);
> ```
等价类涉及到针对指令设计、常量折叠、递归计算等,很复杂,需要拎出来单独讨论。
首先注意到传给传进去的参数是指令类型(指针)。需要`transferFunction`处理的指令,一次赋值的右式自然可以由指令类型得到,所以没有毛病。
指令有多种类型,也即lightIR中有多种赋值方式,而非单纯的伪代码中的二元运算。但是即使是已经在算法中讨论过的二元运算,写起来也不容易。
#### 生成等价类的逻辑
首先讨论生成等价类的逻辑:
- 二元运算`y op z`
如果y、z都是常量,那么可以进行常量折叠,得到**一个**`ConstantExpression`类型。
否则返回的是一个`BinaryExpression`类型,这个类型需要左右操作数,都是`Expression`类型,如果操作数是常量,使用`ConstantExpression`创建新的值表达式,否则递归调用`valueExpr`。
> 这里为什么递归调用不会走到未定义?
- phi函数
我们逻辑上把其改变成了copy,`transferFunction`应该维护这一点,所以认为没有phi函数传入。
具体是`transferFunction`函数中的这几句话:
```cpp
GVN::partitions
GVN::transferFunction(Instruction *x, Value *e, partitions pin)
{
auto e_instr = dynamic_cast<Instruction *>(e);
auto e_const = dynamic_cast<Constant *>(e);
assert((not e or e_instr or e_const) &&
"A value must be from an instruction or constant");
// ……
if (e) {
if (e_const)
ve = ConstantExpression::create(e_const);
else
ve = valueExpr(e_instr, &pin);
} else
ve = valueExpr(x, &pin);
}
```
- 比较函数:和二元运算一致。
- 类型转换:添加新的表达式类型`CastExpression`,处理和二元操作类似。
- 其他产生赋值的指令,也会由`transferFunction`传递给`ValueExpr`,这些指令是:`load`、`alloca`和返回值非空的`call`。
这里设计一种值表达式类型:`UniqueExpression`,这个表达式和任何其他都不等价,表现为`equiv`直接返回false。
#### 若干需要梳理的问题
除此之外,关于等价类有几个命题一定要说明,这些是理解GVN的关键。
- 等价类中的ve一定存在吗?
根据等价类的设计,ve一定是存在的,且等价类中`members_`中一定存在元素。
- ve一定唯一吗?
从逻辑上,等价类集合中的元素享有共同的ve产生的结果,ve是唯一的
从实现上,添加
- ve或者vpf能够作为代表吗?
这个问题源于`transferFunction`的一句话:
> if ve or vpf is in a class Ci in POUTs
- 代表元如何选取?什么时候选取?
- 怎么判断等价类相等
根据论坛上这个[例子](http://cscourse.ustc.edu.cn/forum/thread.jsp?forum=91&thread=68),我发现值编号会有非必要的增长,即最终收敛时用到的编号是1、5、6,其中2、3、4都是迭代中用到的但是最后都没有体现。
从这个例子中可以发现,迭代收敛时,是pout中对应分区的`members_`不变,值编号还是会变化的,所以判断等价类相等,应该比较`members_`集合。
- 什么时候给新的值编号?
1. 执行转移函数发现没有可以归属的等价类时
2. 求交巴拉巴拉还没搞清楚那里
- 等价类为空的判定,可以用`members_`做判断吗?
我认为是可以的,存在情况:两个基本块汇合时,交出一个集合,它的members为空,但是ve不空:
```
BB0: y =
     z =
; pout: [{v1,y}, {v2,z}]
BB1: x1 = y + z ; preds = BB0
; pout: [{v1,y}, {v2,z}, {v3, x1, v1+v2}]
BB2: x2 = y + z ; preds = BB0
; pout: [{v1,y}, {v2,z}, {v4, x2, v1+v2}]
BB3: ... ; preds = BB1, BB2
```
进入BB3时,对BB1和BB2的pout中等价类分别取交,执行`Intersect(Ci={v3, x1, v1+v2}, Cj={v4, x2, v1+v2}`,得到的是`{v1+v2}`,`members_`中没有成员,但是ve存在,我认为这种情况也应该判定结果为空集。
因为最终反映到IR上,我们直接关心的都是`members_`中的成员,所以用`members_`的空代表等价类的空我觉得是合适的。
| 类型 | 处理 |
| ---------------- | ------------------------- |
| 二元运算 | 依伪代码 |
| 没有返回值的 | 不处理 |
| phi函数 | 本块中的不管,后继块的做考虑 |
| 函数调用 | 有返回值的考虑等价类,**注意后边纯函数的处理** |
| 产生指针(GEP、alloca) | 先为返回值单独新建等价类 |
| 类型转换及0扩展 | 单独新建等价类 |
| 比较 | 根据op和操作数递归判断,应该和二元运算类似 |
| load | 单独新建等价类 |
### 对于函数`Intersect(Ci, Cj )`
伪代码中,一个等价类是一个集合,求交后,v~k~是自然而然的在或不在新集合中,但是实现中并不这么简单,如何对这样的两个结构体进行逻辑上的取交:
```cpp
struct CongruenceClass {
size_t index_;
Value *leader_;
std::shared_ptr<GVNExpression::Expression> value_expr_;
std::shared_ptr<GVNExpression::PhiExpression> value_phi_;
std::set<Value *> members_;
}
```
- 首先是Intersect的伪代码描述中,下边这句怎么在实现中体现?
> C~k~ does not have value number
解决:index不同的自然就交不出v~i~,相同的时候才有v~i~即value number.
隐患:transferFunction要对v~i~有继承性,即不能每次涉及一个等价类都新开一个value number
- 根据上一条讨论,得出一个观点:在伪代码中,一个等价类集合会出现的内容,在实现中的对应如下:
- 普通变量:`members_`
- 值表达式:`value_expr_`
- phi函数:`value_phi_`
- 值编号v~i~:`index`
因此,伪代码中对集合的求交,对应到实现上,就是对以上四个域求交。
### 对于函数`valuePhiFunc(ve, P)`
- 输入就一个分区,实际上要用到两个前驱的分区,怎么处理?
传入基本块做参数。
- 二元运算的顺序颠倒怎么办?
> 如phi(x1,y1)+phi(x2,y2)和phi(x1,y1)+phi(y2,x2)
>
> 对于第二种,单纯寻找x1+y2和y1+x2肯定不合逻辑
应该不会有这种情况:phi表达式,追根溯源是在join时产生的:
```cpp
GVN::join(const partitions &P1, const partitions &P2);
```
这里的分区是严格按照`Basicblock`的方法`get_pre_basic_blocks()`的返回顺序传入的,所以不会有担心的情况发生。
### 所见非所得
两方面:
- 论文中伪代码对于基本块的关注不是很多,但是我们的实现一定要围绕基本块进行,这个的解决对策说到了,其实和单条语句没有本质差异。
- 我们针对lightir优化,经常深入API中陷入到实现细节,还容易被C++的语言特性绊住,但是应该尝试从直观上理解:我们究竟要实现什么样的效果。这个就得去看ll文件找灵感。然后这还不够,回过头来看代码又呆住了,因为ll语句和那些类型对应不上,这个就是容易忽视的一点:看lightir类型的print函数,这个能帮我们直观地串联起实现细节和表象的ll文件。
## 实验设计
### join_helper
### _TOP
### transferFunction(Basicblock)
###
实现思路,相应代码,优化前后的IR对比(举一个例子)并辅以简单说明
### 思考题
1. 请简要分析你的算法复杂度
2. `std::shared_ptr`如果存在环形引用,则无法正确释放内存,你的 Expression 类是否存在 circular reference?
3. 尽管本次实验已经写了很多代码,但是在算法上和工程上仍然可以对 GVN 进行改进,请简述你的 GVN 实现可以改进的地方
......@@ -25,4 +278,3 @@
## 实验反馈(可选 不会评分)
对本次实验的建议
......@@ -9,8 +9,7 @@
class BasicBlock;
class Function;
class Instruction : public User, public llvm::ilist_node<Instruction>
{
class Instruction : public User, public llvm::ilist_node<Instruction> {
public:
enum OpID {
// Terminator Instructions
......@@ -86,8 +85,7 @@ class Instruction : public User, public llvm::ilist_node<Instruction>
// clang-format on
std::string get_instr_op_name() { return get_instr_op_name(op_id_); }
bool is_void()
{
bool is_void() {
return ((op_id_ == ret) || (op_id_ == br) || (op_id_ == store) ||
(op_id_ == call && this->get_type()->is_void_type()));
}
......@@ -108,6 +106,7 @@ class Instruction : public User, public llvm::ilist_node<Instruction>
bool is_fsub() { return op_id_ == fsub; }
bool is_fmul() { return op_id_ == fmul; }
bool is_fdiv() { return op_id_ == fdiv; }
bool is_fp2si() { return op_id_ == fptosi; }
bool is_si2fp() { return op_id_ == sitofp; }
......@@ -118,8 +117,7 @@ class Instruction : public User, public llvm::ilist_node<Instruction>
bool is_gep() { return op_id_ == getelementptr; }
bool is_zext() { return op_id_ == zext; }
bool isBinary()
{
bool isBinary() {
return (is_add() || is_sub() || is_mul() || is_div() || is_fadd() ||
is_fsub() || is_fmul() || is_fdiv()) &&
(get_num_operand() == 2);
......@@ -133,34 +131,29 @@ class Instruction : public User, public llvm::ilist_node<Instruction>
BasicBlock *parent_;
};
namespace detail
{
template <typename T>
struct tag
{
namespace detail {
template <typename T>
struct tag {
using type = T;
};
template <typename... Ts>
struct select_last
{
};
template <typename... Ts>
struct select_last {
// Use a fold-expression to fold the comma operator over the parameter
// pack.
using type = typename decltype((tag<Ts>{}, ...))::type;
};
template <typename... Ts>
using select_last_t = typename select_last<Ts...>::type;
};
template <typename... Ts>
using select_last_t = typename select_last<Ts...>::type;
}; // namespace detail
template <class>
inline constexpr bool always_false_v = false;
template <typename Inst>
class BaseInst : public Instruction
{
class BaseInst : public Instruction {
protected:
template <typename... Args>
static Inst *create(Args &&...args)
{
static Inst *create(Args &&...args) {
if constexpr (std::is_same_v<
std::decay_t<detail::select_last_t<Args...>>,
BasicBlock *>) {
......@@ -171,13 +164,10 @@ class BaseInst : public Instruction
}
template <typename... Args>
BaseInst(Args &&...args) : Instruction(std::forward<Args>(args)...)
{
}
BaseInst(Args &&...args) : Instruction(std::forward<Args>(args)...) {}
};
class BinaryInst : public BaseInst<BinaryInst>
{
class BinaryInst : public BaseInst<BinaryInst> {
friend BaseInst<BinaryInst>;
private:
......@@ -185,35 +175,51 @@ class BinaryInst : public BaseInst<BinaryInst>
public:
// create add instruction, auto insert to bb
static BinaryInst *create_add(Value *v1, Value *v2, BasicBlock *bb,
static BinaryInst *create_add(Value *v1,
Value *v2,
BasicBlock *bb,
Module *m);
// create sub instruction, auto insert to bb
static BinaryInst *create_sub(Value *v1, Value *v2, BasicBlock *bb,
static BinaryInst *create_sub(Value *v1,
Value *v2,
BasicBlock *bb,
Module *m);
// create mul instruction, auto insert to bb
static BinaryInst *create_mul(Value *v1, Value *v2, BasicBlock *bb,
static BinaryInst *create_mul(Value *v1,
Value *v2,
BasicBlock *bb,
Module *m);
// create Div instruction, auto insert to bb
static BinaryInst *create_sdiv(Value *v1, Value *v2, BasicBlock *bb,
static BinaryInst *create_sdiv(Value *v1,
Value *v2,
BasicBlock *bb,
Module *m);
// create fadd instruction, auto insert to bb
static BinaryInst *create_fadd(Value *v1, Value *v2, BasicBlock *bb,
static BinaryInst *create_fadd(Value *v1,
Value *v2,
BasicBlock *bb,
Module *m);
// create fsub instruction, auto insert to bb
static BinaryInst *create_fsub(Value *v1, Value *v2, BasicBlock *bb,
static BinaryInst *create_fsub(Value *v1,
Value *v2,
BasicBlock *bb,
Module *m);
// create fmul instruction, auto insert to bb
static BinaryInst *create_fmul(Value *v1, Value *v2, BasicBlock *bb,
static BinaryInst *create_fmul(Value *v1,
Value *v2,
BasicBlock *bb,
Module *m);
// create fDiv instruction, auto insert to bb
static BinaryInst *create_fdiv(Value *v1, Value *v2, BasicBlock *bb,
static BinaryInst *create_fdiv(Value *v1,
Value *v2,
BasicBlock *bb,
Module *m);
virtual std::string print() override;
......@@ -222,8 +228,7 @@ class BinaryInst : public BaseInst<BinaryInst>
void assertValid();
};
class CmpInst : public BaseInst<CmpInst>
{
class CmpInst : public BaseInst<CmpInst> {
friend BaseInst<CmpInst>;
public:
......@@ -240,7 +245,10 @@ class CmpInst : public BaseInst<CmpInst>
CmpInst(Type *ty, CmpOp op, Value *lhs, Value *rhs, BasicBlock *bb);
public:
static CmpInst *create_cmp(CmpOp op, Value *lhs, Value *rhs, BasicBlock *bb,
static CmpInst *create_cmp(CmpOp op,
Value *lhs,
Value *rhs,
BasicBlock *bb,
Module *m);
CmpOp get_cmp_op() { return cmp_op_; }
......@@ -253,8 +261,7 @@ class CmpInst : public BaseInst<CmpInst>
void assertValid();
};
class FCmpInst : public BaseInst<FCmpInst>
{
class FCmpInst : public BaseInst<FCmpInst> {
friend BaseInst<FCmpInst>;
public:
......@@ -271,8 +278,11 @@ class FCmpInst : public BaseInst<FCmpInst>
FCmpInst(Type *ty, CmpOp op, Value *lhs, Value *rhs, BasicBlock *bb);
public:
static FCmpInst *create_fcmp(CmpOp op, Value *lhs, Value *rhs,
BasicBlock *bb, Module *m);
static FCmpInst *create_fcmp(CmpOp op,
Value *lhs,
Value *rhs,
BasicBlock *bb,
Module *m);
CmpOp get_cmp_op() { return cmp_op_; }
......@@ -284,33 +294,36 @@ class FCmpInst : public BaseInst<FCmpInst>
void assert_valid();
};
class CallInst : public BaseInst<CallInst>
{
class CallInst : public BaseInst<CallInst> {
friend BaseInst<CallInst>;
protected:
CallInst(Function *func, std::vector<Value *> args, BasicBlock *bb);
public:
static CallInst *create(Function *func, std::vector<Value *> args,
static CallInst *create(Function *func,
std::vector<Value *> args,
BasicBlock *bb);
FunctionType *get_function_type() const;
virtual std::string print() override;
};
class BranchInst : public BaseInst<BranchInst>
{
class BranchInst : public BaseInst<BranchInst> {
friend BaseInst<BranchInst>;
private:
BranchInst(Value *cond, BasicBlock *if_true, BasicBlock *if_false,
BranchInst(Value *cond,
BasicBlock *if_true,
BasicBlock *if_false,
BasicBlock *bb);
BranchInst(BasicBlock *if_true, BasicBlock *bb);
public:
static BranchInst *create_cond_br(Value *cond, BasicBlock *if_true,
BasicBlock *if_false, BasicBlock *bb);
static BranchInst *create_cond_br(Value *cond,
BasicBlock *if_true,
BasicBlock *if_false,
BasicBlock *bb);
static BranchInst *create_br(BasicBlock *if_true, BasicBlock *bb);
bool is_cond_br() const;
......@@ -318,8 +331,7 @@ class BranchInst : public BaseInst<BranchInst>
virtual std::string print() override;
};
class ReturnInst : public BaseInst<ReturnInst>
{
class ReturnInst : public BaseInst<ReturnInst> {
friend BaseInst<ReturnInst>;
private:
......@@ -334,8 +346,7 @@ class ReturnInst : public BaseInst<ReturnInst>
virtual std::string print() override;
};
class GetElementPtrInst : public BaseInst<GetElementPtrInst>
{
class GetElementPtrInst : public BaseInst<GetElementPtrInst> {
friend BaseInst<GetElementPtrInst>;
private:
......@@ -343,7 +354,8 @@ class GetElementPtrInst : public BaseInst<GetElementPtrInst>
public:
static Type *get_element_type(Value *ptr, std::vector<Value *> idxs);
static GetElementPtrInst *create_gep(Value *ptr, std::vector<Value *> idxs,
static GetElementPtrInst *create_gep(Value *ptr,
std::vector<Value *> idxs,
BasicBlock *bb);
Type *get_element_type() const;
......@@ -353,8 +365,7 @@ class GetElementPtrInst : public BaseInst<GetElementPtrInst>
Type *element_ty_;
};
class StoreInst : public BaseInst<StoreInst>
{
class StoreInst : public BaseInst<StoreInst> {
friend BaseInst<StoreInst>;
private:
......@@ -369,8 +380,7 @@ class StoreInst : public BaseInst<StoreInst>
virtual std::string print() override;
};
class LoadInst : public BaseInst<LoadInst>
{
class LoadInst : public BaseInst<LoadInst> {
friend BaseInst<LoadInst>;
private:
......@@ -385,8 +395,7 @@ class LoadInst : public BaseInst<LoadInst>
virtual std::string print() override;
};
class AllocaInst : public BaseInst<AllocaInst>
{
class AllocaInst : public BaseInst<AllocaInst> {
friend BaseInst<AllocaInst>;
private:
......@@ -403,8 +412,7 @@ class AllocaInst : public BaseInst<AllocaInst>
Type *alloca_ty_;
};
class ZextInst : public BaseInst<ZextInst>
{
class ZextInst : public BaseInst<ZextInst> {
friend BaseInst<ZextInst>;
private:
......@@ -421,8 +429,7 @@ class ZextInst : public BaseInst<ZextInst>
Type *dest_ty_;
};
class FpToSiInst : public BaseInst<FpToSiInst>
{
class FpToSiInst : public BaseInst<FpToSiInst> {
friend BaseInst<FpToSiInst>;
private:
......@@ -439,8 +446,7 @@ class FpToSiInst : public BaseInst<FpToSiInst>
Type *dest_ty_;
};
class SiToFpInst : public BaseInst<SiToFpInst>
{
class SiToFpInst : public BaseInst<SiToFpInst> {
friend BaseInst<SiToFpInst>;
private:
......@@ -457,25 +463,24 @@ class SiToFpInst : public BaseInst<SiToFpInst>
Type *dest_ty_;
};
class PhiInst : public BaseInst<PhiInst>
{
class PhiInst : public BaseInst<PhiInst> {
friend BaseInst<PhiInst>;
private:
PhiInst(OpID op, std::vector<Value *> vals,
std::vector<BasicBlock *> val_bbs, Type *ty, BasicBlock *bb);
PhiInst(OpID op,
std::vector<Value *> vals,
std::vector<BasicBlock *> val_bbs,
Type *ty,
BasicBlock *bb);
PhiInst(Type *ty, OpID op, unsigned num_ops, BasicBlock *bb)
: BaseInst<PhiInst>(ty, op, num_ops, bb)
{
}
: BaseInst<PhiInst>(ty, op, num_ops, bb) {}
Value *l_val_;
public:
static PhiInst *create_phi(Type *ty, BasicBlock *bb);
Value *get_lval() { return l_val_; }
void set_lval(Value *l_val) { l_val_ = l_val; }
void add_phi_pair_operand(Value *val, Value *pre_bb)
{
void add_phi_pair_operand(Value *val, Value *pre_bb) {
this->add_operand(val);
this->add_operand(pre_bb);
}
......
......@@ -19,6 +19,7 @@
#include <utility>
#include <vector>
class GVN;
namespace GVNExpression {
// fold the constant value
......@@ -35,12 +36,13 @@ class ConstFolder {
/**
* for constructor of class derived from `Expression`, we make it public
* because `std::make_shared` needs the constructor to be publicly available,
* but you should call the static factory method `create` instead the constructor itself to get the desired data
* but you should call the static factory method `create` instead the
* constructor itself to get the desired data
*/
class Expression {
public:
// TODO: you need to extend expression types according to testcases
enum gvn_expr_t { e_constant, e_bin, e_phi };
enum gvn_expr_t { e_constant, e_bin, e_phi, e_cast, e_gep, e_unique };
Expression(gvn_expr_t t) : expr_type(t) {}
virtual ~Expression() = default;
virtual std::string print() = 0;
......@@ -50,15 +52,21 @@ class Expression {
gvn_expr_t expr_type;
};
bool operator==(const std::shared_ptr<Expression> &lhs, const std::shared_ptr<Expression> &rhs);
bool operator==(const GVNExpression::Expression &lhs, const GVNExpression::Expression &rhs);
bool operator==(const std::shared_ptr<Expression> &lhs,
const std::shared_ptr<Expression> &rhs);
bool operator==(const GVNExpression::Expression &lhs,
const GVNExpression::Expression &rhs);
class ConstantExpression : public Expression {
public:
static std::shared_ptr<ConstantExpression> create(Constant *c) { return std::make_shared<ConstantExpression>(c); }
static std::shared_ptr<ConstantExpression> create(Constant *c) {
return std::make_shared<ConstantExpression>(c);
}
virtual std::string print() { return c_->print(); }
// we leverage the fact that constants in lightIR have unique addresses
bool equiv(const ConstantExpression *other) const { return c_ == other->c_; }
bool equiv(const ConstantExpression *other) const {
return c_ == other->c_;
}
ConstantExpression(Constant *c) : Expression(e_constant), c_(c) {}
private:
......@@ -67,24 +75,31 @@ class ConstantExpression : public Expression {
// arithmetic expression
class BinaryExpression : public Expression {
friend class ::GVN;
public:
static std::shared_ptr<BinaryExpression> create(Instruction::OpID op,
static std::shared_ptr<BinaryExpression> create(
Instruction::OpID op,
std::shared_ptr<Expression> lhs,
std::shared_ptr<Expression> rhs) {
return std::make_shared<BinaryExpression>(op, lhs, rhs);
}
virtual std::string print() {
return "(" + Instruction::get_instr_op_name(op_) + " " + lhs_->print() + " " + rhs_->print() + ")";
return "(" + Instruction::get_instr_op_name(op_) + " " + lhs_->print() +
" " + rhs_->print() + ")";
}
bool equiv(const BinaryExpression *other) const {
if (op_ == other->op_ and *lhs_ == *other->lhs_ and *rhs_ == *other->rhs_)
if (op_ == other->op_ and *lhs_ == *other->lhs_ and
*rhs_ == *other->rhs_)
return true;
else
return false;
}
BinaryExpression(Instruction::OpID op, std::shared_ptr<Expression> lhs, std::shared_ptr<Expression> rhs)
BinaryExpression(Instruction::OpID op,
std::shared_ptr<Expression> lhs,
std::shared_ptr<Expression> rhs)
: Expression(e_bin), op_(op), lhs_(lhs), rhs_(rhs) {}
private:
......@@ -93,33 +108,122 @@ class BinaryExpression : public Expression {
};
class PhiExpression : public Expression {
friend class ::GVN;
public:
static std::shared_ptr<PhiExpression> create(std::shared_ptr<Expression> lhs, std::shared_ptr<Expression> rhs) {
static std::shared_ptr<PhiExpression> create(
std::shared_ptr<Expression> lhs,
std::shared_ptr<Expression> rhs) {
return std::make_shared<PhiExpression>(lhs, rhs);
}
virtual std::string print() { return "(phi " + lhs_->print() + " " + rhs_->print() + ")"; }
virtual std::string print() {
return "(phi " + lhs_->print() + " " + rhs_->print() + ")";
}
bool equiv(const PhiExpression *other) const {
if (*lhs_ == *other->lhs_ and *rhs_ == *other->rhs_)
return true;
else
return false;
}
PhiExpression(std::shared_ptr<Expression> lhs, std::shared_ptr<Expression> rhs)
PhiExpression(std::shared_ptr<Expression> lhs,
std::shared_ptr<Expression> rhs)
: Expression(e_phi), lhs_(lhs), rhs_(rhs) {}
private:
std::shared_ptr<Expression> lhs_, rhs_;
};
// type cast expression
class CastExpression : public Expression {
public:
static std::shared_ptr<CastExpression> create(
Instruction::OpID op,
std::shared_ptr<Expression> src,
Type *dest_type) {
return std::make_shared<CastExpression>(op, src, dest_type);
}
virtual std::string print() {
return "(" + dest_ty_->print() + " " +
Instruction::get_instr_op_name(op_) + " " + src_->print() + ")";
}
bool equiv(const CastExpression *other) const {
return op_ == other->op_ and src_ == other->src_ and
dest_ty_ == other->dest_ty_;
}
CastExpression(Instruction::OpID op,
std::shared_ptr<Expression> src,
Type *dest_type)
: Expression(e_cast), op_(op), src_(src), dest_ty_(dest_type) {}
private:
Instruction::OpID op_;
std::shared_ptr<Expression> src_;
Type *dest_ty_;
};
// type cast expression
class GEPExpression : public Expression {
public:
static std::shared_ptr<GEPExpression> create(
std::shared_ptr<Expression> ptr,
std::vector<std::shared_ptr<Expression>> &idxs) {
return std::make_shared<GEPExpression>(ptr, idxs);
}
virtual std::string print() {
std::string ret = "(GEP " + ptr_->print();
for (auto idx : idxs_)
ret += " " + idx->print();
return ret + ")";
}
bool equiv(const GEPExpression *other) const {
if (idxs_.size() != other->idxs_.size())
return false;
for (int i = 0; i != idxs_.size(); ++i)
if (not(idxs_[i] == other->idxs_[i]))
return false;
return ptr_ == other->ptr_;
}
GEPExpression(std::shared_ptr<Expression> ptr,
std::vector<std::shared_ptr<Expression>> &idxs)
: Expression(e_gep), ptr_(ptr), idxs_(idxs) {}
private:
std::shared_ptr<Expression> ptr_;
std::vector<std::shared_ptr<Expression>> idxs_;
};
// unique expression: not equal to any one else
class UniqueExpression : public Expression {
public:
static std::shared_ptr<UniqueExpression> create(Instruction *instr) {
return std::make_shared<UniqueExpression>(instr);
}
virtual std::string print() { return "(UNIQUE " + instr_->print() + ")"; }
bool equiv(const UniqueExpression *other) const { return false; }
UniqueExpression(Instruction *instr)
: Expression(e_unique), instr_(instr) {}
private:
std::shared_ptr<Instruction> instr_;
};
} // namespace GVNExpression
/**
* Congruence class in each partitions
* note: for constant propagation, you might need to add other fields
* and for load/store redundancy detection, you most certainly need to modify the class
* and for load/store redundancy detection, you most certainly need to modify
* the class
*/
struct CongruenceClass {
size_t index_;
// representative of the congruence class, used to replace all the members (except itself) when analysis is done
// representative of the congruence class, used to replace all the members
// (except itself) when analysis is done
Value *leader_;
// value expression in congruence class
std::shared_ptr<GVNExpression::Expression> value_expr_;
......@@ -128,17 +232,22 @@ struct CongruenceClass {
// equivalent variables in one congruence class
std::set<Value *> members_;
CongruenceClass(size_t index) : index_(index), leader_{}, value_expr_{}, value_phi_{}, members_{} {}
CongruenceClass(size_t index)
: index_(index), leader_{}, value_expr_{}, value_phi_{}, members_{} {}
bool operator<(const CongruenceClass &other) const { return this->index_ < other.index_; }
bool operator<(const CongruenceClass &other) const {
return this->index_ < other.index_;
}
bool operator==(const CongruenceClass &other) const;
};
namespace std {
template <>
// overload std::less for std::shared_ptr<CongruenceClass>, i.e. how to sort the congruence classes
// overload std::less for std::shared_ptr<CongruenceClass>, i.e. how to sort the
// congruence classes
struct less<std::shared_ptr<CongruenceClass>> {
bool operator()(const std::shared_ptr<CongruenceClass> &a, const std::shared_ptr<CongruenceClass> &b) const {
bool operator()(const std::shared_ptr<CongruenceClass> &a,
const std::shared_ptr<CongruenceClass> &b) const {
// nullptrs should never appear in partitions, so we just dereference it
return *a < *b;
}
......@@ -154,16 +263,23 @@ class GVN : public Pass {
// init for pass metadata;
void initPerFunction();
// fill the following functions according to Pseudocode, **you might need to add more arguments**
// fill the following functions according to Pseudocode, **you might need to
// add more arguments**
void detectEquivalences();
partitions join(const partitions &P1, const partitions &P2);
std::shared_ptr<CongruenceClass> intersect(std::shared_ptr<CongruenceClass>, std::shared_ptr<CongruenceClass>);
std::shared_ptr<CongruenceClass> intersect(
std::shared_ptr<CongruenceClass>,
std::shared_ptr<CongruenceClass>);
partitions transferFunction(Instruction *x, Value *e, partitions pin);
partitions transferFunction(BasicBlock *bb);
std::shared_ptr<GVNExpression::PhiExpression> valuePhiFunc(std::shared_ptr<GVNExpression::Expression>,
const partitions &);
std::shared_ptr<GVNExpression::Expression> valueExpr(Instruction *instr);
std::shared_ptr<GVNExpression::Expression> getVN(const partitions &pout,
std::shared_ptr<GVNExpression::PhiExpression> valuePhiFunc(
std::shared_ptr<GVNExpression::Expression>,
BasicBlock *bb);
std::shared_ptr<GVNExpression::Expression> valueExpr(
Instruction *instr,
partitions *part = nullptr);
std::shared_ptr<GVNExpression::Expression> getVN(
const partitions &pout,
std::shared_ptr<GVNExpression::Expression> ve);
// replace cc members with leader
......@@ -180,6 +296,7 @@ class GVN : public Pass {
// self add
//
std::uint64_t new_number() { return next_value_number_++; }
static int pretend_copy_stmt(Instruction *inst, BasicBlock *bb);
private:
bool dump_json_;
......@@ -191,7 +308,9 @@ class GVN : public Pass {
std::unique_ptr<DeadCode> dce_;
// self add
std::map<Instruction*, bool> _TOP;
std::map<BasicBlock *, bool> _TOP;
partitions join_helper(BasicBlock *pre1, BasicBlock *pre2);
BasicBlock* curr_bb;
};
bool operator==(const GVN::partitions &p1, const GVN::partitions &p2);
cminusf_builder_stu.cpp
......@@ -5,22 +5,20 @@
#include "cminusf_builder.hpp"
#include "logging.hpp"
#define CONST_FP(num) ConstantFP::get((float)num, module.get())
#define CONST_INT(num) ConstantInt::get(num, module.get())
// TODO: Global Variable Declarations
// You can define global variables here
// to store state. You can expand these
// definitions if you need to.
// to store state
// the latest return value
Value *cur_value = nullptr;
// if var is assignment's left part, LV is true
bool LV = false;
// store temporary value
Value *tmp_val = nullptr;
// whether require lvalue
bool require_lvalue = false;
// function that is being built
Function *cur_fun = nullptr;
// detect scope pre-enter (for elegance only)
bool pre_enter_scope = false;
// types
Type *VOID_T;
......@@ -30,9 +28,22 @@ Type *INT32PTR_T;
Type *FLOAT_T;
Type *FLOATPTR_T;
// initializer
ConstantZero *I32Initializer;
ConstantZero *FloatInitializer;
bool
promote(IRBuilder *builder, Value **l_val_p, Value **r_val_p) {
bool is_int;
auto &l_val = *l_val_p;
auto &r_val = *r_val_p;
if (l_val->get_type() == r_val->get_type()) {
is_int = l_val->get_type()->is_integer_type();
} else {
is_int = false;
if (l_val->get_type()->is_integer_type())
l_val = builder->create_sitofp(l_val, FLOAT_T);
else
r_val = builder->create_sitofp(r_val, FLOAT_T);
}
return is_int;
}
/*
* use CMinusfBuilder::Scope to construct scopes
......@@ -42,182 +53,61 @@ ConstantZero *FloatInitializer;
* scope.find: find and return the value bound to the name
*/
void error_exit(std::string s) {
LOG_ERROR << s;
std::abort();
}
// This function makes sure that
// 1. 2 values have same type
// 2. type is either i32 or float
void CminusfBuilder::biop_type_check(Value *&lvalue, Value *&rvalue, std::string util) {
if (Type::is_eq_type(lvalue->get_type(), rvalue->get_type())) {
if (lvalue->get_type()->is_integer_type() or lvalue->get_type()->is_float_type()) {
// check for i1
if (Type::is_eq_type(lvalue->get_type(), INT1_T)) {
lvalue = builder->create_zext(lvalue, INT32_T);
rvalue = builder->create_zext(rvalue, INT32_T);
}
} else
error_exit("not supported type cast for " + util);
return;
}
// only support cast between int and float: i32, i1, float
//
// case that integer and float is mixed, directly cast integer to float
if (lvalue->get_type()->is_integer_type() and rvalue->get_type()->is_float_type())
lvalue = builder->create_sitofp(lvalue, FLOAT_T);
else if (lvalue->get_type()->is_float_type() and rvalue->get_type()->is_integer_type())
rvalue = builder->create_sitofp(rvalue, FLOAT_T);
else if (lvalue->get_type()->is_integer_type() and rvalue->get_type()->is_integer_type()) {
// case that I32 and I1 mixed
if (Type::is_eq_type(lvalue->get_type(), INT1_T))
lvalue = builder->create_zext(lvalue, INT32_T);
else
rvalue = builder->create_zext(rvalue, INT32_T);
} else { // we only support computing among i1, i32 and float
error_exit("not supported type cast for " + util);
}
}
// this function makes sure value is a bool type
void CminusfBuilder::cast_to_i1(Value *&value) {
assert(value->get_type()->is_integer_type() or value->get_type()->is_float_type());
if (value->get_type()->is_float_type())
// value = builder->create_fptosi(value, INT1_T);
value = builder->create_fcmp_ne(value, CONST_FP(0));
else if (Type::is_eq_type(value->get_type(), INT32_T))
value = builder->create_icmp_ne(value, CONST_INT(0));
}
void CminusfBuilder::visit(ASTProgram &node) {
void
CminusfBuilder::visit(ASTProgram &node) {
VOID_T = Type::get_void_type(module.get());
INT1_T = Type::get_int1_type(module.get());
INT32_T = Type::get_int32_type(module.get());
INT32PTR_T = Type::get_int32_ptr_type(module.get());
FLOAT_T = Type::get_float_type(module.get());
FLOATPTR_T = Type::get_float_ptr_type(module.get());
I32Initializer = ConstantZero::get(INT32_T, builder->get_module());
FloatInitializer = ConstantZero::get(FLOAT_T, builder->get_module());
for (auto decl : node.declarations) {
decl->accept(*this);
}
}
// Done
void CminusfBuilder::visit(ASTNum &node) {
//!TODO: This function is empty now.
// Add some code here.
//
switch (node.type) {
case TYPE_INT:
cur_value = CONST_INT(node.i_val);
return;
case TYPE_FLOAT:
cur_value = CONST_FP(node.f_val);
return;
default:
error_exit("ASTNum is not int or float");
}
}
// Done
void CminusfBuilder::visit(ASTVarDeclaration &node) {
//!TODO: This function is empty now.
// Add some code here.
bool global = (builder->get_insert_block() == nullptr);
if (node.num) {
// declares an array
//
// get array size
node.num->accept(*this);
//
// !no type cast here!
if (not(node.num->type == TYPE_INT))
error_exit("size of array has non-integer type");
int size = node.num->i_val;
if (size <= 0)
error_exit("array size[" + std::to_string(size) + "] <= 0");
switch (node.type) {
case TYPE_INT: {
auto I32Array_T = Type::get_array_type(INT32_T, size);
if (global)
cur_value =
GlobalVariable::create(node.id, builder->get_module(), I32Array_T, false, I32Initializer);
void
CminusfBuilder::visit(ASTNum &node) {
if (node.type == TYPE_INT)
tmp_val = CONST_INT(node.i_val);
else
cur_value = builder->create_alloca(I32Array_T);
break;
}
tmp_val = CONST_FP(node.f_val);
}
case TYPE_FLOAT: {
auto FloatArray_T = Type::get_array_type(FLOAT_T, size);
if (global)
cur_value =
GlobalVariable::create(node.id, builder->get_module(), FloatArray_T, false, FloatInitializer);
void
CminusfBuilder::visit(ASTVarDeclaration &node) {
Type *var_type;
if (node.type == TYPE_INT)
var_type = Type::get_int32_type(module.get());
else
cur_value = builder->create_alloca(FloatArray_T);
break;
}
default:
error_exit("Variable type(not array) is not int or float");
}
assert(cur_value->get_type()->is_pointer_type() && "IF SEE THIS: API ERROR");
var_type = Type::get_float_type(module.get());
if (node.num == nullptr) {
if (scope.in_global()) {
auto initializer = ConstantZero::get(var_type, module.get());
auto var = GlobalVariable::create(
node.id, module.get(), var_type, false, initializer);
scope.push(node.id, var);
} else {
// flat int or float type
switch (node.type) {
case TYPE_INT:
if (global)
cur_value = GlobalVariable::create(node.id, builder->get_module(), INT32_T, false, I32Initializer);
else
cur_value = builder->create_alloca(INT32_T);
break;
case TYPE_FLOAT:
if (global)
cur_value =
GlobalVariable::create(node.id, builder->get_module(), FLOAT_T, false, FloatInitializer);
else {
/* Beautiful is better than ugly.
* Explicit is better than implicit.
* Simple is better than complex.
* Complex is better than complicated.
* Flat is better than nested.
* Sparse is better than dense.
* Readability counts.
* Special cases aren't special enough to break the rules.
* Although practicality beats purity.
* Errors should never pass silently.
* Unless explicitly silenced.
* In the face of ambiguity, refuse the temptation to guess.
* There should be one-- and preferably only one --obvious way to do it.
* Although that way may not be obvious at first unless you're Dutch.
* Now is better than never.
* Although never is often better than *right* now.
* If the implementation is hard to explain, it's a bad idea.
* If the implementation is easy to explain, it may be a good idea.
* Namespaces are one honking great idea -- let's do more of those! */
// cur_value = builder->create_alloca(INT32_T);
cur_value = builder->create_alloca(FLOAT_T);
auto var = builder->create_alloca(var_type);
scope.push(node.id, var);
}
break;
default:
error_exit("Variable type(not array) is not int or float");
} else {
auto *array_type = ArrayType::get(var_type, node.num->i_val);
if (scope.in_global()) {
auto initializer = ConstantZero::get(array_type, module.get());
auto var = GlobalVariable::create(
node.id, module.get(), array_type, false, initializer);
scope.push(node.id, var);
} else {
auto var = builder->create_alloca(array_type);
scope.push(node.id, var);
}
}
if (not scope.push(node.id, cur_value))
error_exit("variable redefined: " + node.id);
LOG_DEBUG << "add entry: " << node.id << " " << cur_value;
}
// Done
void CminusfBuilder::visit(ASTFunDeclaration &node) {
void
CminusfBuilder::visit(ASTFunDeclaration &node) {
FunctionType *fun_type;
Type *ret_type;
std::vector<Type *> param_types;
......@@ -229,46 +119,54 @@ void CminusfBuilder::visit(ASTFunDeclaration &node) {
ret_type = VOID_T;
for (auto &param : node.params) {
//!TODO: Please accomplish param_types.
//
// First make function BB, which needs this param type,
// then set_insert_point, we can call accept to gen code,
switch (param->type) {
case TYPE_INT:
param_types.push_back(param->isarray ? INT32PTR_T : INT32_T);
break;
case TYPE_FLOAT:
param_types.push_back(param->isarray ? FLOATPTR_T : FLOAT_T);
break;
case TYPE_VOID:
if (not param_types.empty())
error_exit("function parameters weird");
break;
if (param->type == TYPE_INT) {
if (param->isarray) {
param_types.push_back(INT32PTR_T);
} else {
param_types.push_back(INT32_T);
}
} else {
if (param->isarray) {
param_types.push_back(FLOATPTR_T);
} else {
param_types.push_back(FLOAT_T);
}
}
}
fun_type = FunctionType::get(ret_type, param_types);
auto fun = Function::create(fun_type, node.id, module.get());
cur_fun = fun;
scope.push(node.id, fun);
cur_fun = fun;
auto funBB = BasicBlock::create(module.get(), "entry", fun);
builder->set_insert_point(funBB);
scope.enter();
pre_enter_scope = true;
std::vector<Value *> args;
for (auto arg = fun->arg_begin(); arg != fun->arg_end(); arg++) {
args.push_back(*arg);
}
for (int i = 0; i < node.params.size(); ++i) {
//!TODO: You need to deal with params
// and store them in the scope.
cur_value = args[i];
node.params[i]->accept(*this);
if (node.params[i]->isarray) {
Value *array_alloc;
if (node.params[i]->type == TYPE_INT)
array_alloc = builder->create_alloca(INT32PTR_T);
else
array_alloc = builder->create_alloca(FLOATPTR_T);
builder->create_store(args[i], array_alloc);
scope.push(node.params[i]->id, array_alloc);
} else {
Value *alloc;
if (node.params[i]->type == TYPE_INT)
alloc = builder->create_alloca(INT32_T);
else
alloc = builder->create_alloca(FLOAT_T);
builder->create_store(args[i], alloc);
scope.push(node.params[i]->id, alloc);
}
}
node.compound_stmt->accept(*this);
// default return value
// can't deal with return in both blocks
if (builder->get_insert_block()->get_terminator() == nullptr) {
if (cur_fun->get_return_type()->is_void_type())
builder->create_void_ret();
......@@ -280,40 +178,17 @@ void CminusfBuilder::visit(ASTFunDeclaration &node) {
scope.exit();
}
// Done
void CminusfBuilder::visit(ASTParam &node) {
//!TODO: This function is empty now.
// If the parameter is int|float, copy and store them
auto param_value = cur_value;
switch (node.type) {
case TYPE_INT: {
if (node.isarray)
cur_value = builder->create_alloca(INT32PTR_T);
else
cur_value = builder->create_alloca(INT32_T);
break;
}
case TYPE_FLOAT: {
if (node.isarray)
cur_value = builder->create_alloca(FLOATPTR_T);
else
cur_value = builder->create_alloca(FLOAT_T);
break;
}
case TYPE_VOID:
return;
}
scope.push(node.id, cur_value);
builder->create_store(param_value, cur_value);
}
// Done?
void CminusfBuilder::visit(ASTCompoundStmt &node) {
//!TODO: This function is not complete.
// You may need to add some code here
// to deal with complex statements.
void
CminusfBuilder::visit(ASTParam &node) {}
void
CminusfBuilder::visit(ASTCompoundStmt &node) {
bool need_exit_scope = !pre_enter_scope;
if (pre_enter_scope) {
pre_enter_scope = false;
} else {
scope.enter();
}
for (auto &decl : node.local_declarations) {
decl->accept(*this);
......@@ -325,413 +200,303 @@ void CminusfBuilder::visit(ASTCompoundStmt &node) {
break;
}
if (need_exit_scope) {
scope.exit();
}
}
// Done
void CminusfBuilder::visit(ASTExpressionStmt &node) {
//!TODO: This function is empty now.
// Add some code here.
if (node.expression)
void
CminusfBuilder::visit(ASTExpressionStmt &node) {
if (node.expression != nullptr)
node.expression->accept(*this);
}
// Done
void CminusfBuilder::visit(ASTSelectionStmt &node) {
//!TODO: This function is empty now.
// Add some code here.
scope.enter();
void
CminusfBuilder::visit(ASTSelectionStmt &node) {
node.expression->accept(*this);
auto cond = cur_value;
cast_to_i1(cond);
auto ifBB = BasicBlock::create(builder->get_module(), "", cur_fun);
auto endBB = BasicBlock::create(builder->get_module(), "", cur_fun);
if (node.else_statement) {
auto elseBB = BasicBlock::create(builder->get_module(), "", cur_fun);
builder->create_cond_br(cond, ifBB, elseBB);
auto ret_val = tmp_val;
auto trueBB = BasicBlock::create(module.get(), "", cur_fun);
BasicBlock *falseBB{};
auto contBB = BasicBlock::create(module.get(), "", cur_fun);
Value *cond_val;
if (ret_val->get_type()->is_integer_type())
cond_val = builder->create_icmp_ne(ret_val, CONST_INT(0));
else
cond_val = builder->create_fcmp_ne(ret_val, CONST_FP(0.));
builder->set_insert_point(ifBB);
if (node.else_statement == nullptr) {
builder->create_cond_br(cond_val, trueBB, contBB);
} else {
falseBB = BasicBlock::create(module.get(), "", cur_fun);
builder->create_cond_br(cond_val, trueBB, falseBB);
}
builder->set_insert_point(trueBB);
node.if_statement->accept(*this);
builder->create_br(endBB);
builder->set_insert_point(elseBB);
node.else_statement->accept(*this);
builder->create_br(endBB);
if (builder->get_insert_block()->get_terminator() == nullptr)
builder->create_br(contBB);
builder->set_insert_point(endBB);
if (node.else_statement == nullptr) {
// falseBB->erase_from_parent(); // did not clean up memory
} else {
builder->create_cond_br(cond, ifBB, endBB);
builder->set_insert_point(ifBB);
node.if_statement->accept(*this);
builder->create_br(endBB);
builder->set_insert_point(endBB);
builder->set_insert_point(falseBB);
node.else_statement->accept(*this);
if (builder->get_insert_block()->get_terminator() == nullptr)
builder->create_br(contBB);
}
scope.exit();
}
// Done
void CminusfBuilder::visit(ASTIterationStmt &node) {
//!TODO: This function is empty now.
// Add some code here.
scope.enter();
auto HEAD = BasicBlock::create(builder->get_module(), "", cur_fun);
auto BODY = BasicBlock::create(builder->get_module(), "", cur_fun);
auto END = BasicBlock::create(builder->get_module(), "", cur_fun);
builder->create_br(HEAD);
builder->set_insert_point(contBB);
}
builder->set_insert_point(HEAD);
void
CminusfBuilder::visit(ASTIterationStmt &node) {
auto exprBB = BasicBlock::create(module.get(), "", cur_fun);
if (builder->get_insert_block()->get_terminator() == nullptr)
builder->create_br(exprBB);
builder->set_insert_point(exprBB);
node.expression->accept(*this);
auto cond = cur_value;
cast_to_i1(cond);
builder->create_cond_br(cond, BODY, END);
auto ret_val = tmp_val;
auto trueBB = BasicBlock::create(module.get(), "", cur_fun);
auto contBB = BasicBlock::create(module.get(), "", cur_fun);
Value *cond_val;
if (ret_val->get_type()->is_integer_type())
cond_val = builder->create_icmp_ne(ret_val, CONST_INT(0));
else
cond_val = builder->create_fcmp_ne(ret_val, CONST_FP(0.));
builder->set_insert_point(BODY);
builder->create_cond_br(cond_val, trueBB, contBB);
builder->set_insert_point(trueBB);
node.statement->accept(*this);
builder->create_br(HEAD);
builder->set_insert_point(END);
scope.exit();
if (builder->get_insert_block()->get_terminator() == nullptr)
builder->create_br(exprBB);
builder->set_insert_point(contBB);
}
// Done
void CminusfBuilder::visit(ASTReturnStmt &node) {
void
CminusfBuilder::visit(ASTReturnStmt &node) {
if (node.expression == nullptr) {
builder->create_void_ret();
} else {
//!TODO: The given code is incomplete.
// You need to solve other return cases (e.g. return an integer).
//
auto fun_ret_type = cur_fun->get_function_type()->get_return_type();
node.expression->accept(*this);
// type cast
// return type can only be int, float or void
if (not Type::is_eq_type(cur_fun->get_return_type(), cur_value->get_type())) {
if (not cur_value->get_type()->is_integer_type() and not cur_value->get_type()->is_float_type())
error_exit("unsupported return type");
if (cur_value->get_type()->is_float_type())
cur_value = builder->create_fptosi(cur_value, INT32_T);
else if (cur_fun->get_return_type()->is_float_type())
cur_value = builder->create_sitofp(cur_value, FLOAT_T);
if (fun_ret_type != tmp_val->get_type()) {
if (fun_ret_type->is_integer_type())
tmp_val = builder->create_fptosi(tmp_val, INT32_T);
else
cur_value = builder->create_zext(cur_value, INT32_T);
tmp_val = builder->create_sitofp(tmp_val, FLOAT_T);
}
builder->create_ret(cur_value);
LOG_DEBUG << "create ret:\n" << builder->get_module()->print();
builder->create_ret(tmp_val);
}
}
// Done
// if LV is marked, return memory addr
// else return value stored inside
void CminusfBuilder::visit(ASTVar &node) {
//!TODO: This function is empty now.
// Add some code here.
//
// First it's pointer type, the pointed elements have 3 cases:
// 1. int or float
// 2. [i32 x n] or [float x n]
// 3. int*
auto memory = scope.find(node.id);
Value *addr;
if (memory == nullptr)
error_exit("variable " + node.id + " not declared");
LOG_DEBUG << "find entry: " << node.id << " " << memory;
assert(memory->get_type()->is_pointer_type());
auto element_type = memory->get_type()->get_pointer_element_type();
if (node.expression) { // e.g. int a[10]; // mem is [i32 x 10]*
bool old_LV = LV;
LV = false;
void
CminusfBuilder::visit(ASTVar &node) {
auto var = scope.find(node.id);
assert(var != nullptr);
auto is_int =
var->get_type()->get_pointer_element_type()->is_integer_type();
auto is_float =
var->get_type()->get_pointer_element_type()->is_float_type();
auto is_ptr =
var->get_type()->get_pointer_element_type()->is_pointer_type();
bool should_return_lvalue = require_lvalue;
require_lvalue = false;
if (node.expression == nullptr) {
if (should_return_lvalue) {
tmp_val = var;
require_lvalue = false;
} else {
if (is_int || is_float || is_ptr) {
tmp_val = builder->create_load(var);
} else {
tmp_val =
builder->create_gep(var, {CONST_INT(0), CONST_INT(0)});
}
}
} else {
node.expression->accept(*this);
LV = old_LV;
// subscription type cast
if (not Type::is_eq_type(cur_value->get_type(), INT32_T)) {
if (Type::is_eq_type(cur_value->get_type(), FLOAT_T))
cur_value = builder->create_fptosi(cur_value, INT32_T);
auto val = tmp_val;
Value *is_neg;
auto exceptBB = BasicBlock::create(module.get(), "", cur_fun);
auto contBB = BasicBlock::create(module.get(), "", cur_fun);
if (val->get_type()->is_float_type())
val = builder->create_fptosi(val, INT32_T);
is_neg = builder->create_icmp_lt(val, CONST_INT(0));
builder->create_cond_br(is_neg, exceptBB, contBB);
builder->set_insert_point(exceptBB);
auto neg_idx_except_fun = scope.find("neg_idx_except");
builder->create_call(static_cast<Function *>(neg_idx_except_fun), {});
if (cur_fun->get_return_type()->is_void_type())
builder->create_void_ret();
else if (cur_fun->get_return_type()->is_float_type())
builder->create_ret(CONST_FP(0.));
else
error_exit("bad type for subscription");
}
auto cond = builder->create_icmp_lt(cur_value, CONST_INT(0));
auto except_func = scope.find("neg_idx_except");
auto TBB = BasicBlock::create(builder->get_module(), "", cur_fun);
auto passBB = BasicBlock::create(builder->get_module(), "", cur_fun);
builder->create_cond_br(cond, TBB, passBB);
builder->set_insert_point(TBB);
builder->create_call(except_func, {});
builder->create_br(passBB);
builder->set_insert_point(passBB);
// Now the subscription is in cur_value, which is good value
// We should focus on the var type:
//
// assert it's pointer type
if (element_type->is_float_type() or element_type->is_integer_type()) {
// 1. int or float
error_exit("invalid types for array subscript");
} else if (element_type->is_array_type()) {
// 2. [i32 x n] or [float x n]
//
// addr is actually &memory[0][cur_value]
addr = builder->create_gep(memory, {CONST_INT(0), cur_value});
} else if (element_type->is_pointer_type()) {
// 3. int*
// what to do:
// - int**->int*
// - seek addr
//
// addr is actually &(*memory)[cur_value]
addr = builder->create_load(memory);
addr = builder->create_gep(addr, {cur_value});
}
// The logic for this part is the same as `int|float` without subscription.
// Cause we have subscription to find the particular element(int or float),
// we make `addr` its memory address.
} else { // e.g. int a; // a is i32*
if (element_type->is_float_type() or element_type->is_integer_type()) {
// 1. int or float
// addr is the element's addr
addr = memory;
} else {
if (LV)
error_exit("error: pointer or array type is not assignable");
// For array* or pointer* type, the right-value behaviour is quite special,
// so treat them apart.
if (element_type->is_array_type()) {
// 2. [i32 x n] or [float x n]
// addr is the first element's address in the array
cur_value = builder->create_gep(memory, {CONST_INT(0), CONST_INT(0)});
} else if (element_type->is_pointer_type()) {
// 3. int*
// addr is the content in the memory, which is actually pointer type
cur_value = builder->create_load(memory);
}
return;
}
}
if (LV) {
LOG_INFO << "directly return addr" << node.id;
cur_value = addr;
builder->create_ret(CONST_INT(0));
builder->set_insert_point(contBB);
Value *tmp_ptr;
if (is_int || is_float)
tmp_ptr = builder->create_gep(var, {val});
else if (is_ptr) {
auto array_load = builder->create_load(var);
tmp_ptr = builder->create_gep(array_load, {val});
} else
tmp_ptr = builder->create_gep(var, {CONST_INT(0), val});
if (should_return_lvalue) {
tmp_val = tmp_ptr;
require_lvalue = false;
} else {
LOG_INFO << "create load for var: " << node.id;
cur_value = builder->create_load(addr);
tmp_val = builder->create_load(tmp_ptr);
}
}
}
// Done
void CminusfBuilder::visit(ASTAssignExpression &node) {
//!TODO: This function is empty now.
// Add some code here.
//
LV = true;
node.var->accept(*this);
LV = false;
auto addr = cur_value;
void
CminusfBuilder::visit(ASTAssignExpression &node) {
node.expression->accept(*this);
assert(addr->get_type()->get_pointer_element_type() != nullptr);
// type cast: left is a pointer type, pointed to i32 or float
if (not Type::is_eq_type(addr->get_type()->get_pointer_element_type(), cur_value->get_type())) {
if (cur_value->get_type()->is_float_type())
cur_value = builder->create_fptosi(cur_value, INT32_T);
else if (addr->get_type()->get_pointer_element_type()->is_float_type())
cur_value = builder->create_sitofp(cur_value, FLOAT_T);
else if (Type::is_eq_type(cur_value->get_type(), INT1_T))
cur_value = builder->create_zext(cur_value, INT32_T);
auto expr_result = tmp_val;
require_lvalue = true;
node.var->accept(*this);
auto var_addr = tmp_val;
if (var_addr->get_type()->get_pointer_element_type() !=
expr_result->get_type()) {
if (expr_result->get_type() == INT32_T)
expr_result = builder->create_sitofp(expr_result, FLOAT_T);
else
error_exit("bad type for assignment");
expr_result = builder->create_fptosi(expr_result, INT32_T);
}
// gen code
builder->create_store(cur_value, addr);
builder->create_store(expr_result, var_addr);
tmp_val = expr_result;
}
// Done
void CminusfBuilder::visit(ASTSimpleExpression &node) {
//!TODO: This function is empty now.
// Add some code here.
//
if (node.additive_expression_r) {
void
CminusfBuilder::visit(ASTSimpleExpression &node) {
if (node.additive_expression_r == nullptr) {
node.additive_expression_l->accept(*this);
} else {
node.additive_expression_l->accept(*this);
auto lvalue = cur_value;
auto l_val = tmp_val;
node.additive_expression_r->accept(*this);
auto rvalue = cur_value;
// check type
biop_type_check(lvalue, rvalue, "cmp");
bool float_cmp = lvalue->get_type()->is_float_type();
auto r_val = tmp_val;
bool is_int = promote(&*builder, &l_val, &r_val);
Value *cmp;
switch (node.op) {
case OP_LE: {
if (float_cmp)
cur_value = builder->create_fcmp_le(lvalue, rvalue);
case OP_LT:
if (is_int)
cmp = builder->create_icmp_lt(l_val, r_val);
else
cur_value = builder->create_icmp_le(lvalue, rvalue);
cmp = builder->create_fcmp_lt(l_val, r_val);
break;
}
case OP_LT: {
if (float_cmp)
cur_value = builder->create_fcmp_lt(lvalue, rvalue);
case OP_LE:
if (is_int)
cmp = builder->create_icmp_le(l_val, r_val);
else
cur_value = builder->create_icmp_lt(lvalue, rvalue);
cmp = builder->create_fcmp_le(l_val, r_val);
break;
}
case OP_GT: {
if (float_cmp)
cur_value = builder->create_fcmp_gt(lvalue, rvalue);
case OP_GE:
if (is_int)
cmp = builder->create_icmp_ge(l_val, r_val);
else
cur_value = builder->create_icmp_gt(lvalue, rvalue);
cmp = builder->create_fcmp_ge(l_val, r_val);
break;
}
case OP_GE: {
if (float_cmp)
cur_value = builder->create_fcmp_ge(lvalue, rvalue);
case OP_GT:
if (is_int)
cmp = builder->create_icmp_gt(l_val, r_val);
else
cur_value = builder->create_icmp_ge(lvalue, rvalue);
cmp = builder->create_fcmp_gt(l_val, r_val);
break;
}
case OP_EQ: {
if (float_cmp)
cur_value = builder->create_fcmp_eq(lvalue, rvalue);
case OP_EQ:
if (is_int)
cmp = builder->create_icmp_eq(l_val, r_val);
else
cur_value = builder->create_icmp_eq(lvalue, rvalue);
cmp = builder->create_fcmp_eq(l_val, r_val);
break;
}
case OP_NEQ: {
if (float_cmp)
cur_value = builder->create_fcmp_ne(lvalue, rvalue);
case OP_NEQ:
if (is_int)
cmp = builder->create_icmp_ne(l_val, r_val);
else
cur_value = builder->create_icmp_ne(lvalue, rvalue);
cmp = builder->create_fcmp_ne(l_val, r_val);
break;
}
tmp_val = builder->create_zext(cmp, INT32_T);
}
} else
node.additive_expression_l->accept(*this);
}
// Done
void CminusfBuilder::visit(ASTAdditiveExpression &node) {
//!TODO: This function is empty now.
// Add some code here.
//
if (node.additive_expression) {
void
CminusfBuilder::visit(ASTAdditiveExpression &node) {
if (node.additive_expression == nullptr) {
node.term->accept(*this);
} else {
node.additive_expression->accept(*this);
auto lvalue = cur_value;
auto l_val = tmp_val;
node.term->accept(*this);
auto rvalue = cur_value;
// check type
biop_type_check(lvalue, rvalue, "addop");
bool float_type = lvalue->get_type()->is_float_type();
// now left and right is the same type
auto r_val = tmp_val;
bool is_int = promote(&*builder, &l_val, &r_val);
switch (node.op) {
case OP_PLUS: {
if (float_type)
cur_value = builder->create_fadd(lvalue, rvalue);
case OP_PLUS:
if (is_int)
tmp_val = builder->create_iadd(l_val, r_val);
else
cur_value = builder->create_iadd(lvalue, rvalue);
tmp_val = builder->create_fadd(l_val, r_val);
break;
}
case OP_MINUS: {
if (float_type)
cur_value = builder->create_fsub(lvalue, rvalue);
case OP_MINUS:
if (is_int)
tmp_val = builder->create_isub(l_val, r_val);
else
cur_value = builder->create_isub(lvalue, rvalue);
tmp_val = builder->create_fsub(l_val, r_val);
break;
}
}
} else
node.term->accept(*this);
}
// Done
void CminusfBuilder::visit(ASTTerm &node) {
//!TODO: This function is empty now.
// Add some code here.
if (node.term) {
void
CminusfBuilder::visit(ASTTerm &node) {
if (node.term == nullptr) {
node.factor->accept(*this);
} else {
node.term->accept(*this);
auto lvalue = cur_value;
auto l_val = tmp_val;
node.factor->accept(*this);
auto rvalue = cur_value;
// check type
biop_type_check(lvalue, rvalue, "mul");
bool float_type = lvalue->get_type()->is_float_type();
// now left and right is the same type
auto r_val = tmp_val;
bool is_int = promote(&*builder, &l_val, &r_val);
switch (node.op) {
case OP_MUL: {
if (float_type)
cur_value = builder->create_fmul(lvalue, rvalue);
case OP_MUL:
if (is_int)
tmp_val = builder->create_imul(l_val, r_val);
else
cur_value = builder->create_imul(lvalue, rvalue);
tmp_val = builder->create_fmul(l_val, r_val);
break;
}
case OP_DIV: {
if (float_type)
cur_value = builder->create_fdiv(lvalue, rvalue);
case OP_DIV:
if (is_int)
tmp_val = builder->create_isdiv(l_val, r_val);
else
cur_value = builder->create_isdiv(lvalue, rvalue);
tmp_val = builder->create_fdiv(l_val, r_val);
break;
}
}
} else
node.factor->accept(*this);
}
// Done
void CminusfBuilder::visit(ASTCall &node) {
//!TODO: This function is empty now.
// Add some code here.
Function *func = static_cast<Function *>(scope.find(node.id));
void
CminusfBuilder::visit(ASTCall &node) {
auto fun = static_cast<Function *>(scope.find(node.id));
std::vector<Value *> args;
if (func == nullptr)
error_exit("function " + node.id + " not declared");
if (node.args.size() != func->get_num_of_args())
error_exit("expect " + std::to_string(func->get_num_of_args()) + " params, but " +
std::to_string(node.args.size()) + " is given");
// check every argument
for (int i = 0; i != node.args.size(); ++i) {
// ith parameter's type
Type *param_type = func->get_function_type()->get_param_type(i);
node.args[i]->accept(*this);
// type cast
if (not Type::is_eq_type(param_type, cur_value->get_type())) {
if (param_type->is_pointer_type()) {
// shouldn't need type cast for pointer, logically
if (param_type->get_pointer_element_type()->is_integer_type() or
param_type->get_pointer_element_type()->is_float_type())
error_exit("BUG HERE: ASTVar return value is not int* or float*");
else
error_exit("BUG HERE: function param needs weird pointer type");
} else if (param_type->is_integer_type() or param_type->is_float_type()) {
// need type cast between int and float
if (not cur_value->get_type()->is_integer_type() and not cur_value->get_type()->is_float_type())
error_exit("unexpected type cast!");
if (param_type->is_float_type())
cur_value = builder->create_sitofp(cur_value, FLOAT_T);
else if (param_type->is_integer_type())
if (cur_value->get_type()->is_integer_type())
cur_value = builder->create_zext(cur_value, INT32_T);
else
cur_value = builder->create_fptosi(cur_value, INT32_T);
auto param_type = fun->get_function_type()->param_begin();
for (auto &arg : node.args) {
arg->accept(*this);
if (!tmp_val->get_type()->is_pointer_type() &&
*param_type != tmp_val->get_type()) {
if (tmp_val->get_type()->is_integer_type())
tmp_val = builder->create_sitofp(tmp_val, FLOAT_T);
else
error_exit("unexpected type cast!");
} else
error_exit("unexpected case when casting arguments for function call " + node.id);
tmp_val = builder->create_fptosi(tmp_val, INT32_T);
}
// now cur_value fits the param type
args.push_back(cur_value);
args.push_back(tmp_val);
param_type++;
}
cur_value = builder->create_call(func, args);
tmp_val = builder->create_call(static_cast<Function *>(fun), args);
}
......@@ -10,7 +10,10 @@
#include <cassert>
#include <vector>
Instruction::Instruction(Type *ty, OpID id, unsigned num_ops, BasicBlock *parent)
Instruction::Instruction(Type *ty,
OpID id,
unsigned num_ops,
BasicBlock *parent)
: User(ty, "", num_ops), op_id_(id), num_ops_(num_ops), parent_(parent) {
parent_->add_instruction(this);
}
......@@ -18,56 +21,75 @@ Instruction::Instruction(Type *ty, OpID id, unsigned num_ops, BasicBlock *parent
Instruction::Instruction(Type *ty, OpID id, unsigned num_ops)
: User(ty, "", num_ops), op_id_(id), num_ops_(num_ops), parent_(nullptr) {}
Function *Instruction::get_function() { return parent_->get_parent(); }
Function *
Instruction::get_function() {
return parent_->get_parent();
}
Module *Instruction::get_module() { return parent_->get_module(); }
Module *
Instruction::get_module() {
return parent_->get_module();
}
BinaryInst::BinaryInst(Type *ty, OpID id, Value *v1, Value *v2, BasicBlock *bb) : BaseInst<BinaryInst>(ty, id, 2, bb) {
BinaryInst::BinaryInst(Type *ty, OpID id, Value *v1, Value *v2, BasicBlock *bb)
: BaseInst<BinaryInst>(ty, id, 2, bb) {
set_operand(0, v1);
set_operand(1, v2);
// assertValid();
}
void BinaryInst::assertValid() {
void
BinaryInst::assertValid() {
assert(get_operand(0)->get_type()->is_integer_type());
assert(get_operand(1)->get_type()->is_integer_type());
assert(static_cast<IntegerType *>(get_operand(0)->get_type())->get_num_bits() ==
assert(
static_cast<IntegerType *>(get_operand(0)->get_type())
->get_num_bits() ==
static_cast<IntegerType *>(get_operand(1)->get_type())->get_num_bits());
}
BinaryInst *BinaryInst::create_add(Value *v1, Value *v2, BasicBlock *bb, Module *m) {
BinaryInst *
BinaryInst::create_add(Value *v1, Value *v2, BasicBlock *bb, Module *m) {
return create(Type::get_int32_type(m), Instruction::add, v1, v2, bb);
}
BinaryInst *BinaryInst::create_sub(Value *v1, Value *v2, BasicBlock *bb, Module *m) {
BinaryInst *
BinaryInst::create_sub(Value *v1, Value *v2, BasicBlock *bb, Module *m) {
return create(Type::get_int32_type(m), Instruction::sub, v1, v2, bb);
}
BinaryInst *BinaryInst::create_mul(Value *v1, Value *v2, BasicBlock *bb, Module *m) {
BinaryInst *
BinaryInst::create_mul(Value *v1, Value *v2, BasicBlock *bb, Module *m) {
return create(Type::get_int32_type(m), Instruction::mul, v1, v2, bb);
}
BinaryInst *BinaryInst::create_sdiv(Value *v1, Value *v2, BasicBlock *bb, Module *m) {
BinaryInst *
BinaryInst::create_sdiv(Value *v1, Value *v2, BasicBlock *bb, Module *m) {
return create(Type::get_int32_type(m), Instruction::sdiv, v1, v2, bb);
}
BinaryInst *BinaryInst::create_fadd(Value *v1, Value *v2, BasicBlock *bb, Module *m) {
BinaryInst *
BinaryInst::create_fadd(Value *v1, Value *v2, BasicBlock *bb, Module *m) {
return create(Type::get_float_type(m), Instruction::fadd, v1, v2, bb);
}
BinaryInst *BinaryInst::create_fsub(Value *v1, Value *v2, BasicBlock *bb, Module *m) {
BinaryInst *
BinaryInst::create_fsub(Value *v1, Value *v2, BasicBlock *bb, Module *m) {
return create(Type::get_float_type(m), Instruction::fsub, v1, v2, bb);
}
BinaryInst *BinaryInst::create_fmul(Value *v1, Value *v2, BasicBlock *bb, Module *m) {
BinaryInst *
BinaryInst::create_fmul(Value *v1, Value *v2, BasicBlock *bb, Module *m) {
return create(Type::get_float_type(m), Instruction::fmul, v1, v2, bb);
}
BinaryInst *BinaryInst::create_fdiv(Value *v1, Value *v2, BasicBlock *bb, Module *m) {
BinaryInst *
BinaryInst::create_fdiv(Value *v1, Value *v2, BasicBlock *bb, Module *m) {
return create(Type::get_float_type(m), Instruction::fdiv, v1, v2, bb);
}
std::string BinaryInst::print() {
std::string
BinaryInst::print() {
std::string instr_ir;
instr_ir += "%";
instr_ir += this->get_name();
......@@ -78,7 +100,8 @@ std::string BinaryInst::print() {
instr_ir += " ";
instr_ir += print_as_op(this->get_operand(0), false);
instr_ir += ", ";
if (Type::is_eq_type(this->get_operand(0)->get_type(), this->get_operand(1)->get_type())) {
if (Type::is_eq_type(this->get_operand(0)->get_type(),
this->get_operand(1)->get_type())) {
instr_ir += print_as_op(this->get_operand(1), false);
} else {
instr_ir += print_as_op(this->get_operand(1), true);
......@@ -93,18 +116,27 @@ CmpInst::CmpInst(Type *ty, CmpOp op, Value *lhs, Value *rhs, BasicBlock *bb)
// assertValid();
}
void CmpInst::assertValid() {
void
CmpInst::assertValid() {
assert(get_operand(0)->get_type()->is_integer_type());
assert(get_operand(1)->get_type()->is_integer_type());
assert(static_cast<IntegerType *>(get_operand(0)->get_type())->get_num_bits() ==
assert(
static_cast<IntegerType *>(get_operand(0)->get_type())
->get_num_bits() ==
static_cast<IntegerType *>(get_operand(1)->get_type())->get_num_bits());
}
CmpInst *CmpInst::create_cmp(CmpOp op, Value *lhs, Value *rhs, BasicBlock *bb, Module *m) {
CmpInst *
CmpInst::create_cmp(CmpOp op,
Value *lhs,
Value *rhs,
BasicBlock *bb,
Module *m) {
return create(m->get_int1_type(), op, lhs, rhs, bb);
}
std::string CmpInst::print() {
std::string
CmpInst::print() {
std::string instr_ir;
instr_ir += "%";
instr_ir += this->get_name();
......@@ -117,7 +149,8 @@ std::string CmpInst::print() {
instr_ir += " ";
instr_ir += print_as_op(this->get_operand(0), false);
instr_ir += ", ";
if (Type::is_eq_type(this->get_operand(0)->get_type(), this->get_operand(1)->get_type())) {
if (Type::is_eq_type(this->get_operand(0)->get_type(),
this->get_operand(1)->get_type())) {
instr_ir += print_as_op(this->get_operand(1), false);
} else {
instr_ir += print_as_op(this->get_operand(1), true);
......@@ -132,16 +165,23 @@ FCmpInst::FCmpInst(Type *ty, CmpOp op, Value *lhs, Value *rhs, BasicBlock *bb)
// assertValid();
}
void FCmpInst::assert_valid() {
void
FCmpInst::assert_valid() {
assert(get_operand(0)->get_type()->is_float_type());
assert(get_operand(1)->get_type()->is_float_type());
}
FCmpInst *FCmpInst::create_fcmp(CmpOp op, Value *lhs, Value *rhs, BasicBlock *bb, Module *m) {
FCmpInst *
FCmpInst::create_fcmp(CmpOp op,
Value *lhs,
Value *rhs,
BasicBlock *bb,
Module *m) {
return create(m->get_int1_type(), op, lhs, rhs, bb);
}
std::string FCmpInst::print() {
std::string
FCmpInst::print() {
std::string instr_ir;
instr_ir += "%";
instr_ir += this->get_name();
......@@ -154,7 +194,8 @@ std::string FCmpInst::print() {
instr_ir += " ";
instr_ir += print_as_op(this->get_operand(0), false);
instr_ir += ",";
if (Type::is_eq_type(this->get_operand(0)->get_type(), this->get_operand(1)->get_type())) {
if (Type::is_eq_type(this->get_operand(0)->get_type(),
this->get_operand(1)->get_type())) {
instr_ir += print_as_op(this->get_operand(1), false);
} else {
instr_ir += print_as_op(this->get_operand(1), true);
......@@ -163,7 +204,10 @@ std::string FCmpInst::print() {
}
CallInst::CallInst(Function *func, std::vector<Value *> args, BasicBlock *bb)
: BaseInst<CallInst>(func->get_return_type(), Instruction::call, args.size() + 1, bb) {
: BaseInst<CallInst>(func->get_return_type(),
Instruction::call,
args.size() + 1,
bb) {
assert(func->get_num_of_args() == args.size());
int num_ops = args.size() + 1;
set_operand(0, func);
......@@ -172,13 +216,18 @@ CallInst::CallInst(Function *func, std::vector<Value *> args, BasicBlock *bb)
}
}
CallInst *CallInst::create(Function *func, std::vector<Value *> args, BasicBlock *bb) {
CallInst *
CallInst::create(Function *func, std::vector<Value *> args, BasicBlock *bb) {
return BaseInst<CallInst>::create(func, args, bb);
}
FunctionType *CallInst::get_function_type() const { return static_cast<FunctionType *>(get_operand(0)->get_type()); }
FunctionType *
CallInst::get_function_type() const {
return static_cast<FunctionType *>(get_operand(0)->get_type());
}
std::string CallInst::print() {
std::string
CallInst::print() {
std::string instr_ir;
if (!this->is_void()) {
instr_ir += "%";
......@@ -190,7 +239,8 @@ std::string CallInst::print() {
instr_ir += this->get_function_type()->get_return_type()->print();
instr_ir += " ";
assert(dynamic_cast<Function *>(this->get_operand(0)) && "Wrong call operand function");
assert(dynamic_cast<Function *>(this->get_operand(0)) &&
"Wrong call operand function");
instr_ir += print_as_op(this->get_operand(0), false);
instr_ir += "(";
for (int i = 1; i < this->get_num_operand(); i++) {
......@@ -204,19 +254,32 @@ std::string CallInst::print() {
return instr_ir;
}
BranchInst::BranchInst(Value *cond, BasicBlock *if_true, BasicBlock *if_false, BasicBlock *bb)
: BaseInst<BranchInst>(Type::get_void_type(if_true->get_module()), Instruction::br, 3, bb) {
BranchInst::BranchInst(Value *cond,
BasicBlock *if_true,
BasicBlock *if_false,
BasicBlock *bb)
: BaseInst<BranchInst>(Type::get_void_type(if_true->get_module()),
Instruction::br,
3,
bb) {
set_operand(0, cond);
set_operand(1, if_true);
set_operand(2, if_false);
}
BranchInst::BranchInst(BasicBlock *if_true, BasicBlock *bb)
: BaseInst<BranchInst>(Type::get_void_type(if_true->get_module()), Instruction::br, 1, bb) {
: BaseInst<BranchInst>(Type::get_void_type(if_true->get_module()),
Instruction::br,
1,
bb) {
set_operand(0, if_true);
}
BranchInst *BranchInst::create_cond_br(Value *cond, BasicBlock *if_true, BasicBlock *if_false, BasicBlock *bb) {
BranchInst *
BranchInst::create_cond_br(Value *cond,
BasicBlock *if_true,
BasicBlock *if_false,
BasicBlock *bb) {
if_true->add_pre_basic_block(bb);
if_false->add_pre_basic_block(bb);
bb->add_succ_basic_block(if_false);
......@@ -225,16 +288,21 @@ BranchInst *BranchInst::create_cond_br(Value *cond, BasicBlock *if_true, BasicBl
return create(cond, if_true, if_false, bb);
}
BranchInst *BranchInst::create_br(BasicBlock *if_true, BasicBlock *bb) {
BranchInst *
BranchInst::create_br(BasicBlock *if_true, BasicBlock *bb) {
if_true->add_pre_basic_block(bb);
bb->add_succ_basic_block(if_true);
return create(if_true, bb);
}
bool BranchInst::is_cond_br() const { return get_num_operand() == 3; }
bool
BranchInst::is_cond_br() const {
return get_num_operand() == 3;
}
std::string BranchInst::print() {
std::string
BranchInst::print() {
std::string instr_ir;
instr_ir += this->get_module()->get_instr_op_name(this->get_instr_type());
instr_ir += " ";
......@@ -250,20 +318,36 @@ std::string BranchInst::print() {
}
ReturnInst::ReturnInst(Value *val, BasicBlock *bb)
: BaseInst<ReturnInst>(Type::get_void_type(bb->get_module()), Instruction::ret, 1, bb) {
: BaseInst<ReturnInst>(Type::get_void_type(bb->get_module()),
Instruction::ret,
1,
bb) {
set_operand(0, val);
}
ReturnInst::ReturnInst(BasicBlock *bb)
: BaseInst<ReturnInst>(Type::get_void_type(bb->get_module()), Instruction::ret, 0, bb) {}
: BaseInst<ReturnInst>(Type::get_void_type(bb->get_module()),
Instruction::ret,
0,
bb) {}
ReturnInst *ReturnInst::create_ret(Value *val, BasicBlock *bb) { return create(val, bb); }
ReturnInst *
ReturnInst::create_ret(Value *val, BasicBlock *bb) {
return create(val, bb);
}
ReturnInst *ReturnInst::create_void_ret(BasicBlock *bb) { return create(bb); }
ReturnInst *
ReturnInst::create_void_ret(BasicBlock *bb) {
return create(bb);
}
bool ReturnInst::is_void_ret() const { return get_num_operand() == 0; }
bool
ReturnInst::is_void_ret() const {
return get_num_operand() == 0;
}
std::string ReturnInst::print() {
std::string
ReturnInst::print() {
std::string instr_ir;
instr_ir += this->get_module()->get_instr_op_name(this->get_instr_type());
instr_ir += " ";
......@@ -278,7 +362,9 @@ std::string ReturnInst::print() {
return instr_ir;
}
GetElementPtrInst::GetElementPtrInst(Value *ptr, std::vector<Value *> idxs, BasicBlock *bb)
GetElementPtrInst::GetElementPtrInst(Value *ptr,
std::vector<Value *> idxs,
BasicBlock *bb)
: BaseInst<GetElementPtrInst>(PointerType::get(get_element_type(ptr, idxs)),
Instruction::getelementptr,
1 + idxs.size(),
......@@ -290,9 +376,11 @@ GetElementPtrInst::GetElementPtrInst(Value *ptr, std::vector<Value *> idxs, Basi
element_ty_ = get_element_type(ptr, idxs);
}
Type *GetElementPtrInst::get_element_type(Value *ptr, std::vector<Value *> idxs) {
Type *
GetElementPtrInst::get_element_type(Value *ptr, std::vector<Value *> idxs) {
Type *ty = ptr->get_type()->get_pointer_element_type();
assert("GetElementPtrInst ptr is wrong type" &&
assert(
"GetElementPtrInst ptr is wrong type" &&
(ty->is_array_type() || ty->is_integer_type() || ty->is_float_type()));
if (ty->is_array_type()) {
ArrayType *arr_ty = static_cast<ArrayType *>(ty);
......@@ -309,13 +397,20 @@ Type *GetElementPtrInst::get_element_type(Value *ptr, std::vector<Value *> idxs)
return ty;
}
Type *GetElementPtrInst::get_element_type() const { return element_ty_; }
Type *
GetElementPtrInst::get_element_type() const {
return element_ty_;
}
GetElementPtrInst *GetElementPtrInst::create_gep(Value *ptr, std::vector<Value *> idxs, BasicBlock *bb) {
GetElementPtrInst *
GetElementPtrInst::create_gep(Value *ptr,
std::vector<Value *> idxs,
BasicBlock *bb) {
return create(ptr, idxs, bb);
}
std::string GetElementPtrInst::print() {
std::string
GetElementPtrInst::print() {
std::string instr_ir;
instr_ir += "%";
instr_ir += this->get_name();
......@@ -323,7 +418,8 @@ std::string GetElementPtrInst::print() {
instr_ir += this->get_module()->get_instr_op_name(this->get_instr_type());
instr_ir += " ";
assert(this->get_operand(0)->get_type()->is_pointer_type());
instr_ir += this->get_operand(0)->get_type()->get_pointer_element_type()->print();
instr_ir +=
this->get_operand(0)->get_type()->get_pointer_element_type()->print();
instr_ir += ", ";
for (int i = 0; i < this->get_num_operand(); i++) {
if (i > 0)
......@@ -336,14 +432,21 @@ std::string GetElementPtrInst::print() {
}
StoreInst::StoreInst(Value *val, Value *ptr, BasicBlock *bb)
: BaseInst<StoreInst>(Type::get_void_type(bb->get_module()), Instruction::store, 2, bb) {
: BaseInst<StoreInst>(Type::get_void_type(bb->get_module()),
Instruction::store,
2,
bb) {
set_operand(0, val);
set_operand(1, ptr);
}
StoreInst *StoreInst::create_store(Value *val, Value *ptr, BasicBlock *bb) { return create(val, ptr, bb); }
StoreInst *
StoreInst::create_store(Value *val, Value *ptr, BasicBlock *bb) {
return create(val, ptr, bb);
}
std::string StoreInst::print() {
std::string
StoreInst::print() {
std::string instr_ir;
instr_ir += this->get_module()->get_instr_op_name(this->get_instr_type());
instr_ir += " ";
......@@ -355,19 +458,27 @@ std::string StoreInst::print() {
return instr_ir;
}
LoadInst::LoadInst(Type *ty, Value *ptr, BasicBlock *bb) : BaseInst<LoadInst>(ty, Instruction::load, 1, bb) {
LoadInst::LoadInst(Type *ty, Value *ptr, BasicBlock *bb)
: BaseInst<LoadInst>(ty, Instruction::load, 1, bb) {
assert(ptr->get_type()->is_pointer_type());
assert(ty == static_cast<PointerType *>(ptr->get_type())->get_element_type());
assert(ty ==
static_cast<PointerType *>(ptr->get_type())->get_element_type());
set_operand(0, ptr);
}
LoadInst *LoadInst::create_load(Type *ty, Value *ptr, BasicBlock *bb) { return create(ty, ptr, bb); }
LoadInst *
LoadInst::create_load(Type *ty, Value *ptr, BasicBlock *bb) {
return create(ty, ptr, bb);
}
Type *LoadInst::get_load_type() const {
return static_cast<PointerType *>(get_operand(0)->get_type())->get_element_type();
Type *
LoadInst::get_load_type() const {
return static_cast<PointerType *>(get_operand(0)->get_type())
->get_element_type();
}
std::string LoadInst::print() {
std::string
LoadInst::print() {
std::string instr_ir;
instr_ir += "%";
instr_ir += this->get_name();
......@@ -375,7 +486,8 @@ std::string LoadInst::print() {
instr_ir += this->get_module()->get_instr_op_name(this->get_instr_type());
instr_ir += " ";
assert(this->get_operand(0)->get_type()->is_pointer_type());
instr_ir += this->get_operand(0)->get_type()->get_pointer_element_type()->print();
instr_ir +=
this->get_operand(0)->get_type()->get_pointer_element_type()->print();
instr_ir += ",";
instr_ir += " ";
instr_ir += print_as_op(this->get_operand(0), true);
......@@ -383,13 +495,21 @@ std::string LoadInst::print() {
}
AllocaInst::AllocaInst(Type *ty, BasicBlock *bb)
: BaseInst<AllocaInst>(PointerType::get(ty), Instruction::alloca, 0, bb), alloca_ty_(ty) {}
: BaseInst<AllocaInst>(PointerType::get(ty), Instruction::alloca, 0, bb)
, alloca_ty_(ty) {}
AllocaInst *AllocaInst::create_alloca(Type *ty, BasicBlock *bb) { return create(ty, bb); }
AllocaInst *
AllocaInst::create_alloca(Type *ty, BasicBlock *bb) {
return create(ty, bb);
}
Type *AllocaInst::get_alloca_type() const { return alloca_ty_; }
Type *
AllocaInst::get_alloca_type() const {
return alloca_ty_;
}
std::string AllocaInst::print() {
std::string
AllocaInst::print() {
std::string instr_ir;
instr_ir += "%";
instr_ir += this->get_name();
......@@ -400,15 +520,23 @@ std::string AllocaInst::print() {
return instr_ir;
}
ZextInst::ZextInst(OpID op, Value *val, Type *ty, BasicBlock *bb) : BaseInst<ZextInst>(ty, op, 1, bb), dest_ty_(ty) {
ZextInst::ZextInst(OpID op, Value *val, Type *ty, BasicBlock *bb)
: BaseInst<ZextInst>(ty, op, 1, bb), dest_ty_(ty) {
set_operand(0, val);
}
ZextInst *ZextInst::create_zext(Value *val, Type *ty, BasicBlock *bb) { return create(Instruction::zext, val, ty, bb); }
ZextInst *
ZextInst::create_zext(Value *val, Type *ty, BasicBlock *bb) {
return create(Instruction::zext, val, ty, bb);
}
Type *ZextInst::get_dest_type() const { return dest_ty_; }
Type *
ZextInst::get_dest_type() const {
return dest_ty_;
}
std::string ZextInst::print() {
std::string
ZextInst::print() {
std::string instr_ir;
instr_ir += "%";
instr_ir += this->get_name();
......@@ -428,13 +556,18 @@ FpToSiInst::FpToSiInst(OpID op, Value *val, Type *ty, BasicBlock *bb)
set_operand(0, val);
}
FpToSiInst *FpToSiInst::create_fptosi(Value *val, Type *ty, BasicBlock *bb) {
FpToSiInst *
FpToSiInst::create_fptosi(Value *val, Type *ty, BasicBlock *bb) {
return create(Instruction::fptosi, val, ty, bb);
}
Type *FpToSiInst::get_dest_type() const { return dest_ty_; }
Type *
FpToSiInst::get_dest_type() const {
return dest_ty_;
}
std::string FpToSiInst::print() {
std::string
FpToSiInst::print() {
std::string instr_ir;
instr_ir += "%";
instr_ir += this->get_name();
......@@ -454,13 +587,18 @@ SiToFpInst::SiToFpInst(OpID op, Value *val, Type *ty, BasicBlock *bb)
set_operand(0, val);
}
SiToFpInst *SiToFpInst::create_sitofp(Value *val, Type *ty, BasicBlock *bb) {
SiToFpInst *
SiToFpInst::create_sitofp(Value *val, Type *ty, BasicBlock *bb) {
return create(Instruction::sitofp, val, ty, bb);
}
Type *SiToFpInst::get_dest_type() const { return dest_ty_; }
Type *
SiToFpInst::get_dest_type() const {
return dest_ty_;
}
std::string SiToFpInst::print() {
std::string
SiToFpInst::print() {
std::string instr_ir;
instr_ir += "%";
instr_ir += this->get_name();
......@@ -475,7 +613,11 @@ std::string SiToFpInst::print() {
return instr_ir;
}
PhiInst::PhiInst(OpID op, std::vector<Value *> vals, std::vector<BasicBlock *> val_bbs, Type *ty, BasicBlock *bb)
PhiInst::PhiInst(OpID op,
std::vector<Value *> vals,
std::vector<BasicBlock *> val_bbs,
Type *ty,
BasicBlock *bb)
: BaseInst<PhiInst>(ty, op, 2 * vals.size()) {
for (int i = 0; i < vals.size(); i++) {
set_operand(2 * i, vals[i]);
......@@ -484,13 +626,15 @@ PhiInst::PhiInst(OpID op, std::vector<Value *> vals, std::vector<BasicBlock *> v
this->set_parent(bb);
}
PhiInst *PhiInst::create_phi(Type *ty, BasicBlock *bb) {
PhiInst *
PhiInst::create_phi(Type *ty, BasicBlock *bb) {
std::vector<Value *> vals;
std::vector<BasicBlock *> val_bbs;
return create(Instruction::phi, vals, val_bbs, ty, bb);
}
std::string PhiInst::print() {
std::string
PhiInst::print() {
std::string instr_ir;
instr_ir += "%";
instr_ir += this->get_name();
......@@ -508,9 +652,12 @@ std::string PhiInst::print() {
instr_ir += print_as_op(this->get_operand(2 * i + 1), false);
instr_ir += " ]";
}
if (this->get_num_operand() / 2 < this->get_parent()->get_pre_basic_blocks().size()) {
if (this->get_num_operand() / 2 <
this->get_parent()->get_pre_basic_blocks().size()) {
for (auto pre_bb : this->get_parent()->get_pre_basic_blocks()) {
if (std::find(this->get_operands().begin(), this->get_operands().end(), static_cast<Value *>(pre_bb)) ==
if (std::find(this->get_operands().begin(),
this->get_operands().end(),
static_cast<Value *>(pre_bb)) ==
this->get_operands().end()) {
// find a pre_bb is not in phi
instr_ir += ", [ undef, " + print_as_op(pre_bb, false) + " ]";
......
......@@ -209,6 +209,17 @@ print_partitions(const GVN::partitions &p) {
}
} // namespace utils
GVN::partitions
GVN::join_helper(BasicBlock *pre1, BasicBlock *pre2) {
assert(not _TOP[pre1] or not _TOP[pre2] && "should flow here, not jump");
if (_TOP[pre1])
return pout_[pre2];
else if (_TOP[pre2])
return pout_[pre1];
return join(pout_[pre1], pout_[pre2]);
}
GVN::partitions
GVN::join(const partitions &P1, const partitions &P2) {
// TODO: do intersection pair-wise
......@@ -228,7 +239,6 @@ std::shared_ptr<CongruenceClass>
GVN::intersect(std::shared_ptr<CongruenceClass> ci,
std::shared_ptr<CongruenceClass> cj) {
// TODO
// If no common members, return null
auto c = createCongruenceClass();
std::set<Value *> intersection;
......@@ -240,10 +250,21 @@ GVN::intersect(std::shared_ptr<CongruenceClass> ci,
c->members_ = intersection;
if (ci->index_ == cj->index_)
c->index_ = ci->index_;
if (ci->leader_ == cj->leader_)
c->leader_ = cj->leader_;
/* if (*ci == *cj)
* return ci; */
if (ci->value_expr_ == cj->value_expr_)
c->value_expr_ = ci->value_expr_;
if (ci->value_phi_ and cj->value_phi_ and
*ci->value_phi_ == *cj->value_phi_)
c->value_phi_ = ci->value_phi_;
// if (c->members_.size() or c->value_expr_ or c->value_phi_) // not empty
// ??
// What if the ve is nullptr?
if (c->members_.size()) // not empty
if (c->index_ == 0) {
c->index_ = new_number();
c->value_phi_ =
PhiExpression::create(ci->value_expr_, cj->value_expr_);
}
return c;
}
......@@ -251,30 +272,31 @@ GVN::intersect(std::shared_ptr<CongruenceClass> ci,
void
GVN::detectEquivalences() {
bool changed;
std::cout << "all the instruction address:" << std::endl;
for (auto &bb : func_->get_basic_blocks()) {
for (auto &instr : bb.get_instructions())
std::cout << &instr << "\t" << instr.print() << std::endl;
}
// initialize pout with top
for (auto &bb : func_->get_basic_blocks()) {
// pin_[&bb].clear();
// pout_[&bb].clear();
for (auto &instr : bb.get_instructions())
_TOP[&instr] = true;
_TOP[&bb] = true;
}
// modify entry block
auto Entry = func_->get_entry_block();
_TOP[&*Entry->get_instructions().begin()] = false;
_TOP[Entry] = false;
pin_[Entry].clear();
pout_[Entry].clear(); // pout_[Entry] = transferFunction(Entry);
pout_[Entry] = transferFunction(Entry);
// iterate until converge
do {
changed = false;
// see the pseudo code in documentation
for (auto &_bb :
func_->get_basic_blocks()) { // you might need to visit the
// blocks in depth-first order
for (auto &_bb : func_->get_basic_blocks()) {
auto bb = &_bb;
// get PIN of bb by predecessor(s)
if (bb == Entry)
continue;
// get PIN of bb from predecessor(s)
auto pre_bbs_ = bb->get_pre_basic_blocks();
if (bb != Entry) {
// only update PIN for blocks that are not Entry
......@@ -283,12 +305,12 @@ GVN::detectEquivalences() {
case 2: {
auto pre_1 = *pre_bbs_.begin();
auto pre_2 = *(++pre_bbs_.begin());
pin_[bb] = join(pin_[pre_1], pin_[pre_2]);
pin_[bb] = join_helper(pre_1, pre_2);
break;
}
case 1: {
auto pre = *(pre_bbs_.begin());
pin_[bb] = clone(pin_[pre]);
pin_[bb] = pout_[pre];
break;
}
default:
......@@ -297,82 +319,246 @@ GVN::detectEquivalences() {
abort();
}
}
auto part = pin_[bb];
// iterate through all instructions in the block
for (auto &instr : bb->get_instructions()) {
// ??
if (not instr.is_phi())
part = transferFunction(&instr, instr.get_operand(1), part);
}
// and the phi instruction in all the successors
for (auto succ : bb->get_succ_basic_blocks()) {
for (auto &instr : succ->get_instructions())
if (instr.is_phi()) {
Instruction *pretend;
// ??
part = transferFunction(
pretend, instr.get_operand(1), part);
}
}
auto part = transferFunction(bb);
// check changes in pout
changed |= not(part == pout_[bb]);
pout_[bb] = part;
_TOP[bb] = false;
}
} while (changed);
}
shared_ptr<Expression>
GVN::valueExpr(Instruction *instr) {
GVN::valueExpr(Instruction *instr, partitions *part) {
// TODO
return {};
// ?? should use part?
std::string err{"Undefined"};
std::cout << instr->print() << std::endl;
if (instr->isBinary() or instr->is_cmp() or instr->is_fcmp()) {
auto op1 = instr->get_operand(0);
auto op2 = instr->get_operand(1);
auto op1_const = dynamic_cast<Constant *>(op1);
auto op2_const = dynamic_cast<Constant *>(op2);
if (op1_const and op2_const) {
// both are constant number, so:
// constant fold!
return ConstantExpression::create(
folder_->compute(instr, op1_const, op2_const));
} else {
// both none constant
auto op1_instr = dynamic_cast<Instruction *>(op1);
auto op2_instr = dynamic_cast<Instruction *>(op2);
assert((op1_instr or op1_const) and (op2_instr or op2_const) &&
"must be this case");
return BinaryExpression::create(
instr->get_instr_type(),
(op1_const ? ConstantExpression::create(op1_const)
: valueExpr(op1_instr)),
(op2_const ? ConstantExpression::create(op2_const)
: valueExpr(op2_instr)));
}
} else if (instr->is_phi()) {
err = "phi";
} else if (instr->is_fp2si() or instr->is_si2fp() or instr->is_zext()) {
auto op = instr->get_operand(0);
auto op_const = dynamic_cast<Constant *>(op);
auto op_instr = dynamic_cast<Instruction *>(op);
assert(op_instr or op_const);
// get dest type
auto instr_fp2si = dynamic_cast<FpToSiInst *>(instr);
auto instr_si2fp = dynamic_cast<SiToFpInst *>(instr);
auto instr_zext = dynamic_cast<ZextInst *>(instr);
Type *dest_type = nullptr;
if (instr_fp2si)
dest_type = instr_fp2si->get_dest_type();
else if (instr_si2fp)
dest_type = instr_si2fp->get_dest_type();
else if (instr_zext)
dest_type = instr_zext->get_dest_type();
else
err = "cast";
if (dest_type) {
if (op_const)
return ConstantExpression::create(
folder_->compute(instr, op_const));
else
return CastExpression::create(
instr->get_instr_type(), valueExpr(op_instr), dest_type);
}
} else if (instr->is_gep()) {
auto operands = instr->get_operands();
auto ptr = operands[0];
std::vector<std::shared_ptr<Expression>> idxs;
// check for base address
assert(not dynamic_cast<Constant *>(ptr) and
dynamic_cast<Instruction *>(ptr) &&
"base address should only be from instruction");
// set idxes
for (int i = 1; i < operands.size(); i++) {
if (dynamic_cast<Constant *>(operands[i]))
idxs.push_back(ConstantExpression::create(
dynamic_cast<Constant *>(operands[i])));
else {
assert(dynamic_cast<Instruction *>(operands[i]));
idxs.push_back(
valueExpr(dynamic_cast<Instruction *>(operands[i])));
}
}
return GEPExpression::create(
valueExpr(dynamic_cast<Instruction *>(ptr)), idxs);
} else if (instr->is_load() or instr->is_alloca() or instr->is_call()) {
return UniqueExpression::create(instr);
}
std::cerr << "Undefined case: " << err << std::endl;
abort();
}
// instruction of the form `x = e`, mostly x is just e (SSA),
// but for copy stmt x is a phi instruction in the successor.
// Phi values (not copy stmt) should be handled in detectEquiv
//
// assert the x is an instruction that can generate a new value
//
/// \param bb basic block in which the transfer function is called
GVN::partitions
GVN::transferFunction(Instruction *x, Value *e, partitions pin) {
partitions pout = clone(pin);
partitions pout = pin;
// TODO: deal with copy-stmt case
// ?? deal with copy statement
auto e_instr = dynamic_cast<Instruction *>(e);
auto e_const = dynamic_cast<Constant *>(e);
assert((not e or e_instr or e_const) &&
"A value must be from an instruction or constant");
// erase the old record for x
std::set<Value *>::iterator it;
for (auto c : pin)
if ((it = std::find(c->members_.begin(), c->members_.end(), x)) !=
c->members_.end()) {
c->members_.erase(it);
}
// TODO: get different ValueExpr by Instruction::OpID, modify pout
std::set<Value *>::iterator iter;
// ??
// get ve and vpf
shared_ptr<Expression> ve;
if (e) {
if (e_const)
ve = ConstantExpression::create(e_const);
else
ve = valueExpr(e_instr, &pin);
} else
ve = valueExpr(x, &pin);
auto vpf = valuePhiFunc(ve, curr_bb);
for (auto c : pout) {
if ((iter = std::find(c->members_.begin(), c->members_.end(), x)) !=
c->members_.end()) {
// static_cast<Value *>(x))) != c->members_.end()) {
c->members_.erase(iter);
if (ve == c->value_expr_ or (vpf and vpf == c->value_phi_)) {
c->value_expr_ = ve;
c->members_.insert(x);
} else {
auto c = createCongruenceClass(new_number());
c->members_.insert(x);
c->value_expr_ = ve;
c->value_phi_ = vpf;
pout.insert(c);
}
}
auto ve = valueExpr(x);
auto vpf = valuePhiFunc(ve, pin);
/* pout.insert({});
* auto c = CongruenceClass(new_number());
* c.leader_ = e; */
/* // first version: ignore ve and vpf
* // and only update index, leader and members
* auto c = createCongruenceClass(new_number());
* c->leader_ = x;
* c->members_.insert(x);
* pout.insert(c); */
return pout;
}
/*
* read the pin for the block and then execute transferFunction() for all
* instructions inside.
*/
GVN::partitions
GVN::transferFunction(BasicBlock *bb) {
partitions pout = clone(pin_[bb]);
// ??
return pout;
curr_bb = bb;
int res;
auto part = pin_[bb];
/* LOG_INFO << "transferFunction(bb=" << bb->get_name() << ")\n";
* LOG_INFO << "pin:\n";
* utils::print_partitions(pin_[bb]);
* LOG_INFO << "pout before:\n";
* utils::print_partitions(pout_[bb]); */
// iterate through all instructions in the block
for (auto &instr : bb->get_instructions()) {
// ?? what about orther instructions? Are they all ok?
if (not instr.is_phi() and not instr.is_void())
part = transferFunction(&instr, nullptr, part);
}
// and the phi instruction in all the successors
for (auto succ : bb->get_succ_basic_blocks()) {
for (auto &instr : succ->get_instructions()) {
if (instr.is_phi()) {
if ((res = pretend_copy_stmt(&instr, bb) == -1))
continue;
part = transferFunction(&instr, instr.get_operand(res), part);
}
}
}
/* LOG_INFO << "pout after:\n";
* utils::print_partitions(part);
* std::cout << std::endl; */
return part;
}
shared_ptr<PhiExpression>
GVN::valuePhiFunc(shared_ptr<Expression> ve, const partitions &P) {
GVN::valuePhiFunc(shared_ptr<Expression> ve, BasicBlock *bb) {
// TODO
return {};
if (ve->get_expr_type() != Expression::e_bin)
return nullptr;
auto ve_bin = static_cast<BinaryExpression *>(ve.get());
if (ve_bin->lhs_->get_expr_type() != Expression::e_phi or
ve_bin->rhs_->get_expr_type() != Expression::e_phi)
return nullptr;
// get 2 phi expressions
auto lhs = static_cast<PhiExpression *>(ve_bin->lhs_.get());
auto rhs = static_cast<PhiExpression *>(ve_bin->rhs_.get());
// get 2 predecessors
auto pre_bbs_ = bb->get_pre_basic_blocks();
auto pre_1 = *pre_bbs_.begin();
auto pre_2 = *(++pre_bbs_.begin());
// try to get the merged value expression
auto vl_merge = BinaryExpression::create(ve_bin->op_, lhs->lhs_, rhs->lhs_);
auto vr_merge = BinaryExpression::create(ve_bin->op_, lhs->rhs_, rhs->rhs_);
auto vi = getVN(pout_[pre_1], vl_merge);
auto vj = getVN(pout_[pre_2], vr_merge);
if (vi == nullptr)
vi = valuePhiFunc(vl_merge, pre_1);
if (vj == nullptr)
vj = valuePhiFunc(vr_merge, pre_2);
if (vi and vj)
return PhiExpression::create(vi, vj);
else
return nullptr;
}
shared_ptr<Expression>
GVN::getVN(const partitions &pout, shared_ptr<Expression> ve) {
// TODO: return what?
/* for (auto c : pout) {
* if (c->value_expr_ == ve)
* return ve;
* } */
for (auto it = pout.begin(); it != pout.end(); it++)
if ((*it)->value_expr_ and *(*it)->value_expr_ == *ve)
return {};
return ve;
return nullptr;
}
......@@ -490,6 +676,12 @@ GVNExpression::operator==(const Expression &lhs, const Expression &rhs) {
return equiv_as<BinaryExpression>(lhs, rhs);
case Expression::e_phi:
return equiv_as<PhiExpression>(lhs, rhs);
case Expression::e_cast:
return equiv_as<CastExpression>(lhs, rhs);
case Expression::e_gep:
return equiv_as<GEPExpression>(lhs, rhs);
case Expression::e_unique:
return equiv_as<UniqueExpression>(lhs, rhs);
}
}
......@@ -517,12 +709,35 @@ operator==(const GVN::partitions &p1, const GVN::partitions &p2) {
// cannot direct compare???
if (p1.size() != p2.size())
return false;
return std::equal(p1.begin(), p1.end(), p2.begin(), p2.end());
auto it1 = p1.begin();
auto it2 = p2.begin();
for (; it1 != p1.end(); ++it1, ++it2)
if (not(**it1 == **it2))
return false;
return true;
}
// only compare index
// only compare members
bool
CongruenceClass::operator==(const CongruenceClass &other) const {
// TODO: which fields need to be compared?
return index_ == other.index_;
if (members_.size() != other.members_.size())
return false;
return members_ == other.members_;
}
int
GVN::pretend_copy_stmt(Instruction *instr, BasicBlock *bb) {
auto phi = static_cast<PhiInst *>(instr);
// res = phi [op1, name1], [op2, name2]
// ^0 ^1 ^2 ^3
if (phi->get_operand(1)->get_name() == bb->get_name()) {
// pretend copy statement:
// `res = op1`
return 0;
} else if (phi->get_operand(3)->get_name() == bb->get_name()) {
// `res = op2`
return 2;
}
return -1;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment