#pragma once #include "BasicBlock.h" #include "Constant.h" #include "DeadCode.h" #include "FuncInfo.h" #include "Function.h" #include "IRprinter.h" #include "Instruction.h" #include "Module.h" #include "PassManager.hpp" #include "Value.h" #include #include #include #include #include #include #include #include #include // #define __DEBUG__ class GVN; namespace GVNExpression { // fold the constant value class ConstFolder { public: ConstFolder(Module *m) : module_(m) {} Constant *compute(Instruction *instr, Constant *value1, Constant *value2); Constant *compute(Instruction *instr, Constant *value1); private: Module *module_; }; /** * for constructor of class derived from `Expression`, we make it public * because `std::make_shared` needs the constructor to be publicly available, * but you should call the static factory method `create` instead the * constructor itself to get the desired data */ class Expression { public: // TODO: you need to extend expression types according to testcases enum gvn_expr_t { e_constant, e_bin, e_cmp, e_phi, e_cast, e_gep, e_unique, e_global, e_argument, e_call }; Expression(gvn_expr_t t) : expr_type(t) {} virtual ~Expression() = default; virtual std::string print() = 0; gvn_expr_t get_expr_type() const { return expr_type; } protected: gvn_expr_t expr_type; }; bool operator==(const std::shared_ptr &lhs, const std::shared_ptr &rhs); bool operator==(const GVNExpression::Expression &lhs, const GVNExpression::Expression &rhs); class ConstantExpression : public Expression { public: static std::shared_ptr create(Constant *c) { return std::make_shared(c); } virtual std::string print() { return c_->print(); } // we leverage the fact that constants in lightIR have unique addresses bool equiv(const ConstantExpression *other) const { return c_ == other->c_; } ConstantExpression(Constant *c) : Expression(e_constant), c_(c) {} Constant *get_val() { return c_; } private: Constant *c_; }; // arithmetic expression class BinaryExpression : public Expression { friend class ::GVN; public: static std::shared_ptr create( Instruction::OpID op, std::shared_ptr lhs, std::shared_ptr rhs) { return std::make_shared(op, lhs, rhs); } virtual std::string print() { return "(" + Instruction::get_instr_op_name(op_) + " " + lhs_->print() + " " + rhs_->print() + ")"; } bool equiv(const BinaryExpression *other) const { if (op_ == other->op_ and *lhs_ == *other->lhs_ and *rhs_ == *other->rhs_) return true; else return false; } BinaryExpression(Instruction::OpID op, std::shared_ptr lhs, std::shared_ptr rhs, gvn_expr_t tp = e_bin) : Expression(tp), op_(op), lhs_(lhs), rhs_(rhs) {} protected: Instruction::OpID op_; std::shared_ptr lhs_, rhs_; }; // cmp expression, inherited from binary-expression class CmpExpression : public BinaryExpression { friend class ::GVN; typedef union { CmpInst::CmpOp int_cmp_op; FCmpInst::CmpOp float_cmp_op; } cmp_op_i_or_f; public: static std::shared_ptr create( Instruction::OpID op, cmp_op_i_or_f cmpop, std::shared_ptr lhs, std::shared_ptr rhs) { return std::make_shared(op, cmpop, lhs, rhs); } virtual std::string print() { std::string ret{"(" + Instruction::get_instr_op_name(op_) + " "}; ret += op_ == Instruction::cmp ? print_cmp_type(cmpop_.int_cmp_op) : print_fcmp_type(cmpop_.float_cmp_op); ret += " " + lhs_->print() + " " + rhs_->print() + ")"; return ret; } bool equiv(const CmpExpression *other) const { if (not(op_ == other->op_)) return false; if (op_ == Instruction::cmp) { if (cmpop_.int_cmp_op != other->cmpop_.int_cmp_op) return false; } else { // op_ == Instruction::fcmp if (cmpop_.float_cmp_op != other->cmpop_.float_cmp_op) return false; } return *lhs_ == *other->lhs_ and *rhs_ == *other->rhs_; } CmpExpression(Instruction::OpID op, cmp_op_i_or_f cmpop, std::shared_ptr lhs, std::shared_ptr rhs) : BinaryExpression(op, lhs, rhs, e_cmp), cmpop_(cmpop) { assert(op == Instruction::cmp or op == Instruction::fcmp && "Wrong instruction type!"); } private: cmp_op_i_or_f cmpop_; }; class PhiExpression : public Expression { friend class ::GVN; public: static std::shared_ptr create( std::shared_ptr lhs, std::shared_ptr rhs) { return std::make_shared(lhs, rhs); } virtual std::string print() { return "(phi " + lhs_->print() + " " + rhs_->print() + ")"; } bool equiv(const PhiExpression *other) const { if (*lhs_ == *other->lhs_ and *rhs_ == *other->rhs_) return true; else return false; } PhiExpression(std::shared_ptr lhs, std::shared_ptr rhs) : Expression(e_phi), lhs_(lhs), rhs_(rhs) {} private: std::shared_ptr lhs_, rhs_; }; // type cast expression class CastExpression : public Expression { public: static std::shared_ptr create( Instruction::OpID op, std::shared_ptr src, Type *dest_type) { return std::make_shared(op, src, dest_type); } virtual std::string print() { return "(" + dest_ty_->print() + " " + Instruction::get_instr_op_name(op_) + " " + src_->print() + ")"; } bool equiv(const CastExpression *other) const { return op_ == other->op_ and src_ == other->src_ and dest_ty_ == other->dest_ty_; } CastExpression(Instruction::OpID op, std::shared_ptr src, Type *dest_type) : Expression(e_cast), op_(op), src_(src), dest_ty_(dest_type) {} private: Instruction::OpID op_; std::shared_ptr src_; Type *dest_ty_; }; // type cast expression class GEPExpression : public Expression { public: static std::shared_ptr create( std::shared_ptr ptr, std::vector> &idxs) { return std::make_shared(ptr, idxs); } virtual std::string print() { std::string ret = "(GEP " + ptr_->print(); for (auto idx : idxs_) ret += " " + idx->print(); return ret + ")"; } bool equiv(const GEPExpression *other) const { if (idxs_.size() != other->idxs_.size()) return false; for (int i = 0; i != idxs_.size(); ++i) if (not(idxs_[i] == other->idxs_[i])) return false; return ptr_ == other->ptr_; } GEPExpression(std::shared_ptr ptr, std::vector> &idxs) : Expression(e_gep), ptr_(ptr), idxs_(idxs) {} private: std::shared_ptr ptr_; std::vector> idxs_; }; class UniqueExpression : public Expression { public: static std::shared_ptr create(Instruction *instr, size_t index) { return std::make_shared(instr, index); } // virtual std::string print() { return "(UNIQUE " + instr_->print() + ")"; // } virtual std::string print() { return "v" + std::to_string(index_); } bool equiv(const UniqueExpression *other) const { return instr_ == other->instr_; } UniqueExpression(Instruction *instr, size_t index, gvn_expr_t tp = e_unique) : Expression(tp), instr_(instr), index_(index) {} protected: Instruction *instr_; size_t index_; }; // global variables class GlobalVarExpression : public Expression { public: static std::shared_ptr create(GlobalVariable *g) { return std::make_shared(g); } virtual std::string print() { return "<" + g_->get_name() + ">"; } bool equiv(const GlobalVarExpression *other) const { return g_ == other->g_; } GlobalVarExpression(GlobalVariable *g) : Expression(e_global), g_(g) {} private: GlobalVariable *g_; }; // this is the arguments passed to the current function, but not the arguments // given to a function call. class ArgExpression : public Expression { public: static std::shared_ptr create(Argument *a) { return std::make_shared(a); } virtual std::string print() { return "[" + a_->get_name() + "]"; } bool equiv(const ArgExpression *other) const { return a_ == other->a_; } ArgExpression(Argument *a) : Expression(e_argument), a_(a) {} private: Argument *a_; }; // function call class CallExpression : public UniqueExpression { public: static std::shared_ptr create( Instruction *instr, size_t index, bool pure, std::vector> &args) { return std::make_shared(instr, index, pure, args); } virtual std::string print() { std::string ret = "{ " + f_->get_name(); for (auto arg : args_) ret += " " + arg->print(); return ret + " }"; } bool equiv(const CallExpression *other) const { // single instruction, should be same if (instr_ == other->instr_) return true; if (not(pure_ and f_ == other->f_)) return false; for (int i = 0; i < args_.size(); i++) if (not(*args_[i] == *other->args_[i])) return false; return true; } CallExpression(Instruction *instr, size_t index, bool pure, std::vector> &args) : UniqueExpression(instr, index, e_call) , pure_(pure) , f_(static_cast(instr->get_operand(0))) , args_(args) {} private: bool pure_; Function *f_; std::vector> args_; }; } // namespace GVNExpression /** * Congruence class in each partitions * note: for constant propagation, you might need to add other fields * and for load/store redundancy detection, you most certainly need to modify * the class */ struct CongruenceClass { size_t index_; // representative of the congruence class, used to replace all the members // (except itself) when analysis is done Value *leader_; // value expression in congruence class std::shared_ptr value_expr_; // value φ-function is an annotation of the congruence class std::shared_ptr value_phi_; // equivalent variables in one congruence class std::set members_; CongruenceClass(size_t index) : index_(index), leader_{}, value_expr_{}, value_phi_{}, members_{} {} bool operator<(const CongruenceClass &other) const { return this->index_ < other.index_; } bool operator==(const CongruenceClass &other) const; }; namespace std { template <> // overload std::less for std::shared_ptr, i.e. how to sort the // congruence classes struct less> { bool operator()(const std::shared_ptr &a, const std::shared_ptr &b) const { // nullptrs should never appear in partitions, so we just dereference it return *a < *b; } }; } // namespace std class GVN : public Pass { public: using partitions = std::set>; GVN(Module *m, bool dump_json) : Pass(m), dump_json_(dump_json) {} // pass start void run() override; // init for pass metadata; void initPerFunction(); // fill the following functions according to Pseudocode, **you might need to // add more arguments** void detectEquivalences(); partitions join(const partitions &P1, const partitions &P2); std::shared_ptr intersect( std::shared_ptr, std::shared_ptr); partitions transferFunction(Instruction *x, Value *e, partitions pin); partitions transferFunction(BasicBlock *bb); std::shared_ptr valuePhiFunc( std::shared_ptr, BasicBlock *bb, Instruction *instr); std::shared_ptr valueExpr( Instruction *instr, const partitions &part); std::shared_ptr getVN( const partitions &pout, std::shared_ptr ve); // replace cc members with leader void replace_cc_members(); // note: be careful when to use copy constructor or clone partitions clone(const partitions &p); // create congruence class helper std::shared_ptr createCongruenceClass(size_t index = 0) { return std::make_shared(index); } private: bool dump_json_; std::uint64_t next_value_number_; Function *func_; std::map pin_, pout_; std::unique_ptr func_info_; std::unique_ptr folder_; std::unique_ptr dce_; // self add member // std::uint64_t start_number_; BasicBlock *curr_bb; std::map _TOP; std::map, size_t> start_idx_; std::map> global_map_; std::map> arg_map_; // // self add function // void add_map_(); std::uint64_t new_number() { return next_value_number_++; } void deal_with_entry(BasicBlock *Entry); static int pretend_copy_stmt(Instruction *inst, BasicBlock *bb); std::shared_ptr search_ve( Value *v, const partitions &part); partitions join_helper(BasicBlock *pre1, BasicBlock *pre2); std::vector> valueExpr_core_( Instruction *instr, const partitions &part, const size_t count, bool fold_ = false, int start_index_ = 0); Constant *constFold_core(const size_t count, const partitions &part, Instruction *instr, const std::vector &operands); void assign_start_idx_(); void dump_tmp(Function &); void reset_number() { next_value_number_ = 1; } }; bool operator==(const GVN::partitions &p1, const GVN::partitions &p2);