diff --git a/Reports/4.2-gvn/report.md b/Reports/4.2-gvn/report.md index d290e359f9533ec66fe37e41d1cefaed08ed8b1a..971bfa99ad283da14a91d2dd0bda12bff45ec727 100644 --- a/Reports/4.2-gvn/report.md +++ b/Reports/4.2-gvn/report.md @@ -8,12 +8,265 @@ ## 实验难点 -实验中遇到哪些挑战 +### 对于函数`detectEquivalences(G)` + +```cpp +detectEquivalences(G) + PIN1 = {} // “1” is the first statement in the program + POUT1 = transferFunction(PIN1) + for each statement s other than the first statement in the program + POUTs = Top + while changes to any POUT occur // i.e. changes in equivalences + for each statement s other than the first statement in the program + if s appears in block b that has two predecessors + then + PINs = Join(POUTs1, POUTs2) // s1 and s2 are last statements in respective predecessors + else + PINs = POUTs3 // s3 the statement just before s + POUTs = transferFunction(PINs) // apply transferFunction on each statement in the block +``` + +- POUT、PIN是对于语句的的map,但是实现中是对于基本块的map,怎么处理区别? + + 注意到基本块的pin、pout对应着首条语句的pin和末尾语句的pout,所以没有本质上的不同,主要是怎么补全中间语句的pin、pout。 + + 中间语句的pin、pout应该不会被显式保存,只在对基本块执行转移函数时,逐条计算。所以应该没啥问题。 + +- 怎么设计top? + + 尽管算法伪代码中,top是对每一条语句的赋值的,但是任意一条语句的都有所属基本块,只要对基本块的pout标注top、就能保证基本块的pin合乎逻辑,逐条执行转移函数后,每条语句的pin都合法。`` + +- `while changes to any POUT occur` + + 这句话如此轻松,但是在实现中出现了问题:这涉及到怎么判断分区相等,做不好会出现死循环。 + + 如果单纯比较index恐怕不行:因为我目前的`transferFunction`会生成一个新的等价类,赋予一个新的编号。 + + 所以我使用对等价类中的`members`做相等的判断,如果两个分区中的等价类都能找到相等的,那么就判断分区相等。 + + > 但是这样依赖于set的排序, + > + > - 对于分区,这个排序依赖于等价类中index的值,侥幸来说每条指令按顺序执行,即使index不一致(每次转移函数分配最新的index),大小顺序也是一致的? + > + > - 对于等价类,里边是`Value*`指针,这个指针的等价就是语义上的等价,没问题。 + + 改过来后发现还是死循环,DEBUG很久意识到我的分区相等写的有问题:`return members_ == other.members_`,这样会挨个儿比较`members_`指针的值,而不是我预想的比较指针指向的等价类。所以要手动遍历一下。 + + 这个遍历也不是很直观的,因为要做两层解引用: + + ```cpp + for (; it1 != p1.end(); ++it1, ++it2) + if (not(**it1 == **it2)) + return false; + ``` + + `*it`是等价类的指针,`**it`才是等价类。 + + 除此之外,我还忘了`++it2`,淦。 + + 改完这些,终于能跑通第一版本了…… + +### 等价类设计 + +> ```cpp +> shared_ptr +> GVN::valueExpr(Instruction *instr); +> ``` + +等价类涉及到针对指令设计、常量折叠、递归计算等,很复杂,需要拎出来单独讨论。 + +首先注意到传给传进去的参数是指令类型(指针)。需要`transferFunction`处理的指令,一次赋值的右式自然可以由指令类型得到,所以没有毛病。 + +指令有多种类型,也即lightIR中有多种赋值方式,而非单纯的伪代码中的二元运算。但是即使是已经在算法中讨论过的二元运算,写起来也不容易。 + +#### 生成等价类的逻辑 + +首先讨论生成等价类的逻辑: + +- 二元运算`y op z` + + 如果y、z都是常量,那么可以进行常量折叠,得到**一个**`ConstantExpression`类型。 + + 否则返回的是一个`BinaryExpression`类型,这个类型需要左右操作数,都是`Expression`类型,如果操作数是常量,使用`ConstantExpression`创建新的值表达式,否则递归调用`valueExpr`。 + + > 这里为什么递归调用不会走到未定义? + +- phi函数 + + 我们逻辑上把其改变成了copy,`transferFunction`应该维护这一点,所以认为没有phi函数传入。 + + 具体是`transferFunction`函数中的这几句话: + + ```cpp + GVN::partitions + GVN::transferFunction(Instruction *x, Value *e, partitions pin) + { + auto e_instr = dynamic_cast(e); + auto e_const = dynamic_cast(e); + assert((not e or e_instr or e_const) && + "A value must be from an instruction or constant"); + // …… + if (e) { + if (e_const) + ve = ConstantExpression::create(e_const); + else + ve = valueExpr(e_instr, &pin); + } else + ve = valueExpr(x, &pin); + } + ``` + +- 比较函数:和二元运算一致。 + +- 类型转换:添加新的表达式类型`CastExpression`,处理和二元操作类似。 + +- 其他产生赋值的指令,也会由`transferFunction`传递给`ValueExpr`,这些指令是:`load`、`alloca`和返回值非空的`call`。 + + 这里设计一种值表达式类型:`UniqueExpression`,这个表达式和任何其他都不等价,表现为`equiv`直接返回false。 + +#### 若干需要梳理的问题 + +除此之外,关于等价类有几个命题一定要说明,这些是理解GVN的关键。 + +- 等价类中的ve一定存在吗? + + 根据等价类的设计,ve一定是存在的,且等价类中`members_`中一定存在元素。 + +- ve一定唯一吗? + + 从逻辑上,等价类集合中的元素享有共同的ve产生的结果,ve是唯一的 + + 从实现上,添加 + +- ve或者vpf能够作为代表吗? + + 这个问题源于`transferFunction`的一句话: + + > if ve or vpf is in a class Ci in POUTs + +- 代表元如何选取?什么时候选取? + +- 怎么判断等价类相等 + + 根据论坛上这个[例子](http://cscourse.ustc.edu.cn/forum/thread.jsp?forum=91&thread=68),我发现值编号会有非必要的增长,即最终收敛时用到的编号是1、5、6,其中2、3、4都是迭代中用到的但是最后都没有体现。 + + 从这个例子中可以发现,迭代收敛时,是pout中对应分区的`members_`不变,值编号还是会变化的,所以判断等价类相等,应该比较`members_`集合。 + +- 什么时候给新的值编号? + + 1. 执行转移函数发现没有可以归属的等价类时 + + 2. 求交巴拉巴拉还没搞清楚那里 + +- 等价类为空的判定,可以用`members_`做判断吗? + + 我认为是可以的,存在情况:两个基本块汇合时,交出一个集合,它的members为空,但是ve不空: + + ``` + BB0: y = +      z = + ; pout: [{v1,y}, {v2,z}] + + BB1: x1 = y + z ; preds = BB0 + ; pout: [{v1,y}, {v2,z}, {v3, x1, v1+v2}] + + BB2: x2 = y + z ; preds = BB0 + ; pout: [{v1,y}, {v2,z}, {v4, x2, v1+v2}] + + BB3: ... ; preds = BB1, BB2 + ``` + + 进入BB3时,对BB1和BB2的pout中等价类分别取交,执行`Intersect(Ci={v3, x1, v1+v2}, Cj={v4, x2, v1+v2}`,得到的是`{v1+v2}`,`members_`中没有成员,但是ve存在,我认为这种情况也应该判定结果为空集。 + + 因为最终反映到IR上,我们直接关心的都是`members_`中的成员,所以用`members_`的空代表等价类的空我觉得是合适的。 + +| 类型 | 处理 | +| ---------------- | ------------------------- | +| 二元运算 | 依伪代码 | +| 没有返回值的 | 不处理 | +| phi函数 | 本块中的不管,后继块的做考虑 | +| 函数调用 | 有返回值的考虑等价类,**注意后边纯函数的处理** | +| 产生指针(GEP、alloca) | 先为返回值单独新建等价类 | +| 类型转换及0扩展 | 单独新建等价类 | +| 比较 | 根据op和操作数递归判断,应该和二元运算类似 | +| load | 单独新建等价类 | + +### 对于函数`Intersect(Ci, Cj )` + +伪代码中,一个等价类是一个集合,求交后,v~k~是自然而然的在或不在新集合中,但是实现中并不这么简单,如何对这样的两个结构体进行逻辑上的取交: + +```cpp +struct CongruenceClass { + size_t index_; + Value *leader_; + std::shared_ptr value_expr_; + std::shared_ptr value_phi_; + std::set members_; +} +``` + +- 首先是Intersect的伪代码描述中,下边这句怎么在实现中体现? + + > C~k~ does not have value number + + 解决:index不同的自然就交不出v~i~,相同的时候才有v~i~即value number. + + 隐患:transferFunction要对v~i~有继承性,即不能每次涉及一个等价类都新开一个value number + +- 根据上一条讨论,得出一个观点:在伪代码中,一个等价类集合会出现的内容,在实现中的对应如下: + + - 普通变量:`members_` + + - 值表达式:`value_expr_` + + - phi函数:`value_phi_` + + - 值编号v~i~:`index` + + 因此,伪代码中对集合的求交,对应到实现上,就是对以上四个域求交。 + +### 对于函数`valuePhiFunc(ve, P)` + +- 输入就一个分区,实际上要用到两个前驱的分区,怎么处理? + + 传入基本块做参数。 + +- 二元运算的顺序颠倒怎么办? + + > 如phi(x1,y1)+phi(x2,y2)和phi(x1,y1)+phi(y2,x2) + > + > 对于第二种,单纯寻找x1+y2和y1+x2肯定不合逻辑 + + 应该不会有这种情况:phi表达式,追根溯源是在join时产生的: + + ```cpp + GVN::join(const partitions &P1, const partitions &P2); + ``` + + 这里的分区是严格按照`Basicblock`的方法`get_pre_basic_blocks()`的返回顺序传入的,所以不会有担心的情况发生。 + +### 所见非所得 + +两方面: + +- 论文中伪代码对于基本块的关注不是很多,但是我们的实现一定要围绕基本块进行,这个的解决对策说到了,其实和单条语句没有本质差异。 + +- 我们针对lightir优化,经常深入API中陷入到实现细节,还容易被C++的语言特性绊住,但是应该尝试从直观上理解:我们究竟要实现什么样的效果。这个就得去看ll文件找灵感。然后这还不够,回过头来看代码又呆住了,因为ll语句和那些类型对应不上,这个就是容易忽视的一点:看lightir类型的print函数,这个能帮我们直观地串联起实现细节和表象的ll文件。 ## 实验设计 + +### join_helper + +### _TOP + +### transferFunction(Basicblock) + +### + 实现思路,相应代码,优化前后的IR对比(举一个例子)并辅以简单说明 ### 思考题 + 1. 请简要分析你的算法复杂度 2. `std::shared_ptr`如果存在环形引用,则无法正确释放内存,你的 Expression 类是否存在 circular reference? 3. 尽管本次实验已经写了很多代码,但是在算法上和工程上仍然可以对 GVN 进行改进,请简述你的 GVN 实现可以改进的地方 @@ -25,4 +278,3 @@ ## 实验反馈(可选 不会评分) 对本次实验的建议 - diff --git a/include/lightir/Instruction.h b/include/lightir/Instruction.h index b8571e1548142f82b7423c88dfbcfa79b7665e1b..b63c880651282f26a1f3729486c736208198c9dc 100644 --- a/include/lightir/Instruction.h +++ b/include/lightir/Instruction.h @@ -9,8 +9,7 @@ class BasicBlock; class Function; -class Instruction : public User, public llvm::ilist_node -{ +class Instruction : public User, public llvm::ilist_node { public: enum OpID { // Terminator Instructions @@ -49,11 +48,11 @@ class Instruction : public User, public llvm::ilist_node Instruction(const Instruction &) = delete; virtual ~Instruction() = default; inline const BasicBlock *get_parent() const { return parent_; } - inline BasicBlock *get_parent() { return parent_; } + inline BasicBlock *get_parent() { return parent_; } void set_parent(BasicBlock *parent) { this->parent_ = parent; } // Return the function this instruction belongs to. Function *get_function(); - Module *get_module(); + Module *get_module(); OpID get_instr_type() const { return op_id_; } // clang-format off @@ -86,8 +85,7 @@ class Instruction : public User, public llvm::ilist_node // clang-format on std::string get_instr_op_name() { return get_instr_op_name(op_id_); } - bool is_void() - { + bool is_void() { return ((op_id_ == ret) || (op_id_ == br) || (op_id_ == store) || (op_id_ == call && this->get_type()->is_void_type())); } @@ -108,6 +106,7 @@ class Instruction : public User, public llvm::ilist_node bool is_fsub() { return op_id_ == fsub; } bool is_fmul() { return op_id_ == fmul; } bool is_fdiv() { return op_id_ == fdiv; } + bool is_fp2si() { return op_id_ == fptosi; } bool is_si2fp() { return op_id_ == sitofp; } @@ -118,8 +117,7 @@ class Instruction : public User, public llvm::ilist_node bool is_gep() { return op_id_ == getelementptr; } bool is_zext() { return op_id_ == zext; } - bool isBinary() - { + bool isBinary() { return (is_add() || is_sub() || is_mul() || is_div() || is_fadd() || is_fsub() || is_fmul() || is_fdiv()) && (get_num_operand() == 2); @@ -128,39 +126,34 @@ class Instruction : public User, public llvm::ilist_node bool isTerminator() { return is_br() || is_ret(); } private: - OpID op_id_; - unsigned num_ops_; + OpID op_id_; + unsigned num_ops_; BasicBlock *parent_; }; -namespace detail -{ - template - struct tag - { - using type = T; - }; - template - struct select_last - { - // Use a fold-expression to fold the comma operator over the parameter - // pack. - using type = typename decltype((tag{}, ...))::type; - }; - template - using select_last_t = typename select_last::type; +namespace detail { +template +struct tag { + using type = T; +}; +template +struct select_last { + // Use a fold-expression to fold the comma operator over the parameter + // pack. + using type = typename decltype((tag{}, ...))::type; +}; +template +using select_last_t = typename select_last::type; }; // namespace detail template inline constexpr bool always_false_v = false; template -class BaseInst : public Instruction -{ +class BaseInst : public Instruction { protected: template - static Inst *create(Args &&...args) - { + static Inst *create(Args &&...args) { if constexpr (std::is_same_v< std::decay_t>, BasicBlock *>) { @@ -171,13 +164,10 @@ class BaseInst : public Instruction } template - BaseInst(Args &&...args) : Instruction(std::forward(args)...) - { - } + BaseInst(Args &&...args) : Instruction(std::forward(args)...) {} }; -class BinaryInst : public BaseInst -{ +class BinaryInst : public BaseInst { friend BaseInst; private: @@ -185,35 +175,51 @@ class BinaryInst : public BaseInst public: // create add instruction, auto insert to bb - static BinaryInst *create_add(Value *v1, Value *v2, BasicBlock *bb, + static BinaryInst *create_add(Value *v1, + Value *v2, + BasicBlock *bb, Module *m); // create sub instruction, auto insert to bb - static BinaryInst *create_sub(Value *v1, Value *v2, BasicBlock *bb, + static BinaryInst *create_sub(Value *v1, + Value *v2, + BasicBlock *bb, Module *m); // create mul instruction, auto insert to bb - static BinaryInst *create_mul(Value *v1, Value *v2, BasicBlock *bb, + static BinaryInst *create_mul(Value *v1, + Value *v2, + BasicBlock *bb, Module *m); // create Div instruction, auto insert to bb - static BinaryInst *create_sdiv(Value *v1, Value *v2, BasicBlock *bb, + static BinaryInst *create_sdiv(Value *v1, + Value *v2, + BasicBlock *bb, Module *m); // create fadd instruction, auto insert to bb - static BinaryInst *create_fadd(Value *v1, Value *v2, BasicBlock *bb, + static BinaryInst *create_fadd(Value *v1, + Value *v2, + BasicBlock *bb, Module *m); // create fsub instruction, auto insert to bb - static BinaryInst *create_fsub(Value *v1, Value *v2, BasicBlock *bb, + static BinaryInst *create_fsub(Value *v1, + Value *v2, + BasicBlock *bb, Module *m); // create fmul instruction, auto insert to bb - static BinaryInst *create_fmul(Value *v1, Value *v2, BasicBlock *bb, + static BinaryInst *create_fmul(Value *v1, + Value *v2, + BasicBlock *bb, Module *m); // create fDiv instruction, auto insert to bb - static BinaryInst *create_fdiv(Value *v1, Value *v2, BasicBlock *bb, + static BinaryInst *create_fdiv(Value *v1, + Value *v2, + BasicBlock *bb, Module *m); virtual std::string print() override; @@ -222,8 +228,7 @@ class BinaryInst : public BaseInst void assertValid(); }; -class CmpInst : public BaseInst -{ +class CmpInst : public BaseInst { friend BaseInst; public: @@ -240,7 +245,10 @@ class CmpInst : public BaseInst CmpInst(Type *ty, CmpOp op, Value *lhs, Value *rhs, BasicBlock *bb); public: - static CmpInst *create_cmp(CmpOp op, Value *lhs, Value *rhs, BasicBlock *bb, + static CmpInst *create_cmp(CmpOp op, + Value *lhs, + Value *rhs, + BasicBlock *bb, Module *m); CmpOp get_cmp_op() { return cmp_op_; } @@ -253,8 +261,7 @@ class CmpInst : public BaseInst void assertValid(); }; -class FCmpInst : public BaseInst -{ +class FCmpInst : public BaseInst { friend BaseInst; public: @@ -271,8 +278,11 @@ class FCmpInst : public BaseInst FCmpInst(Type *ty, CmpOp op, Value *lhs, Value *rhs, BasicBlock *bb); public: - static FCmpInst *create_fcmp(CmpOp op, Value *lhs, Value *rhs, - BasicBlock *bb, Module *m); + static FCmpInst *create_fcmp(CmpOp op, + Value *lhs, + Value *rhs, + BasicBlock *bb, + Module *m); CmpOp get_cmp_op() { return cmp_op_; } @@ -284,33 +294,36 @@ class FCmpInst : public BaseInst void assert_valid(); }; -class CallInst : public BaseInst -{ +class CallInst : public BaseInst { friend BaseInst; protected: CallInst(Function *func, std::vector args, BasicBlock *bb); public: - static CallInst *create(Function *func, std::vector args, + static CallInst *create(Function *func, + std::vector args, BasicBlock *bb); - FunctionType *get_function_type() const; + FunctionType *get_function_type() const; virtual std::string print() override; }; -class BranchInst : public BaseInst -{ +class BranchInst : public BaseInst { friend BaseInst; private: - BranchInst(Value *cond, BasicBlock *if_true, BasicBlock *if_false, + BranchInst(Value *cond, + BasicBlock *if_true, + BasicBlock *if_false, BasicBlock *bb); BranchInst(BasicBlock *if_true, BasicBlock *bb); public: - static BranchInst *create_cond_br(Value *cond, BasicBlock *if_true, - BasicBlock *if_false, BasicBlock *bb); + static BranchInst *create_cond_br(Value *cond, + BasicBlock *if_true, + BasicBlock *if_false, + BasicBlock *bb); static BranchInst *create_br(BasicBlock *if_true, BasicBlock *bb); bool is_cond_br() const; @@ -318,8 +331,7 @@ class BranchInst : public BaseInst virtual std::string print() override; }; -class ReturnInst : public BaseInst -{ +class ReturnInst : public BaseInst { friend BaseInst; private: @@ -329,13 +341,12 @@ class ReturnInst : public BaseInst public: static ReturnInst *create_ret(Value *val, BasicBlock *bb); static ReturnInst *create_void_ret(BasicBlock *bb); - bool is_void_ret() const; + bool is_void_ret() const; virtual std::string print() override; }; -class GetElementPtrInst : public BaseInst -{ +class GetElementPtrInst : public BaseInst { friend BaseInst; private: @@ -343,9 +354,10 @@ class GetElementPtrInst : public BaseInst public: static Type *get_element_type(Value *ptr, std::vector idxs); - static GetElementPtrInst *create_gep(Value *ptr, std::vector idxs, + static GetElementPtrInst *create_gep(Value *ptr, + std::vector idxs, BasicBlock *bb); - Type *get_element_type() const; + Type *get_element_type() const; virtual std::string print() override; @@ -353,8 +365,7 @@ class GetElementPtrInst : public BaseInst Type *element_ty_; }; -class StoreInst : public BaseInst -{ +class StoreInst : public BaseInst { friend BaseInst; private: @@ -369,8 +380,7 @@ class StoreInst : public BaseInst virtual std::string print() override; }; -class LoadInst : public BaseInst -{ +class LoadInst : public BaseInst { friend BaseInst; private: @@ -378,15 +388,14 @@ class LoadInst : public BaseInst public: static LoadInst *create_load(Type *ty, Value *ptr, BasicBlock *bb); - Value *get_lval() { return this->get_operand(0); } + Value *get_lval() { return this->get_operand(0); } Type *get_load_type() const; virtual std::string print() override; }; -class AllocaInst : public BaseInst -{ +class AllocaInst : public BaseInst { friend BaseInst; private: @@ -403,8 +412,7 @@ class AllocaInst : public BaseInst Type *alloca_ty_; }; -class ZextInst : public BaseInst -{ +class ZextInst : public BaseInst { friend BaseInst; private: @@ -421,8 +429,7 @@ class ZextInst : public BaseInst Type *dest_ty_; }; -class FpToSiInst : public BaseInst -{ +class FpToSiInst : public BaseInst { friend BaseInst; private: @@ -439,8 +446,7 @@ class FpToSiInst : public BaseInst Type *dest_ty_; }; -class SiToFpInst : public BaseInst -{ +class SiToFpInst : public BaseInst { friend BaseInst; private: @@ -457,25 +463,24 @@ class SiToFpInst : public BaseInst Type *dest_ty_; }; -class PhiInst : public BaseInst -{ +class PhiInst : public BaseInst { friend BaseInst; private: - PhiInst(OpID op, std::vector vals, - std::vector val_bbs, Type *ty, BasicBlock *bb); + PhiInst(OpID op, + std::vector vals, + std::vector val_bbs, + Type *ty, + BasicBlock *bb); PhiInst(Type *ty, OpID op, unsigned num_ops, BasicBlock *bb) - : BaseInst(ty, op, num_ops, bb) - { - } + : BaseInst(ty, op, num_ops, bb) {} Value *l_val_; public: static PhiInst *create_phi(Type *ty, BasicBlock *bb); - Value *get_lval() { return l_val_; } - void set_lval(Value *l_val) { l_val_ = l_val; } - void add_phi_pair_operand(Value *val, Value *pre_bb) - { + Value *get_lval() { return l_val_; } + void set_lval(Value *l_val) { l_val_ = l_val; } + void add_phi_pair_operand(Value *val, Value *pre_bb) { this->add_operand(val); this->add_operand(pre_bb); } diff --git a/include/optimization/GVN.h b/include/optimization/GVN.h index df408a550fe636e595c9b195aad4f3756426fd1b..144ddfce0fe1a7f7466721d33eef0b7630ba6d84 100644 --- a/include/optimization/GVN.h +++ b/include/optimization/GVN.h @@ -19,6 +19,7 @@ #include #include +class GVN; namespace GVNExpression { // fold the constant value @@ -35,12 +36,13 @@ class ConstFolder { /** * for constructor of class derived from `Expression`, we make it public * because `std::make_shared` needs the constructor to be publicly available, - * but you should call the static factory method `create` instead the constructor itself to get the desired data + * but you should call the static factory method `create` instead the + * constructor itself to get the desired data */ class Expression { public: // TODO: you need to extend expression types according to testcases - enum gvn_expr_t { e_constant, e_bin, e_phi }; + enum gvn_expr_t { e_constant, e_bin, e_phi, e_cast, e_gep, e_unique }; Expression(gvn_expr_t t) : expr_type(t) {} virtual ~Expression() = default; virtual std::string print() = 0; @@ -50,15 +52,21 @@ class Expression { gvn_expr_t expr_type; }; -bool operator==(const std::shared_ptr &lhs, const std::shared_ptr &rhs); -bool operator==(const GVNExpression::Expression &lhs, const GVNExpression::Expression &rhs); +bool operator==(const std::shared_ptr &lhs, + const std::shared_ptr &rhs); +bool operator==(const GVNExpression::Expression &lhs, + const GVNExpression::Expression &rhs); class ConstantExpression : public Expression { public: - static std::shared_ptr create(Constant *c) { return std::make_shared(c); } + static std::shared_ptr create(Constant *c) { + return std::make_shared(c); + } virtual std::string print() { return c_->print(); } // we leverage the fact that constants in lightIR have unique addresses - bool equiv(const ConstantExpression *other) const { return c_ == other->c_; } + bool equiv(const ConstantExpression *other) const { + return c_ == other->c_; + } ConstantExpression(Constant *c) : Expression(e_constant), c_(c) {} private: @@ -67,24 +75,31 @@ class ConstantExpression : public Expression { // arithmetic expression class BinaryExpression : public Expression { + friend class ::GVN; + public: - static std::shared_ptr create(Instruction::OpID op, - std::shared_ptr lhs, - std::shared_ptr rhs) { + static std::shared_ptr create( + Instruction::OpID op, + std::shared_ptr lhs, + std::shared_ptr rhs) { return std::make_shared(op, lhs, rhs); } virtual std::string print() { - return "(" + Instruction::get_instr_op_name(op_) + " " + lhs_->print() + " " + rhs_->print() + ")"; + return "(" + Instruction::get_instr_op_name(op_) + " " + lhs_->print() + + " " + rhs_->print() + ")"; } bool equiv(const BinaryExpression *other) const { - if (op_ == other->op_ and *lhs_ == *other->lhs_ and *rhs_ == *other->rhs_) + if (op_ == other->op_ and *lhs_ == *other->lhs_ and + *rhs_ == *other->rhs_) return true; else return false; } - BinaryExpression(Instruction::OpID op, std::shared_ptr lhs, std::shared_ptr rhs) + BinaryExpression(Instruction::OpID op, + std::shared_ptr lhs, + std::shared_ptr rhs) : Expression(e_bin), op_(op), lhs_(lhs), rhs_(rhs) {} private: @@ -93,33 +108,122 @@ class BinaryExpression : public Expression { }; class PhiExpression : public Expression { + friend class ::GVN; + public: - static std::shared_ptr create(std::shared_ptr lhs, std::shared_ptr rhs) { + static std::shared_ptr create( + std::shared_ptr lhs, + std::shared_ptr rhs) { return std::make_shared(lhs, rhs); } - virtual std::string print() { return "(phi " + lhs_->print() + " " + rhs_->print() + ")"; } + virtual std::string print() { + return "(phi " + lhs_->print() + " " + rhs_->print() + ")"; + } bool equiv(const PhiExpression *other) const { if (*lhs_ == *other->lhs_ and *rhs_ == *other->rhs_) return true; else return false; } - PhiExpression(std::shared_ptr lhs, std::shared_ptr rhs) + PhiExpression(std::shared_ptr lhs, + std::shared_ptr rhs) : Expression(e_phi), lhs_(lhs), rhs_(rhs) {} private: std::shared_ptr lhs_, rhs_; }; + +// type cast expression +class CastExpression : public Expression { + public: + static std::shared_ptr create( + Instruction::OpID op, + std::shared_ptr src, + Type *dest_type) { + return std::make_shared(op, src, dest_type); + } + virtual std::string print() { + return "(" + dest_ty_->print() + " " + + Instruction::get_instr_op_name(op_) + " " + src_->print() + ")"; + } + + bool equiv(const CastExpression *other) const { + return op_ == other->op_ and src_ == other->src_ and + dest_ty_ == other->dest_ty_; + } + + CastExpression(Instruction::OpID op, + std::shared_ptr src, + Type *dest_type) + : Expression(e_cast), op_(op), src_(src), dest_ty_(dest_type) {} + + private: + Instruction::OpID op_; + std::shared_ptr src_; + Type *dest_ty_; +}; + +// type cast expression +class GEPExpression : public Expression { + public: + static std::shared_ptr create( + std::shared_ptr ptr, + std::vector> &idxs) { + return std::make_shared(ptr, idxs); + } + virtual std::string print() { + std::string ret = "(GEP " + ptr_->print(); + for (auto idx : idxs_) + ret += " " + idx->print(); + return ret + ")"; + } + + bool equiv(const GEPExpression *other) const { + if (idxs_.size() != other->idxs_.size()) + return false; + for (int i = 0; i != idxs_.size(); ++i) + if (not(idxs_[i] == other->idxs_[i])) + return false; + return ptr_ == other->ptr_; + } + + GEPExpression(std::shared_ptr ptr, + std::vector> &idxs) + : Expression(e_gep), ptr_(ptr), idxs_(idxs) {} + + private: + std::shared_ptr ptr_; + std::vector> idxs_; +}; + +// unique expression: not equal to any one else +class UniqueExpression : public Expression { + public: + static std::shared_ptr create(Instruction *instr) { + return std::make_shared(instr); + } + virtual std::string print() { return "(UNIQUE " + instr_->print() + ")"; } + + bool equiv(const UniqueExpression *other) const { return false; } + + UniqueExpression(Instruction *instr) + : Expression(e_unique), instr_(instr) {} + + private: + std::shared_ptr instr_; +}; } // namespace GVNExpression /** * Congruence class in each partitions * note: for constant propagation, you might need to add other fields - * and for load/store redundancy detection, you most certainly need to modify the class + * and for load/store redundancy detection, you most certainly need to modify + * the class */ struct CongruenceClass { size_t index_; - // representative of the congruence class, used to replace all the members (except itself) when analysis is done + // representative of the congruence class, used to replace all the members + // (except itself) when analysis is done Value *leader_; // value expression in congruence class std::shared_ptr value_expr_; @@ -128,17 +232,22 @@ struct CongruenceClass { // equivalent variables in one congruence class std::set members_; - CongruenceClass(size_t index) : index_(index), leader_{}, value_expr_{}, value_phi_{}, members_{} {} + CongruenceClass(size_t index) + : index_(index), leader_{}, value_expr_{}, value_phi_{}, members_{} {} - bool operator<(const CongruenceClass &other) const { return this->index_ < other.index_; } + bool operator<(const CongruenceClass &other) const { + return this->index_ < other.index_; + } bool operator==(const CongruenceClass &other) const; }; namespace std { template <> -// overload std::less for std::shared_ptr, i.e. how to sort the congruence classes +// overload std::less for std::shared_ptr, i.e. how to sort the +// congruence classes struct less> { - bool operator()(const std::shared_ptr &a, const std::shared_ptr &b) const { + bool operator()(const std::shared_ptr &a, + const std::shared_ptr &b) const { // nullptrs should never appear in partitions, so we just dereference it return *a < *b; } @@ -154,17 +263,24 @@ class GVN : public Pass { // init for pass metadata; void initPerFunction(); - // fill the following functions according to Pseudocode, **you might need to add more arguments** + // fill the following functions according to Pseudocode, **you might need to + // add more arguments** void detectEquivalences(); partitions join(const partitions &P1, const partitions &P2); - std::shared_ptr intersect(std::shared_ptr, std::shared_ptr); + std::shared_ptr intersect( + std::shared_ptr, + std::shared_ptr); partitions transferFunction(Instruction *x, Value *e, partitions pin); partitions transferFunction(BasicBlock *bb); - std::shared_ptr valuePhiFunc(std::shared_ptr, - const partitions &); - std::shared_ptr valueExpr(Instruction *instr); - std::shared_ptr getVN(const partitions &pout, - std::shared_ptr ve); + std::shared_ptr valuePhiFunc( + std::shared_ptr, + BasicBlock *bb); + std::shared_ptr valueExpr( + Instruction *instr, + partitions *part = nullptr); + std::shared_ptr getVN( + const partitions &pout, + std::shared_ptr ve); // replace cc members with leader void replace_cc_members(); @@ -180,6 +296,7 @@ class GVN : public Pass { // self add // std::uint64_t new_number() { return next_value_number_++; } + static int pretend_copy_stmt(Instruction *inst, BasicBlock *bb); private: bool dump_json_; @@ -191,7 +308,9 @@ class GVN : public Pass { std::unique_ptr dce_; // self add - std::map _TOP; + std::map _TOP; + partitions join_helper(BasicBlock *pre1, BasicBlock *pre2); + BasicBlock* curr_bb; }; bool operator==(const GVN::partitions &p1, const GVN::partitions &p2); diff --git a/src/cminusfc/.gitignore b/src/cminusfc/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..80d0420b11d310cde3d4c75755db4fc8c0190bc0 --- /dev/null +++ b/src/cminusfc/.gitignore @@ -0,0 +1 @@ +cminusf_builder_stu.cpp diff --git a/src/cminusfc/cminusf_builder.cpp b/src/cminusfc/cminusf_builder.cpp index d33f33c342b5e35f3942fb7a5a8a55fc09bdcc2d..a11b279761fa95ca39a39bb75c37488072c04e3e 100644 --- a/src/cminusfc/cminusf_builder.cpp +++ b/src/cminusfc/cminusf_builder.cpp @@ -5,22 +5,20 @@ #include "cminusf_builder.hpp" -#include "logging.hpp" - #define CONST_FP(num) ConstantFP::get((float)num, module.get()) #define CONST_INT(num) ConstantInt::get(num, module.get()) -// TODO: Global Variable Declarations // You can define global variables here -// to store state. You can expand these -// definitions if you need to. +// to store state -// the latest return value -Value *cur_value = nullptr; -// if var is assignment's left part, LV is true -bool LV = false; +// store temporary value +Value *tmp_val = nullptr; +// whether require lvalue +bool require_lvalue = false; // function that is being built Function *cur_fun = nullptr; +// detect scope pre-enter (for elegance only) +bool pre_enter_scope = false; // types Type *VOID_T; @@ -30,9 +28,22 @@ Type *INT32PTR_T; Type *FLOAT_T; Type *FLOATPTR_T; -// initializer -ConstantZero *I32Initializer; -ConstantZero *FloatInitializer; +bool +promote(IRBuilder *builder, Value **l_val_p, Value **r_val_p) { + bool is_int; + auto &l_val = *l_val_p; + auto &r_val = *r_val_p; + if (l_val->get_type() == r_val->get_type()) { + is_int = l_val->get_type()->is_integer_type(); + } else { + is_int = false; + if (l_val->get_type()->is_integer_type()) + l_val = builder->create_sitofp(l_val, FLOAT_T); + else + r_val = builder->create_sitofp(r_val, FLOAT_T); + } + return is_int; +} /* * use CMinusfBuilder::Scope to construct scopes @@ -42,182 +53,61 @@ ConstantZero *FloatInitializer; * scope.find: find and return the value bound to the name */ -void error_exit(std::string s) { - LOG_ERROR << s; - std::abort(); -} - -// This function makes sure that -// 1. 2 values have same type -// 2. type is either i32 or float -void CminusfBuilder::biop_type_check(Value *&lvalue, Value *&rvalue, std::string util) { - if (Type::is_eq_type(lvalue->get_type(), rvalue->get_type())) { - if (lvalue->get_type()->is_integer_type() or lvalue->get_type()->is_float_type()) { - // check for i1 - if (Type::is_eq_type(lvalue->get_type(), INT1_T)) { - lvalue = builder->create_zext(lvalue, INT32_T); - rvalue = builder->create_zext(rvalue, INT32_T); - } - - } else - error_exit("not supported type cast for " + util); - return; - } - - // only support cast between int and float: i32, i1, float - // - // case that integer and float is mixed, directly cast integer to float - if (lvalue->get_type()->is_integer_type() and rvalue->get_type()->is_float_type()) - lvalue = builder->create_sitofp(lvalue, FLOAT_T); - else if (lvalue->get_type()->is_float_type() and rvalue->get_type()->is_integer_type()) - rvalue = builder->create_sitofp(rvalue, FLOAT_T); - else if (lvalue->get_type()->is_integer_type() and rvalue->get_type()->is_integer_type()) { - // case that I32 and I1 mixed - if (Type::is_eq_type(lvalue->get_type(), INT1_T)) - lvalue = builder->create_zext(lvalue, INT32_T); - else - rvalue = builder->create_zext(rvalue, INT32_T); - } else { // we only support computing among i1, i32 and float - error_exit("not supported type cast for " + util); - } -} - -// this function makes sure value is a bool type -void CminusfBuilder::cast_to_i1(Value *&value) { - assert(value->get_type()->is_integer_type() or value->get_type()->is_float_type()); - if (value->get_type()->is_float_type()) - // value = builder->create_fptosi(value, INT1_T); - value = builder->create_fcmp_ne(value, CONST_FP(0)); - else if (Type::is_eq_type(value->get_type(), INT32_T)) - value = builder->create_icmp_ne(value, CONST_INT(0)); -} - -void CminusfBuilder::visit(ASTProgram &node) { +void +CminusfBuilder::visit(ASTProgram &node) { VOID_T = Type::get_void_type(module.get()); INT1_T = Type::get_int1_type(module.get()); INT32_T = Type::get_int32_type(module.get()); INT32PTR_T = Type::get_int32_ptr_type(module.get()); FLOAT_T = Type::get_float_type(module.get()); FLOATPTR_T = Type::get_float_ptr_type(module.get()); - I32Initializer = ConstantZero::get(INT32_T, builder->get_module()); - FloatInitializer = ConstantZero::get(FLOAT_T, builder->get_module()); for (auto decl : node.declarations) { decl->accept(*this); } } -// Done -void CminusfBuilder::visit(ASTNum &node) { - //!TODO: This function is empty now. - // Add some code here. - // - switch (node.type) { - case TYPE_INT: - cur_value = CONST_INT(node.i_val); - return; - case TYPE_FLOAT: - cur_value = CONST_FP(node.f_val); - return; - default: - error_exit("ASTNum is not int or float"); - } +void +CminusfBuilder::visit(ASTNum &node) { + if (node.type == TYPE_INT) + tmp_val = CONST_INT(node.i_val); + else + tmp_val = CONST_FP(node.f_val); } -// Done -void CminusfBuilder::visit(ASTVarDeclaration &node) { - //!TODO: This function is empty now. - // Add some code here. - bool global = (builder->get_insert_block() == nullptr); - if (node.num) { - // declares an array - // - // get array size - node.num->accept(*this); - // - // !no type cast here! - if (not(node.num->type == TYPE_INT)) - error_exit("size of array has non-integer type"); - - int size = node.num->i_val; - if (size <= 0) - error_exit("array size[" + std::to_string(size) + "] <= 0"); - - switch (node.type) { - case TYPE_INT: { - auto I32Array_T = Type::get_array_type(INT32_T, size); - if (global) - cur_value = - GlobalVariable::create(node.id, builder->get_module(), I32Array_T, false, I32Initializer); - else - cur_value = builder->create_alloca(I32Array_T); - break; - } - - case TYPE_FLOAT: { - auto FloatArray_T = Type::get_array_type(FLOAT_T, size); - if (global) - cur_value = - GlobalVariable::create(node.id, builder->get_module(), FloatArray_T, false, FloatInitializer); - else - cur_value = builder->create_alloca(FloatArray_T); - break; - } - default: - error_exit("Variable type(not array) is not int or float"); +void +CminusfBuilder::visit(ASTVarDeclaration &node) { + Type *var_type; + if (node.type == TYPE_INT) + var_type = Type::get_int32_type(module.get()); + else + var_type = Type::get_float_type(module.get()); + if (node.num == nullptr) { + if (scope.in_global()) { + auto initializer = ConstantZero::get(var_type, module.get()); + auto var = GlobalVariable::create( + node.id, module.get(), var_type, false, initializer); + scope.push(node.id, var); + } else { + auto var = builder->create_alloca(var_type); + scope.push(node.id, var); } - assert(cur_value->get_type()->is_pointer_type() && "IF SEE THIS: API ERROR"); - } else { - // flat int or float type - switch (node.type) { - case TYPE_INT: - if (global) - cur_value = GlobalVariable::create(node.id, builder->get_module(), INT32_T, false, I32Initializer); - else - cur_value = builder->create_alloca(INT32_T); - break; - - case TYPE_FLOAT: - if (global) - cur_value = - GlobalVariable::create(node.id, builder->get_module(), FLOAT_T, false, FloatInitializer); - else { - /* Beautiful is better than ugly. - * Explicit is better than implicit. - * Simple is better than complex. - * Complex is better than complicated. - * Flat is better than nested. - * Sparse is better than dense. - * Readability counts. - * Special cases aren't special enough to break the rules. - * Although practicality beats purity. - * Errors should never pass silently. - * Unless explicitly silenced. - * In the face of ambiguity, refuse the temptation to guess. - * There should be one-- and preferably only one --obvious way to do it. - * Although that way may not be obvious at first unless you're Dutch. - * Now is better than never. - * Although never is often better than *right* now. - * If the implementation is hard to explain, it's a bad idea. - * If the implementation is easy to explain, it may be a good idea. - * Namespaces are one honking great idea -- let's do more of those! */ - // cur_value = builder->create_alloca(INT32_T); - cur_value = builder->create_alloca(FLOAT_T); - } - break; - default: - error_exit("Variable type(not array) is not int or float"); + auto *array_type = ArrayType::get(var_type, node.num->i_val); + if (scope.in_global()) { + auto initializer = ConstantZero::get(array_type, module.get()); + auto var = GlobalVariable::create( + node.id, module.get(), array_type, false, initializer); + scope.push(node.id, var); + } else { + auto var = builder->create_alloca(array_type); + scope.push(node.id, var); } } - - if (not scope.push(node.id, cur_value)) - error_exit("variable redefined: " + node.id); - LOG_DEBUG << "add entry: " << node.id << " " << cur_value; } -// Done -void CminusfBuilder::visit(ASTFunDeclaration &node) { +void +CminusfBuilder::visit(ASTFunDeclaration &node) { FunctionType *fun_type; Type *ret_type; std::vector param_types; @@ -229,46 +119,54 @@ void CminusfBuilder::visit(ASTFunDeclaration &node) { ret_type = VOID_T; for (auto ¶m : node.params) { - //!TODO: Please accomplish param_types. - // - // First make function BB, which needs this param type, - // then set_insert_point, we can call accept to gen code, - switch (param->type) { - case TYPE_INT: - param_types.push_back(param->isarray ? INT32PTR_T : INT32_T); - break; - case TYPE_FLOAT: - param_types.push_back(param->isarray ? FLOATPTR_T : FLOAT_T); - break; - case TYPE_VOID: - if (not param_types.empty()) - error_exit("function parameters weird"); - break; + if (param->type == TYPE_INT) { + if (param->isarray) { + param_types.push_back(INT32PTR_T); + } else { + param_types.push_back(INT32_T); + } + } else { + if (param->isarray) { + param_types.push_back(FLOATPTR_T); + } else { + param_types.push_back(FLOAT_T); + } } } fun_type = FunctionType::get(ret_type, param_types); auto fun = Function::create(fun_type, node.id, module.get()); - cur_fun = fun; scope.push(node.id, fun); - + cur_fun = fun; auto funBB = BasicBlock::create(module.get(), "entry", fun); builder->set_insert_point(funBB); scope.enter(); - + pre_enter_scope = true; std::vector args; for (auto arg = fun->arg_begin(); arg != fun->arg_end(); arg++) { args.push_back(*arg); } - for (int i = 0; i < node.params.size(); ++i) { - //!TODO: You need to deal with params - // and store them in the scope. - cur_value = args[i]; - node.params[i]->accept(*this); + if (node.params[i]->isarray) { + Value *array_alloc; + if (node.params[i]->type == TYPE_INT) + array_alloc = builder->create_alloca(INT32PTR_T); + else + array_alloc = builder->create_alloca(FLOATPTR_T); + builder->create_store(args[i], array_alloc); + scope.push(node.params[i]->id, array_alloc); + } else { + Value *alloc; + if (node.params[i]->type == TYPE_INT) + alloc = builder->create_alloca(INT32_T); + else + alloc = builder->create_alloca(FLOAT_T); + builder->create_store(args[i], alloc); + scope.push(node.params[i]->id, alloc); + } } node.compound_stmt->accept(*this); - // default return value + // can't deal with return in both blocks if (builder->get_insert_block()->get_terminator() == nullptr) { if (cur_fun->get_return_type()->is_void_type()) builder->create_void_ret(); @@ -280,40 +178,17 @@ void CminusfBuilder::visit(ASTFunDeclaration &node) { scope.exit(); } -// Done -void CminusfBuilder::visit(ASTParam &node) { - //!TODO: This function is empty now. - // If the parameter is int|float, copy and store them - auto param_value = cur_value; - switch (node.type) { - case TYPE_INT: { - if (node.isarray) - cur_value = builder->create_alloca(INT32PTR_T); - else - cur_value = builder->create_alloca(INT32_T); - break; - } - case TYPE_FLOAT: { - if (node.isarray) - cur_value = builder->create_alloca(FLOATPTR_T); - else - cur_value = builder->create_alloca(FLOAT_T); - break; - } - case TYPE_VOID: - return; - } - scope.push(node.id, cur_value); - builder->create_store(param_value, cur_value); -} +void +CminusfBuilder::visit(ASTParam &node) {} -// Done? -void CminusfBuilder::visit(ASTCompoundStmt &node) { - //!TODO: This function is not complete. - // You may need to add some code here - // to deal with complex statements. - - scope.enter(); +void +CminusfBuilder::visit(ASTCompoundStmt &node) { + bool need_exit_scope = !pre_enter_scope; + if (pre_enter_scope) { + pre_enter_scope = false; + } else { + scope.enter(); + } for (auto &decl : node.local_declarations) { decl->accept(*this); @@ -325,413 +200,303 @@ void CminusfBuilder::visit(ASTCompoundStmt &node) { break; } - scope.exit(); + if (need_exit_scope) { + scope.exit(); + } } -// Done -void CminusfBuilder::visit(ASTExpressionStmt &node) { - //!TODO: This function is empty now. - // Add some code here. - if (node.expression) +void +CminusfBuilder::visit(ASTExpressionStmt &node) { + if (node.expression != nullptr) node.expression->accept(*this); } -// Done -void CminusfBuilder::visit(ASTSelectionStmt &node) { - //!TODO: This function is empty now. - // Add some code here. - scope.enter(); +void +CminusfBuilder::visit(ASTSelectionStmt &node) { node.expression->accept(*this); - auto cond = cur_value; - cast_to_i1(cond); - - auto ifBB = BasicBlock::create(builder->get_module(), "", cur_fun); - auto endBB = BasicBlock::create(builder->get_module(), "", cur_fun); - if (node.else_statement) { - auto elseBB = BasicBlock::create(builder->get_module(), "", cur_fun); - builder->create_cond_br(cond, ifBB, elseBB); - - builder->set_insert_point(ifBB); - node.if_statement->accept(*this); - builder->create_br(endBB); - - builder->set_insert_point(elseBB); - node.else_statement->accept(*this); - builder->create_br(endBB); + auto ret_val = tmp_val; + auto trueBB = BasicBlock::create(module.get(), "", cur_fun); + BasicBlock *falseBB{}; + auto contBB = BasicBlock::create(module.get(), "", cur_fun); + Value *cond_val; + if (ret_val->get_type()->is_integer_type()) + cond_val = builder->create_icmp_ne(ret_val, CONST_INT(0)); + else + cond_val = builder->create_fcmp_ne(ret_val, CONST_FP(0.)); - builder->set_insert_point(endBB); + if (node.else_statement == nullptr) { + builder->create_cond_br(cond_val, trueBB, contBB); } else { - builder->create_cond_br(cond, ifBB, endBB); + falseBB = BasicBlock::create(module.get(), "", cur_fun); + builder->create_cond_br(cond_val, trueBB, falseBB); + } + builder->set_insert_point(trueBB); + node.if_statement->accept(*this); - builder->set_insert_point(ifBB); - node.if_statement->accept(*this); - builder->create_br(endBB); + if (builder->get_insert_block()->get_terminator() == nullptr) + builder->create_br(contBB); - builder->set_insert_point(endBB); + if (node.else_statement == nullptr) { + // falseBB->erase_from_parent(); // did not clean up memory + } else { + builder->set_insert_point(falseBB); + node.else_statement->accept(*this); + if (builder->get_insert_block()->get_terminator() == nullptr) + builder->create_br(contBB); } - scope.exit(); -} -// Done -void CminusfBuilder::visit(ASTIterationStmt &node) { - //!TODO: This function is empty now. - // Add some code here. - scope.enter(); - auto HEAD = BasicBlock::create(builder->get_module(), "", cur_fun); - auto BODY = BasicBlock::create(builder->get_module(), "", cur_fun); - auto END = BasicBlock::create(builder->get_module(), "", cur_fun); - - builder->create_br(HEAD); + builder->set_insert_point(contBB); +} - builder->set_insert_point(HEAD); +void +CminusfBuilder::visit(ASTIterationStmt &node) { + auto exprBB = BasicBlock::create(module.get(), "", cur_fun); + if (builder->get_insert_block()->get_terminator() == nullptr) + builder->create_br(exprBB); + builder->set_insert_point(exprBB); node.expression->accept(*this); - auto cond = cur_value; - cast_to_i1(cond); - builder->create_cond_br(cond, BODY, END); + auto ret_val = tmp_val; + auto trueBB = BasicBlock::create(module.get(), "", cur_fun); + auto contBB = BasicBlock::create(module.get(), "", cur_fun); + Value *cond_val; + if (ret_val->get_type()->is_integer_type()) + cond_val = builder->create_icmp_ne(ret_val, CONST_INT(0)); + else + cond_val = builder->create_fcmp_ne(ret_val, CONST_FP(0.)); - builder->set_insert_point(BODY); + builder->create_cond_br(cond_val, trueBB, contBB); + builder->set_insert_point(trueBB); node.statement->accept(*this); - builder->create_br(HEAD); - - builder->set_insert_point(END); - scope.exit(); + if (builder->get_insert_block()->get_terminator() == nullptr) + builder->create_br(exprBB); + builder->set_insert_point(contBB); } -// Done -void CminusfBuilder::visit(ASTReturnStmt &node) { +void +CminusfBuilder::visit(ASTReturnStmt &node) { if (node.expression == nullptr) { builder->create_void_ret(); } else { - //!TODO: The given code is incomplete. - // You need to solve other return cases (e.g. return an integer). - // + auto fun_ret_type = cur_fun->get_function_type()->get_return_type(); node.expression->accept(*this); - // type cast - // return type can only be int, float or void - if (not Type::is_eq_type(cur_fun->get_return_type(), cur_value->get_type())) { - if (not cur_value->get_type()->is_integer_type() and not cur_value->get_type()->is_float_type()) - error_exit("unsupported return type"); - if (cur_value->get_type()->is_float_type()) - cur_value = builder->create_fptosi(cur_value, INT32_T); - else if (cur_fun->get_return_type()->is_float_type()) - cur_value = builder->create_sitofp(cur_value, FLOAT_T); + if (fun_ret_type != tmp_val->get_type()) { + if (fun_ret_type->is_integer_type()) + tmp_val = builder->create_fptosi(tmp_val, INT32_T); else - cur_value = builder->create_zext(cur_value, INT32_T); + tmp_val = builder->create_sitofp(tmp_val, FLOAT_T); } - builder->create_ret(cur_value); - LOG_DEBUG << "create ret:\n" << builder->get_module()->print(); + builder->create_ret(tmp_val); } } -// Done -// if LV is marked, return memory addr -// else return value stored inside -void CminusfBuilder::visit(ASTVar &node) { - //!TODO: This function is empty now. - // Add some code here. - // - // First it's pointer type, the pointed elements have 3 cases: - // 1. int or float - // 2. [i32 x n] or [float x n] - // 3. int* - auto memory = scope.find(node.id); - Value *addr; - if (memory == nullptr) - error_exit("variable " + node.id + " not declared"); - LOG_DEBUG << "find entry: " << node.id << " " << memory; - - assert(memory->get_type()->is_pointer_type()); - auto element_type = memory->get_type()->get_pointer_element_type(); - - if (node.expression) { // e.g. int a[10]; // mem is [i32 x 10]* - bool old_LV = LV; - LV = false; - node.expression->accept(*this); - LV = old_LV; - - // subscription type cast - if (not Type::is_eq_type(cur_value->get_type(), INT32_T)) { - if (Type::is_eq_type(cur_value->get_type(), FLOAT_T)) - cur_value = builder->create_fptosi(cur_value, INT32_T); - else - error_exit("bad type for subscription"); - } - - auto cond = builder->create_icmp_lt(cur_value, CONST_INT(0)); - auto except_func = scope.find("neg_idx_except"); - auto TBB = BasicBlock::create(builder->get_module(), "", cur_fun); - auto passBB = BasicBlock::create(builder->get_module(), "", cur_fun); - builder->create_cond_br(cond, TBB, passBB); - - builder->set_insert_point(TBB); - builder->create_call(except_func, {}); - builder->create_br(passBB); - - builder->set_insert_point(passBB); - - // Now the subscription is in cur_value, which is good value - // We should focus on the var type: - // - // assert it's pointer type - if (element_type->is_float_type() or element_type->is_integer_type()) { - // 1. int or float - error_exit("invalid types for array subscript"); - } else if (element_type->is_array_type()) { - // 2. [i32 x n] or [float x n] - // - // addr is actually &memory[0][cur_value] - addr = builder->create_gep(memory, {CONST_INT(0), cur_value}); - } else if (element_type->is_pointer_type()) { - // 3. int* - // what to do: - // - int**->int* - // - seek addr - // - // addr is actually &(*memory)[cur_value] - addr = builder->create_load(memory); - addr = builder->create_gep(addr, {cur_value}); - } - - // The logic for this part is the same as `int|float` without subscription. - // Cause we have subscription to find the particular element(int or float), - // we make `addr` its memory address. - - } else { // e.g. int a; // a is i32* - if (element_type->is_float_type() or element_type->is_integer_type()) { - // 1. int or float - // addr is the element's addr - addr = memory; +void +CminusfBuilder::visit(ASTVar &node) { + auto var = scope.find(node.id); + assert(var != nullptr); + auto is_int = + var->get_type()->get_pointer_element_type()->is_integer_type(); + auto is_float = + var->get_type()->get_pointer_element_type()->is_float_type(); + auto is_ptr = + var->get_type()->get_pointer_element_type()->is_pointer_type(); + bool should_return_lvalue = require_lvalue; + require_lvalue = false; + if (node.expression == nullptr) { + if (should_return_lvalue) { + tmp_val = var; + require_lvalue = false; } else { - if (LV) - error_exit("error: pointer or array type is not assignable"); - // For array* or pointer* type, the right-value behaviour is quite special, - // so treat them apart. - if (element_type->is_array_type()) { - // 2. [i32 x n] or [float x n] - // addr is the first element's address in the array - cur_value = builder->create_gep(memory, {CONST_INT(0), CONST_INT(0)}); - } else if (element_type->is_pointer_type()) { - // 3. int* - // addr is the content in the memory, which is actually pointer type - cur_value = builder->create_load(memory); + if (is_int || is_float || is_ptr) { + tmp_val = builder->create_load(var); + } else { + tmp_val = + builder->create_gep(var, {CONST_INT(0), CONST_INT(0)}); } - return; } - } - - if (LV) { - LOG_INFO << "directly return addr" << node.id; - cur_value = addr; } else { - LOG_INFO << "create load for var: " << node.id; - cur_value = builder->create_load(addr); + node.expression->accept(*this); + auto val = tmp_val; + Value *is_neg; + auto exceptBB = BasicBlock::create(module.get(), "", cur_fun); + auto contBB = BasicBlock::create(module.get(), "", cur_fun); + if (val->get_type()->is_float_type()) + val = builder->create_fptosi(val, INT32_T); + + is_neg = builder->create_icmp_lt(val, CONST_INT(0)); + + builder->create_cond_br(is_neg, exceptBB, contBB); + builder->set_insert_point(exceptBB); + auto neg_idx_except_fun = scope.find("neg_idx_except"); + builder->create_call(static_cast(neg_idx_except_fun), {}); + if (cur_fun->get_return_type()->is_void_type()) + builder->create_void_ret(); + else if (cur_fun->get_return_type()->is_float_type()) + builder->create_ret(CONST_FP(0.)); + else + builder->create_ret(CONST_INT(0)); + + builder->set_insert_point(contBB); + Value *tmp_ptr; + if (is_int || is_float) + tmp_ptr = builder->create_gep(var, {val}); + else if (is_ptr) { + auto array_load = builder->create_load(var); + tmp_ptr = builder->create_gep(array_load, {val}); + } else + tmp_ptr = builder->create_gep(var, {CONST_INT(0), val}); + if (should_return_lvalue) { + tmp_val = tmp_ptr; + require_lvalue = false; + } else { + tmp_val = builder->create_load(tmp_ptr); + } } } -// Done -void CminusfBuilder::visit(ASTAssignExpression &node) { - //!TODO: This function is empty now. - // Add some code here. - // - LV = true; - node.var->accept(*this); - LV = false; - auto addr = cur_value; +void +CminusfBuilder::visit(ASTAssignExpression &node) { node.expression->accept(*this); - - assert(addr->get_type()->get_pointer_element_type() != nullptr); - // type cast: left is a pointer type, pointed to i32 or float - if (not Type::is_eq_type(addr->get_type()->get_pointer_element_type(), cur_value->get_type())) { - if (cur_value->get_type()->is_float_type()) - cur_value = builder->create_fptosi(cur_value, INT32_T); - else if (addr->get_type()->get_pointer_element_type()->is_float_type()) - cur_value = builder->create_sitofp(cur_value, FLOAT_T); - else if (Type::is_eq_type(cur_value->get_type(), INT1_T)) - cur_value = builder->create_zext(cur_value, INT32_T); + auto expr_result = tmp_val; + require_lvalue = true; + node.var->accept(*this); + auto var_addr = tmp_val; + if (var_addr->get_type()->get_pointer_element_type() != + expr_result->get_type()) { + if (expr_result->get_type() == INT32_T) + expr_result = builder->create_sitofp(expr_result, FLOAT_T); else - error_exit("bad type for assignment"); + expr_result = builder->create_fptosi(expr_result, INT32_T); } - // gen code - builder->create_store(cur_value, addr); + builder->create_store(expr_result, var_addr); + tmp_val = expr_result; } -// Done -void CminusfBuilder::visit(ASTSimpleExpression &node) { - //!TODO: This function is empty now. - // Add some code here. - // - if (node.additive_expression_r) { +void +CminusfBuilder::visit(ASTSimpleExpression &node) { + if (node.additive_expression_r == nullptr) { node.additive_expression_l->accept(*this); - auto lvalue = cur_value; + } else { + node.additive_expression_l->accept(*this); + auto l_val = tmp_val; node.additive_expression_r->accept(*this); - auto rvalue = cur_value; - // check type - biop_type_check(lvalue, rvalue, "cmp"); - bool float_cmp = lvalue->get_type()->is_float_type(); + auto r_val = tmp_val; + bool is_int = promote(&*builder, &l_val, &r_val); + Value *cmp; switch (node.op) { - case OP_LE: { - if (float_cmp) - cur_value = builder->create_fcmp_le(lvalue, rvalue); + case OP_LT: + if (is_int) + cmp = builder->create_icmp_lt(l_val, r_val); else - cur_value = builder->create_icmp_le(lvalue, rvalue); + cmp = builder->create_fcmp_lt(l_val, r_val); break; - } - case OP_LT: { - if (float_cmp) - cur_value = builder->create_fcmp_lt(lvalue, rvalue); + case OP_LE: + if (is_int) + cmp = builder->create_icmp_le(l_val, r_val); else - cur_value = builder->create_icmp_lt(lvalue, rvalue); + cmp = builder->create_fcmp_le(l_val, r_val); break; - } - case OP_GT: { - if (float_cmp) - cur_value = builder->create_fcmp_gt(lvalue, rvalue); + case OP_GE: + if (is_int) + cmp = builder->create_icmp_ge(l_val, r_val); else - cur_value = builder->create_icmp_gt(lvalue, rvalue); + cmp = builder->create_fcmp_ge(l_val, r_val); break; - } - case OP_GE: { - if (float_cmp) - cur_value = builder->create_fcmp_ge(lvalue, rvalue); + case OP_GT: + if (is_int) + cmp = builder->create_icmp_gt(l_val, r_val); else - cur_value = builder->create_icmp_ge(lvalue, rvalue); + cmp = builder->create_fcmp_gt(l_val, r_val); break; - } - case OP_EQ: { - if (float_cmp) - cur_value = builder->create_fcmp_eq(lvalue, rvalue); + case OP_EQ: + if (is_int) + cmp = builder->create_icmp_eq(l_val, r_val); else - cur_value = builder->create_icmp_eq(lvalue, rvalue); + cmp = builder->create_fcmp_eq(l_val, r_val); break; - } - case OP_NEQ: { - if (float_cmp) - cur_value = builder->create_fcmp_ne(lvalue, rvalue); + case OP_NEQ: + if (is_int) + cmp = builder->create_icmp_ne(l_val, r_val); else - cur_value = builder->create_icmp_ne(lvalue, rvalue); + cmp = builder->create_fcmp_ne(l_val, r_val); break; - } } - } else - node.additive_expression_l->accept(*this); + + tmp_val = builder->create_zext(cmp, INT32_T); + } } -// Done -void CminusfBuilder::visit(ASTAdditiveExpression &node) { - //!TODO: This function is empty now. - // Add some code here. - // - if (node.additive_expression) { +void +CminusfBuilder::visit(ASTAdditiveExpression &node) { + if (node.additive_expression == nullptr) { + node.term->accept(*this); + } else { node.additive_expression->accept(*this); - auto lvalue = cur_value; + auto l_val = tmp_val; node.term->accept(*this); - auto rvalue = cur_value; - // check type - biop_type_check(lvalue, rvalue, "addop"); - bool float_type = lvalue->get_type()->is_float_type(); - // now left and right is the same type + auto r_val = tmp_val; + bool is_int = promote(&*builder, &l_val, &r_val); switch (node.op) { - case OP_PLUS: { - if (float_type) - cur_value = builder->create_fadd(lvalue, rvalue); + case OP_PLUS: + if (is_int) + tmp_val = builder->create_iadd(l_val, r_val); else - cur_value = builder->create_iadd(lvalue, rvalue); + tmp_val = builder->create_fadd(l_val, r_val); break; - } - case OP_MINUS: { - if (float_type) - cur_value = builder->create_fsub(lvalue, rvalue); + case OP_MINUS: + if (is_int) + tmp_val = builder->create_isub(l_val, r_val); else - cur_value = builder->create_isub(lvalue, rvalue); + tmp_val = builder->create_fsub(l_val, r_val); break; - } } - } else - node.term->accept(*this); + } } -// Done -void CminusfBuilder::visit(ASTTerm &node) { - //!TODO: This function is empty now. - // Add some code here. - if (node.term) { +void +CminusfBuilder::visit(ASTTerm &node) { + if (node.term == nullptr) { + node.factor->accept(*this); + } else { node.term->accept(*this); - auto lvalue = cur_value; + auto l_val = tmp_val; node.factor->accept(*this); - auto rvalue = cur_value; - // check type - biop_type_check(lvalue, rvalue, "mul"); - bool float_type = lvalue->get_type()->is_float_type(); - // now left and right is the same type + auto r_val = tmp_val; + bool is_int = promote(&*builder, &l_val, &r_val); switch (node.op) { - case OP_MUL: { - if (float_type) - cur_value = builder->create_fmul(lvalue, rvalue); + case OP_MUL: + if (is_int) + tmp_val = builder->create_imul(l_val, r_val); else - cur_value = builder->create_imul(lvalue, rvalue); + tmp_val = builder->create_fmul(l_val, r_val); break; - } - case OP_DIV: { - if (float_type) - cur_value = builder->create_fdiv(lvalue, rvalue); + case OP_DIV: + if (is_int) + tmp_val = builder->create_isdiv(l_val, r_val); else - cur_value = builder->create_isdiv(lvalue, rvalue); + tmp_val = builder->create_fdiv(l_val, r_val); break; - } } - } else - node.factor->accept(*this); + } } -// Done -void CminusfBuilder::visit(ASTCall &node) { - //!TODO: This function is empty now. - // Add some code here. - Function *func = static_cast(scope.find(node.id)); +void +CminusfBuilder::visit(ASTCall &node) { + auto fun = static_cast(scope.find(node.id)); std::vector args; - if (func == nullptr) - error_exit("function " + node.id + " not declared"); - if (node.args.size() != func->get_num_of_args()) - error_exit("expect " + std::to_string(func->get_num_of_args()) + " params, but " + - std::to_string(node.args.size()) + " is given"); - // check every argument - for (int i = 0; i != node.args.size(); ++i) { - // ith parameter's type - Type *param_type = func->get_function_type()->get_param_type(i); - node.args[i]->accept(*this); - - // type cast - if (not Type::is_eq_type(param_type, cur_value->get_type())) { - if (param_type->is_pointer_type()) { - // shouldn't need type cast for pointer, logically - if (param_type->get_pointer_element_type()->is_integer_type() or - param_type->get_pointer_element_type()->is_float_type()) - error_exit("BUG HERE: ASTVar return value is not int* or float*"); - else - error_exit("BUG HERE: function param needs weird pointer type"); - - } else if (param_type->is_integer_type() or param_type->is_float_type()) { - // need type cast between int and float - if (not cur_value->get_type()->is_integer_type() and not cur_value->get_type()->is_float_type()) - error_exit("unexpected type cast!"); - - if (param_type->is_float_type()) - cur_value = builder->create_sitofp(cur_value, FLOAT_T); - else if (param_type->is_integer_type()) - if (cur_value->get_type()->is_integer_type()) - cur_value = builder->create_zext(cur_value, INT32_T); - else - cur_value = builder->create_fptosi(cur_value, INT32_T); - else - error_exit("unexpected type cast!"); - - } else - error_exit("unexpected case when casting arguments for function call " + node.id); + auto param_type = fun->get_function_type()->param_begin(); + for (auto &arg : node.args) { + arg->accept(*this); + if (!tmp_val->get_type()->is_pointer_type() && + *param_type != tmp_val->get_type()) { + if (tmp_val->get_type()->is_integer_type()) + tmp_val = builder->create_sitofp(tmp_val, FLOAT_T); + else + tmp_val = builder->create_fptosi(tmp_val, INT32_T); } - - // now cur_value fits the param type - args.push_back(cur_value); + args.push_back(tmp_val); + param_type++; } - cur_value = builder->create_call(func, args); + + tmp_val = builder->create_call(static_cast(fun), args); } diff --git a/src/lightir/Instruction.cpp b/src/lightir/Instruction.cpp index e4cd6cd4c361c9f5f1e322e8274998fe5f2a3fcf..9eb9dcde10662a36a8340a75d68d12a6190d4baa 100644 --- a/src/lightir/Instruction.cpp +++ b/src/lightir/Instruction.cpp @@ -10,7 +10,10 @@ #include #include -Instruction::Instruction(Type *ty, OpID id, unsigned num_ops, BasicBlock *parent) +Instruction::Instruction(Type *ty, + OpID id, + unsigned num_ops, + BasicBlock *parent) : User(ty, "", num_ops), op_id_(id), num_ops_(num_ops), parent_(parent) { parent_->add_instruction(this); } @@ -18,56 +21,75 @@ Instruction::Instruction(Type *ty, OpID id, unsigned num_ops, BasicBlock *parent Instruction::Instruction(Type *ty, OpID id, unsigned num_ops) : User(ty, "", num_ops), op_id_(id), num_ops_(num_ops), parent_(nullptr) {} -Function *Instruction::get_function() { return parent_->get_parent(); } +Function * +Instruction::get_function() { + return parent_->get_parent(); +} -Module *Instruction::get_module() { return parent_->get_module(); } +Module * +Instruction::get_module() { + return parent_->get_module(); +} -BinaryInst::BinaryInst(Type *ty, OpID id, Value *v1, Value *v2, BasicBlock *bb) : BaseInst(ty, id, 2, bb) { +BinaryInst::BinaryInst(Type *ty, OpID id, Value *v1, Value *v2, BasicBlock *bb) + : BaseInst(ty, id, 2, bb) { set_operand(0, v1); set_operand(1, v2); // assertValid(); } -void BinaryInst::assertValid() { +void +BinaryInst::assertValid() { assert(get_operand(0)->get_type()->is_integer_type()); assert(get_operand(1)->get_type()->is_integer_type()); - assert(static_cast(get_operand(0)->get_type())->get_num_bits() == - static_cast(get_operand(1)->get_type())->get_num_bits()); + assert( + static_cast(get_operand(0)->get_type()) + ->get_num_bits() == + static_cast(get_operand(1)->get_type())->get_num_bits()); } -BinaryInst *BinaryInst::create_add(Value *v1, Value *v2, BasicBlock *bb, Module *m) { +BinaryInst * +BinaryInst::create_add(Value *v1, Value *v2, BasicBlock *bb, Module *m) { return create(Type::get_int32_type(m), Instruction::add, v1, v2, bb); } -BinaryInst *BinaryInst::create_sub(Value *v1, Value *v2, BasicBlock *bb, Module *m) { +BinaryInst * +BinaryInst::create_sub(Value *v1, Value *v2, BasicBlock *bb, Module *m) { return create(Type::get_int32_type(m), Instruction::sub, v1, v2, bb); } -BinaryInst *BinaryInst::create_mul(Value *v1, Value *v2, BasicBlock *bb, Module *m) { +BinaryInst * +BinaryInst::create_mul(Value *v1, Value *v2, BasicBlock *bb, Module *m) { return create(Type::get_int32_type(m), Instruction::mul, v1, v2, bb); } -BinaryInst *BinaryInst::create_sdiv(Value *v1, Value *v2, BasicBlock *bb, Module *m) { +BinaryInst * +BinaryInst::create_sdiv(Value *v1, Value *v2, BasicBlock *bb, Module *m) { return create(Type::get_int32_type(m), Instruction::sdiv, v1, v2, bb); } -BinaryInst *BinaryInst::create_fadd(Value *v1, Value *v2, BasicBlock *bb, Module *m) { +BinaryInst * +BinaryInst::create_fadd(Value *v1, Value *v2, BasicBlock *bb, Module *m) { return create(Type::get_float_type(m), Instruction::fadd, v1, v2, bb); } -BinaryInst *BinaryInst::create_fsub(Value *v1, Value *v2, BasicBlock *bb, Module *m) { +BinaryInst * +BinaryInst::create_fsub(Value *v1, Value *v2, BasicBlock *bb, Module *m) { return create(Type::get_float_type(m), Instruction::fsub, v1, v2, bb); } -BinaryInst *BinaryInst::create_fmul(Value *v1, Value *v2, BasicBlock *bb, Module *m) { +BinaryInst * +BinaryInst::create_fmul(Value *v1, Value *v2, BasicBlock *bb, Module *m) { return create(Type::get_float_type(m), Instruction::fmul, v1, v2, bb); } -BinaryInst *BinaryInst::create_fdiv(Value *v1, Value *v2, BasicBlock *bb, Module *m) { +BinaryInst * +BinaryInst::create_fdiv(Value *v1, Value *v2, BasicBlock *bb, Module *m) { return create(Type::get_float_type(m), Instruction::fdiv, v1, v2, bb); } -std::string BinaryInst::print() { +std::string +BinaryInst::print() { std::string instr_ir; instr_ir += "%"; instr_ir += this->get_name(); @@ -78,7 +100,8 @@ std::string BinaryInst::print() { instr_ir += " "; instr_ir += print_as_op(this->get_operand(0), false); instr_ir += ", "; - if (Type::is_eq_type(this->get_operand(0)->get_type(), this->get_operand(1)->get_type())) { + if (Type::is_eq_type(this->get_operand(0)->get_type(), + this->get_operand(1)->get_type())) { instr_ir += print_as_op(this->get_operand(1), false); } else { instr_ir += print_as_op(this->get_operand(1), true); @@ -93,18 +116,27 @@ CmpInst::CmpInst(Type *ty, CmpOp op, Value *lhs, Value *rhs, BasicBlock *bb) // assertValid(); } -void CmpInst::assertValid() { +void +CmpInst::assertValid() { assert(get_operand(0)->get_type()->is_integer_type()); assert(get_operand(1)->get_type()->is_integer_type()); - assert(static_cast(get_operand(0)->get_type())->get_num_bits() == - static_cast(get_operand(1)->get_type())->get_num_bits()); -} - -CmpInst *CmpInst::create_cmp(CmpOp op, Value *lhs, Value *rhs, BasicBlock *bb, Module *m) { + assert( + static_cast(get_operand(0)->get_type()) + ->get_num_bits() == + static_cast(get_operand(1)->get_type())->get_num_bits()); +} + +CmpInst * +CmpInst::create_cmp(CmpOp op, + Value *lhs, + Value *rhs, + BasicBlock *bb, + Module *m) { return create(m->get_int1_type(), op, lhs, rhs, bb); } -std::string CmpInst::print() { +std::string +CmpInst::print() { std::string instr_ir; instr_ir += "%"; instr_ir += this->get_name(); @@ -117,7 +149,8 @@ std::string CmpInst::print() { instr_ir += " "; instr_ir += print_as_op(this->get_operand(0), false); instr_ir += ", "; - if (Type::is_eq_type(this->get_operand(0)->get_type(), this->get_operand(1)->get_type())) { + if (Type::is_eq_type(this->get_operand(0)->get_type(), + this->get_operand(1)->get_type())) { instr_ir += print_as_op(this->get_operand(1), false); } else { instr_ir += print_as_op(this->get_operand(1), true); @@ -132,16 +165,23 @@ FCmpInst::FCmpInst(Type *ty, CmpOp op, Value *lhs, Value *rhs, BasicBlock *bb) // assertValid(); } -void FCmpInst::assert_valid() { +void +FCmpInst::assert_valid() { assert(get_operand(0)->get_type()->is_float_type()); assert(get_operand(1)->get_type()->is_float_type()); } -FCmpInst *FCmpInst::create_fcmp(CmpOp op, Value *lhs, Value *rhs, BasicBlock *bb, Module *m) { +FCmpInst * +FCmpInst::create_fcmp(CmpOp op, + Value *lhs, + Value *rhs, + BasicBlock *bb, + Module *m) { return create(m->get_int1_type(), op, lhs, rhs, bb); } -std::string FCmpInst::print() { +std::string +FCmpInst::print() { std::string instr_ir; instr_ir += "%"; instr_ir += this->get_name(); @@ -154,7 +194,8 @@ std::string FCmpInst::print() { instr_ir += " "; instr_ir += print_as_op(this->get_operand(0), false); instr_ir += ","; - if (Type::is_eq_type(this->get_operand(0)->get_type(), this->get_operand(1)->get_type())) { + if (Type::is_eq_type(this->get_operand(0)->get_type(), + this->get_operand(1)->get_type())) { instr_ir += print_as_op(this->get_operand(1), false); } else { instr_ir += print_as_op(this->get_operand(1), true); @@ -163,7 +204,10 @@ std::string FCmpInst::print() { } CallInst::CallInst(Function *func, std::vector args, BasicBlock *bb) - : BaseInst(func->get_return_type(), Instruction::call, args.size() + 1, bb) { + : BaseInst(func->get_return_type(), + Instruction::call, + args.size() + 1, + bb) { assert(func->get_num_of_args() == args.size()); int num_ops = args.size() + 1; set_operand(0, func); @@ -172,13 +216,18 @@ CallInst::CallInst(Function *func, std::vector args, BasicBlock *bb) } } -CallInst *CallInst::create(Function *func, std::vector args, BasicBlock *bb) { +CallInst * +CallInst::create(Function *func, std::vector args, BasicBlock *bb) { return BaseInst::create(func, args, bb); } -FunctionType *CallInst::get_function_type() const { return static_cast(get_operand(0)->get_type()); } +FunctionType * +CallInst::get_function_type() const { + return static_cast(get_operand(0)->get_type()); +} -std::string CallInst::print() { +std::string +CallInst::print() { std::string instr_ir; if (!this->is_void()) { instr_ir += "%"; @@ -190,7 +239,8 @@ std::string CallInst::print() { instr_ir += this->get_function_type()->get_return_type()->print(); instr_ir += " "; - assert(dynamic_cast(this->get_operand(0)) && "Wrong call operand function"); + assert(dynamic_cast(this->get_operand(0)) && + "Wrong call operand function"); instr_ir += print_as_op(this->get_operand(0), false); instr_ir += "("; for (int i = 1; i < this->get_num_operand(); i++) { @@ -204,19 +254,32 @@ std::string CallInst::print() { return instr_ir; } -BranchInst::BranchInst(Value *cond, BasicBlock *if_true, BasicBlock *if_false, BasicBlock *bb) - : BaseInst(Type::get_void_type(if_true->get_module()), Instruction::br, 3, bb) { +BranchInst::BranchInst(Value *cond, + BasicBlock *if_true, + BasicBlock *if_false, + BasicBlock *bb) + : BaseInst(Type::get_void_type(if_true->get_module()), + Instruction::br, + 3, + bb) { set_operand(0, cond); set_operand(1, if_true); set_operand(2, if_false); } BranchInst::BranchInst(BasicBlock *if_true, BasicBlock *bb) - : BaseInst(Type::get_void_type(if_true->get_module()), Instruction::br, 1, bb) { + : BaseInst(Type::get_void_type(if_true->get_module()), + Instruction::br, + 1, + bb) { set_operand(0, if_true); } -BranchInst *BranchInst::create_cond_br(Value *cond, BasicBlock *if_true, BasicBlock *if_false, BasicBlock *bb) { +BranchInst * +BranchInst::create_cond_br(Value *cond, + BasicBlock *if_true, + BasicBlock *if_false, + BasicBlock *bb) { if_true->add_pre_basic_block(bb); if_false->add_pre_basic_block(bb); bb->add_succ_basic_block(if_false); @@ -225,16 +288,21 @@ BranchInst *BranchInst::create_cond_br(Value *cond, BasicBlock *if_true, BasicBl return create(cond, if_true, if_false, bb); } -BranchInst *BranchInst::create_br(BasicBlock *if_true, BasicBlock *bb) { +BranchInst * +BranchInst::create_br(BasicBlock *if_true, BasicBlock *bb) { if_true->add_pre_basic_block(bb); bb->add_succ_basic_block(if_true); return create(if_true, bb); } -bool BranchInst::is_cond_br() const { return get_num_operand() == 3; } +bool +BranchInst::is_cond_br() const { + return get_num_operand() == 3; +} -std::string BranchInst::print() { +std::string +BranchInst::print() { std::string instr_ir; instr_ir += this->get_module()->get_instr_op_name(this->get_instr_type()); instr_ir += " "; @@ -250,20 +318,36 @@ std::string BranchInst::print() { } ReturnInst::ReturnInst(Value *val, BasicBlock *bb) - : BaseInst(Type::get_void_type(bb->get_module()), Instruction::ret, 1, bb) { + : BaseInst(Type::get_void_type(bb->get_module()), + Instruction::ret, + 1, + bb) { set_operand(0, val); } ReturnInst::ReturnInst(BasicBlock *bb) - : BaseInst(Type::get_void_type(bb->get_module()), Instruction::ret, 0, bb) {} + : BaseInst(Type::get_void_type(bb->get_module()), + Instruction::ret, + 0, + bb) {} -ReturnInst *ReturnInst::create_ret(Value *val, BasicBlock *bb) { return create(val, bb); } +ReturnInst * +ReturnInst::create_ret(Value *val, BasicBlock *bb) { + return create(val, bb); +} -ReturnInst *ReturnInst::create_void_ret(BasicBlock *bb) { return create(bb); } +ReturnInst * +ReturnInst::create_void_ret(BasicBlock *bb) { + return create(bb); +} -bool ReturnInst::is_void_ret() const { return get_num_operand() == 0; } +bool +ReturnInst::is_void_ret() const { + return get_num_operand() == 0; +} -std::string ReturnInst::print() { +std::string +ReturnInst::print() { std::string instr_ir; instr_ir += this->get_module()->get_instr_op_name(this->get_instr_type()); instr_ir += " "; @@ -278,7 +362,9 @@ std::string ReturnInst::print() { return instr_ir; } -GetElementPtrInst::GetElementPtrInst(Value *ptr, std::vector idxs, BasicBlock *bb) +GetElementPtrInst::GetElementPtrInst(Value *ptr, + std::vector idxs, + BasicBlock *bb) : BaseInst(PointerType::get(get_element_type(ptr, idxs)), Instruction::getelementptr, 1 + idxs.size(), @@ -290,10 +376,12 @@ GetElementPtrInst::GetElementPtrInst(Value *ptr, std::vector idxs, Basi element_ty_ = get_element_type(ptr, idxs); } -Type *GetElementPtrInst::get_element_type(Value *ptr, std::vector idxs) { +Type * +GetElementPtrInst::get_element_type(Value *ptr, std::vector idxs) { Type *ty = ptr->get_type()->get_pointer_element_type(); - assert("GetElementPtrInst ptr is wrong type" && - (ty->is_array_type() || ty->is_integer_type() || ty->is_float_type())); + assert( + "GetElementPtrInst ptr is wrong type" && + (ty->is_array_type() || ty->is_integer_type() || ty->is_float_type())); if (ty->is_array_type()) { ArrayType *arr_ty = static_cast(ty); for (int i = 1; i < idxs.size(); i++) { @@ -309,13 +397,20 @@ Type *GetElementPtrInst::get_element_type(Value *ptr, std::vector idxs) return ty; } -Type *GetElementPtrInst::get_element_type() const { return element_ty_; } +Type * +GetElementPtrInst::get_element_type() const { + return element_ty_; +} -GetElementPtrInst *GetElementPtrInst::create_gep(Value *ptr, std::vector idxs, BasicBlock *bb) { +GetElementPtrInst * +GetElementPtrInst::create_gep(Value *ptr, + std::vector idxs, + BasicBlock *bb) { return create(ptr, idxs, bb); } -std::string GetElementPtrInst::print() { +std::string +GetElementPtrInst::print() { std::string instr_ir; instr_ir += "%"; instr_ir += this->get_name(); @@ -323,7 +418,8 @@ std::string GetElementPtrInst::print() { instr_ir += this->get_module()->get_instr_op_name(this->get_instr_type()); instr_ir += " "; assert(this->get_operand(0)->get_type()->is_pointer_type()); - instr_ir += this->get_operand(0)->get_type()->get_pointer_element_type()->print(); + instr_ir += + this->get_operand(0)->get_type()->get_pointer_element_type()->print(); instr_ir += ", "; for (int i = 0; i < this->get_num_operand(); i++) { if (i > 0) @@ -336,14 +432,21 @@ std::string GetElementPtrInst::print() { } StoreInst::StoreInst(Value *val, Value *ptr, BasicBlock *bb) - : BaseInst(Type::get_void_type(bb->get_module()), Instruction::store, 2, bb) { + : BaseInst(Type::get_void_type(bb->get_module()), + Instruction::store, + 2, + bb) { set_operand(0, val); set_operand(1, ptr); } -StoreInst *StoreInst::create_store(Value *val, Value *ptr, BasicBlock *bb) { return create(val, ptr, bb); } +StoreInst * +StoreInst::create_store(Value *val, Value *ptr, BasicBlock *bb) { + return create(val, ptr, bb); +} -std::string StoreInst::print() { +std::string +StoreInst::print() { std::string instr_ir; instr_ir += this->get_module()->get_instr_op_name(this->get_instr_type()); instr_ir += " "; @@ -355,19 +458,27 @@ std::string StoreInst::print() { return instr_ir; } -LoadInst::LoadInst(Type *ty, Value *ptr, BasicBlock *bb) : BaseInst(ty, Instruction::load, 1, bb) { +LoadInst::LoadInst(Type *ty, Value *ptr, BasicBlock *bb) + : BaseInst(ty, Instruction::load, 1, bb) { assert(ptr->get_type()->is_pointer_type()); - assert(ty == static_cast(ptr->get_type())->get_element_type()); + assert(ty == + static_cast(ptr->get_type())->get_element_type()); set_operand(0, ptr); } -LoadInst *LoadInst::create_load(Type *ty, Value *ptr, BasicBlock *bb) { return create(ty, ptr, bb); } +LoadInst * +LoadInst::create_load(Type *ty, Value *ptr, BasicBlock *bb) { + return create(ty, ptr, bb); +} -Type *LoadInst::get_load_type() const { - return static_cast(get_operand(0)->get_type())->get_element_type(); +Type * +LoadInst::get_load_type() const { + return static_cast(get_operand(0)->get_type()) + ->get_element_type(); } -std::string LoadInst::print() { +std::string +LoadInst::print() { std::string instr_ir; instr_ir += "%"; instr_ir += this->get_name(); @@ -375,7 +486,8 @@ std::string LoadInst::print() { instr_ir += this->get_module()->get_instr_op_name(this->get_instr_type()); instr_ir += " "; assert(this->get_operand(0)->get_type()->is_pointer_type()); - instr_ir += this->get_operand(0)->get_type()->get_pointer_element_type()->print(); + instr_ir += + this->get_operand(0)->get_type()->get_pointer_element_type()->print(); instr_ir += ","; instr_ir += " "; instr_ir += print_as_op(this->get_operand(0), true); @@ -383,13 +495,21 @@ std::string LoadInst::print() { } AllocaInst::AllocaInst(Type *ty, BasicBlock *bb) - : BaseInst(PointerType::get(ty), Instruction::alloca, 0, bb), alloca_ty_(ty) {} + : BaseInst(PointerType::get(ty), Instruction::alloca, 0, bb) + , alloca_ty_(ty) {} -AllocaInst *AllocaInst::create_alloca(Type *ty, BasicBlock *bb) { return create(ty, bb); } +AllocaInst * +AllocaInst::create_alloca(Type *ty, BasicBlock *bb) { + return create(ty, bb); +} -Type *AllocaInst::get_alloca_type() const { return alloca_ty_; } +Type * +AllocaInst::get_alloca_type() const { + return alloca_ty_; +} -std::string AllocaInst::print() { +std::string +AllocaInst::print() { std::string instr_ir; instr_ir += "%"; instr_ir += this->get_name(); @@ -400,15 +520,23 @@ std::string AllocaInst::print() { return instr_ir; } -ZextInst::ZextInst(OpID op, Value *val, Type *ty, BasicBlock *bb) : BaseInst(ty, op, 1, bb), dest_ty_(ty) { +ZextInst::ZextInst(OpID op, Value *val, Type *ty, BasicBlock *bb) + : BaseInst(ty, op, 1, bb), dest_ty_(ty) { set_operand(0, val); } -ZextInst *ZextInst::create_zext(Value *val, Type *ty, BasicBlock *bb) { return create(Instruction::zext, val, ty, bb); } +ZextInst * +ZextInst::create_zext(Value *val, Type *ty, BasicBlock *bb) { + return create(Instruction::zext, val, ty, bb); +} -Type *ZextInst::get_dest_type() const { return dest_ty_; } +Type * +ZextInst::get_dest_type() const { + return dest_ty_; +} -std::string ZextInst::print() { +std::string +ZextInst::print() { std::string instr_ir; instr_ir += "%"; instr_ir += this->get_name(); @@ -428,13 +556,18 @@ FpToSiInst::FpToSiInst(OpID op, Value *val, Type *ty, BasicBlock *bb) set_operand(0, val); } -FpToSiInst *FpToSiInst::create_fptosi(Value *val, Type *ty, BasicBlock *bb) { +FpToSiInst * +FpToSiInst::create_fptosi(Value *val, Type *ty, BasicBlock *bb) { return create(Instruction::fptosi, val, ty, bb); } -Type *FpToSiInst::get_dest_type() const { return dest_ty_; } +Type * +FpToSiInst::get_dest_type() const { + return dest_ty_; +} -std::string FpToSiInst::print() { +std::string +FpToSiInst::print() { std::string instr_ir; instr_ir += "%"; instr_ir += this->get_name(); @@ -454,13 +587,18 @@ SiToFpInst::SiToFpInst(OpID op, Value *val, Type *ty, BasicBlock *bb) set_operand(0, val); } -SiToFpInst *SiToFpInst::create_sitofp(Value *val, Type *ty, BasicBlock *bb) { +SiToFpInst * +SiToFpInst::create_sitofp(Value *val, Type *ty, BasicBlock *bb) { return create(Instruction::sitofp, val, ty, bb); } -Type *SiToFpInst::get_dest_type() const { return dest_ty_; } +Type * +SiToFpInst::get_dest_type() const { + return dest_ty_; +} -std::string SiToFpInst::print() { +std::string +SiToFpInst::print() { std::string instr_ir; instr_ir += "%"; instr_ir += this->get_name(); @@ -475,7 +613,11 @@ std::string SiToFpInst::print() { return instr_ir; } -PhiInst::PhiInst(OpID op, std::vector vals, std::vector val_bbs, Type *ty, BasicBlock *bb) +PhiInst::PhiInst(OpID op, + std::vector vals, + std::vector val_bbs, + Type *ty, + BasicBlock *bb) : BaseInst(ty, op, 2 * vals.size()) { for (int i = 0; i < vals.size(); i++) { set_operand(2 * i, vals[i]); @@ -484,13 +626,15 @@ PhiInst::PhiInst(OpID op, std::vector vals, std::vector v this->set_parent(bb); } -PhiInst *PhiInst::create_phi(Type *ty, BasicBlock *bb) { +PhiInst * +PhiInst::create_phi(Type *ty, BasicBlock *bb) { std::vector vals; std::vector val_bbs; return create(Instruction::phi, vals, val_bbs, ty, bb); } -std::string PhiInst::print() { +std::string +PhiInst::print() { std::string instr_ir; instr_ir += "%"; instr_ir += this->get_name(); @@ -508,9 +652,12 @@ std::string PhiInst::print() { instr_ir += print_as_op(this->get_operand(2 * i + 1), false); instr_ir += " ]"; } - if (this->get_num_operand() / 2 < this->get_parent()->get_pre_basic_blocks().size()) { + if (this->get_num_operand() / 2 < + this->get_parent()->get_pre_basic_blocks().size()) { for (auto pre_bb : this->get_parent()->get_pre_basic_blocks()) { - if (std::find(this->get_operands().begin(), this->get_operands().end(), static_cast(pre_bb)) == + if (std::find(this->get_operands().begin(), + this->get_operands().end(), + static_cast(pre_bb)) == this->get_operands().end()) { // find a pre_bb is not in phi instr_ir += ", [ undef, " + print_as_op(pre_bb, false) + " ]"; diff --git a/src/optimization/GVN.cpp b/src/optimization/GVN.cpp index 5f76ab405010d1a57629ed21d7bb60a2f386110b..db437dc140b4feb1293064570958229f68c46307 100644 --- a/src/optimization/GVN.cpp +++ b/src/optimization/GVN.cpp @@ -209,6 +209,17 @@ print_partitions(const GVN::partitions &p) { } } // namespace utils +GVN::partitions +GVN::join_helper(BasicBlock *pre1, BasicBlock *pre2) { + assert(not _TOP[pre1] or not _TOP[pre2] && "should flow here, not jump"); + + if (_TOP[pre1]) + return pout_[pre2]; + else if (_TOP[pre2]) + return pout_[pre1]; + return join(pout_[pre1], pout_[pre2]); +} + GVN::partitions GVN::join(const partitions &P1, const partitions &P2) { // TODO: do intersection pair-wise @@ -228,7 +239,6 @@ std::shared_ptr GVN::intersect(std::shared_ptr ci, std::shared_ptr cj) { // TODO - // If no common members, return null auto c = createCongruenceClass(); std::set intersection; @@ -240,10 +250,21 @@ GVN::intersect(std::shared_ptr ci, c->members_ = intersection; if (ci->index_ == cj->index_) c->index_ = ci->index_; - if (ci->leader_ == cj->leader_) - c->leader_ = cj->leader_; - /* if (*ci == *cj) - * return ci; */ + if (ci->value_expr_ == cj->value_expr_) + c->value_expr_ = ci->value_expr_; + if (ci->value_phi_ and cj->value_phi_ and + *ci->value_phi_ == *cj->value_phi_) + c->value_phi_ = ci->value_phi_; + + // if (c->members_.size() or c->value_expr_ or c->value_phi_) // not empty + // ?? + // What if the ve is nullptr? + if (c->members_.size()) // not empty + if (c->index_ == 0) { + c->index_ = new_number(); + c->value_phi_ = + PhiExpression::create(ci->value_expr_, cj->value_expr_); + } return c; } @@ -251,30 +272,31 @@ GVN::intersect(std::shared_ptr ci, void GVN::detectEquivalences() { bool changed; + std::cout << "all the instruction address:" << std::endl; + for (auto &bb : func_->get_basic_blocks()) { + for (auto &instr : bb.get_instructions()) + std::cout << &instr << "\t" << instr.print() << std::endl; + } // initialize pout with top for (auto &bb : func_->get_basic_blocks()) { - // pin_[&bb].clear(); - // pout_[&bb].clear(); - for (auto &instr : bb.get_instructions()) - _TOP[&instr] = true; + _TOP[&bb] = true; } - // modify entry block auto Entry = func_->get_entry_block(); - _TOP[&*Entry->get_instructions().begin()] = false; + _TOP[Entry] = false; + pin_[Entry].clear(); - pout_[Entry].clear(); // pout_[Entry] = transferFunction(Entry); + pout_[Entry] = transferFunction(Entry); // iterate until converge do { changed = false; - // see the pseudo code in documentation - for (auto &_bb : - func_->get_basic_blocks()) { // you might need to visit the - // blocks in depth-first order + for (auto &_bb : func_->get_basic_blocks()) { auto bb = &_bb; - // get PIN of bb by predecessor(s) + if (bb == Entry) + continue; + // get PIN of bb from predecessor(s) auto pre_bbs_ = bb->get_pre_basic_blocks(); if (bb != Entry) { // only update PIN for blocks that are not Entry @@ -283,12 +305,12 @@ GVN::detectEquivalences() { case 2: { auto pre_1 = *pre_bbs_.begin(); auto pre_2 = *(++pre_bbs_.begin()); - pin_[bb] = join(pin_[pre_1], pin_[pre_2]); + pin_[bb] = join_helper(pre_1, pre_2); break; } case 1: { auto pre = *(pre_bbs_.begin()); - pin_[bb] = clone(pin_[pre]); + pin_[bb] = pout_[pre]; break; } default: @@ -297,82 +319,246 @@ GVN::detectEquivalences() { abort(); } } - auto part = pin_[bb]; - // iterate through all instructions in the block - for (auto &instr : bb->get_instructions()) { - // ?? - if (not instr.is_phi()) - part = transferFunction(&instr, instr.get_operand(1), part); - } - // and the phi instruction in all the successors - for (auto succ : bb->get_succ_basic_blocks()) { - for (auto &instr : succ->get_instructions()) - if (instr.is_phi()) { - Instruction *pretend; - // ?? - part = transferFunction( - pretend, instr.get_operand(1), part); - } - } + auto part = transferFunction(bb); // check changes in pout changed |= not(part == pout_[bb]); pout_[bb] = part; + _TOP[bb] = false; } } while (changed); } shared_ptr -GVN::valueExpr(Instruction *instr) { +GVN::valueExpr(Instruction *instr, partitions *part) { // TODO - return {}; + // ?? should use part? + std::string err{"Undefined"}; + std::cout << instr->print() << std::endl; + + if (instr->isBinary() or instr->is_cmp() or instr->is_fcmp()) { + auto op1 = instr->get_operand(0); + auto op2 = instr->get_operand(1); + auto op1_const = dynamic_cast(op1); + auto op2_const = dynamic_cast(op2); + if (op1_const and op2_const) { + // both are constant number, so: + // constant fold! + return ConstantExpression::create( + folder_->compute(instr, op1_const, op2_const)); + } else { + // both none constant + auto op1_instr = dynamic_cast(op1); + auto op2_instr = dynamic_cast(op2); + assert((op1_instr or op1_const) and (op2_instr or op2_const) && + "must be this case"); + return BinaryExpression::create( + instr->get_instr_type(), + (op1_const ? ConstantExpression::create(op1_const) + : valueExpr(op1_instr)), + (op2_const ? ConstantExpression::create(op2_const) + : valueExpr(op2_instr))); + } + } else if (instr->is_phi()) { + err = "phi"; + } else if (instr->is_fp2si() or instr->is_si2fp() or instr->is_zext()) { + auto op = instr->get_operand(0); + auto op_const = dynamic_cast(op); + auto op_instr = dynamic_cast(op); + assert(op_instr or op_const); + + // get dest type + auto instr_fp2si = dynamic_cast(instr); + auto instr_si2fp = dynamic_cast(instr); + auto instr_zext = dynamic_cast(instr); + Type *dest_type = nullptr; + if (instr_fp2si) + dest_type = instr_fp2si->get_dest_type(); + else if (instr_si2fp) + dest_type = instr_si2fp->get_dest_type(); + else if (instr_zext) + dest_type = instr_zext->get_dest_type(); + else + err = "cast"; + + if (dest_type) { + if (op_const) + return ConstantExpression::create( + folder_->compute(instr, op_const)); + else + return CastExpression::create( + instr->get_instr_type(), valueExpr(op_instr), dest_type); + } + } else if (instr->is_gep()) { + auto operands = instr->get_operands(); + auto ptr = operands[0]; + std::vector> idxs; + // check for base address + assert(not dynamic_cast(ptr) and + dynamic_cast(ptr) && + "base address should only be from instruction"); + // set idxes + for (int i = 1; i < operands.size(); i++) { + if (dynamic_cast(operands[i])) + idxs.push_back(ConstantExpression::create( + dynamic_cast(operands[i]))); + else { + assert(dynamic_cast(operands[i])); + idxs.push_back( + valueExpr(dynamic_cast(operands[i]))); + } + } + return GEPExpression::create( + valueExpr(dynamic_cast(ptr)), idxs); + } else if (instr->is_load() or instr->is_alloca() or instr->is_call()) { + return UniqueExpression::create(instr); + } + + std::cerr << "Undefined case: " << err << std::endl; + abort(); } // instruction of the form `x = e`, mostly x is just e (SSA), // but for copy stmt x is a phi instruction in the successor. // Phi values (not copy stmt) should be handled in detectEquiv +// +// assert the x is an instruction that can generate a new value +// /// \param bb basic block in which the transfer function is called GVN::partitions GVN::transferFunction(Instruction *x, Value *e, partitions pin) { - partitions pout = clone(pin); + partitions pout = pin; + // TODO: deal with copy-stmt case + // ?? deal with copy statement + auto e_instr = dynamic_cast(e); + auto e_const = dynamic_cast(e); + assert((not e or e_instr or e_const) && + "A value must be from an instruction or constant"); + // erase the old record for x + std::set::iterator it; + for (auto c : pin) + if ((it = std::find(c->members_.begin(), c->members_.end(), x)) != + c->members_.end()) { + c->members_.erase(it); + } + // TODO: get different ValueExpr by Instruction::OpID, modify pout - std::set::iterator iter; + // ?? + // get ve and vpf + shared_ptr ve; + if (e) { + if (e_const) + ve = ConstantExpression::create(e_const); + else + ve = valueExpr(e_instr, &pin); + } else + ve = valueExpr(x, &pin); + auto vpf = valuePhiFunc(ve, curr_bb); + for (auto c : pout) { - if ((iter = std::find(c->members_.begin(), c->members_.end(), x)) != - c->members_.end()) { - // static_cast(x))) != c->members_.end()) { - c->members_.erase(iter); + if (ve == c->value_expr_ or (vpf and vpf == c->value_phi_)) { + c->value_expr_ = ve; + c->members_.insert(x); + } else { + auto c = createCongruenceClass(new_number()); + c->members_.insert(x); + c->value_expr_ = ve; + c->value_phi_ = vpf; + pout.insert(c); } } - auto ve = valueExpr(x); - auto vpf = valuePhiFunc(ve, pin); - /* pout.insert({}); - * auto c = CongruenceClass(new_number()); - * c.leader_ = e; */ + /* // first version: ignore ve and vpf + * // and only update index, leader and members + * auto c = createCongruenceClass(new_number()); + * c->leader_ = x; + * c->members_.insert(x); + * pout.insert(c); */ return pout; } +/* + * read the pin for the block and then execute transferFunction() for all + * instructions inside. + */ GVN::partitions GVN::transferFunction(BasicBlock *bb) { - partitions pout = clone(pin_[bb]); - // ?? - return pout; + curr_bb = bb; + int res; + auto part = pin_[bb]; + /* LOG_INFO << "transferFunction(bb=" << bb->get_name() << ")\n"; + * LOG_INFO << "pin:\n"; + * utils::print_partitions(pin_[bb]); + * LOG_INFO << "pout before:\n"; + * utils::print_partitions(pout_[bb]); */ + + // iterate through all instructions in the block + for (auto &instr : bb->get_instructions()) { + // ?? what about orther instructions? Are they all ok? + if (not instr.is_phi() and not instr.is_void()) + part = transferFunction(&instr, nullptr, part); + } + // and the phi instruction in all the successors + for (auto succ : bb->get_succ_basic_blocks()) { + for (auto &instr : succ->get_instructions()) { + if (instr.is_phi()) { + if ((res = pretend_copy_stmt(&instr, bb) == -1)) + continue; + part = transferFunction(&instr, instr.get_operand(res), part); + } + } + } + /* LOG_INFO << "pout after:\n"; + * utils::print_partitions(part); + * std::cout << std::endl; */ + return part; } shared_ptr -GVN::valuePhiFunc(shared_ptr ve, const partitions &P) { +GVN::valuePhiFunc(shared_ptr ve, BasicBlock *bb) { // TODO - return {}; + if (ve->get_expr_type() != Expression::e_bin) + return nullptr; + auto ve_bin = static_cast(ve.get()); + if (ve_bin->lhs_->get_expr_type() != Expression::e_phi or + ve_bin->rhs_->get_expr_type() != Expression::e_phi) + return nullptr; + + // get 2 phi expressions + auto lhs = static_cast(ve_bin->lhs_.get()); + auto rhs = static_cast(ve_bin->rhs_.get()); + // get 2 predecessors + auto pre_bbs_ = bb->get_pre_basic_blocks(); + auto pre_1 = *pre_bbs_.begin(); + auto pre_2 = *(++pre_bbs_.begin()); + + // try to get the merged value expression + auto vl_merge = BinaryExpression::create(ve_bin->op_, lhs->lhs_, rhs->lhs_); + auto vr_merge = BinaryExpression::create(ve_bin->op_, lhs->rhs_, rhs->rhs_); + auto vi = getVN(pout_[pre_1], vl_merge); + auto vj = getVN(pout_[pre_2], vr_merge); + if (vi == nullptr) + vi = valuePhiFunc(vl_merge, pre_1); + if (vj == nullptr) + vj = valuePhiFunc(vr_merge, pre_2); + + if (vi and vj) + return PhiExpression::create(vi, vj); + else + return nullptr; } shared_ptr GVN::getVN(const partitions &pout, shared_ptr ve) { // TODO: return what? + /* for (auto c : pout) { + * if (c->value_expr_ == ve) + * return ve; + * } */ for (auto it = pout.begin(); it != pout.end(); it++) if ((*it)->value_expr_ and *(*it)->value_expr_ == *ve) - return {}; + return ve; return nullptr; } @@ -490,6 +676,12 @@ GVNExpression::operator==(const Expression &lhs, const Expression &rhs) { return equiv_as(lhs, rhs); case Expression::e_phi: return equiv_as(lhs, rhs); + case Expression::e_cast: + return equiv_as(lhs, rhs); + case Expression::e_gep: + return equiv_as(lhs, rhs); + case Expression::e_unique: + return equiv_as(lhs, rhs); } } @@ -517,12 +709,35 @@ operator==(const GVN::partitions &p1, const GVN::partitions &p2) { // cannot direct compare??? if (p1.size() != p2.size()) return false; - return std::equal(p1.begin(), p1.end(), p2.begin(), p2.end()); + auto it1 = p1.begin(); + auto it2 = p2.begin(); + for (; it1 != p1.end(); ++it1, ++it2) + if (not(**it1 == **it2)) + return false; + return true; } -// only compare index +// only compare members bool CongruenceClass::operator==(const CongruenceClass &other) const { // TODO: which fields need to be compared? - return index_ == other.index_; + if (members_.size() != other.members_.size()) + return false; + return members_ == other.members_; +} + +int +GVN::pretend_copy_stmt(Instruction *instr, BasicBlock *bb) { + auto phi = static_cast(instr); + // res = phi [op1, name1], [op2, name2] + // ^0 ^1 ^2 ^3 + if (phi->get_operand(1)->get_name() == bb->get_name()) { + // pretend copy statement: + // `res = op1` + return 0; + } else if (phi->get_operand(3)->get_name() == bb->get_name()) { + // `res = op2` + return 2; + } + return -1; } diff --git a/tests/4-ir-opt/testcases/GVN/functional/.gitignore b/tests/4-ir-opt/testcases/GVN/functional/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..fd645655007deecf92e33c8371dd762c0251d69a --- /dev/null +++ b/tests/4-ir-opt/testcases/GVN/functional/.gitignore @@ -0,0 +1 @@ +*.ll