diff --git a/include/optimization/GVN.h b/include/optimization/GVN.h index f4872d33136a450725eb8a3cb9e2d02630336826..8cc26eabb40566bd1963b2d3fe242c712497e027 100644 --- a/include/optimization/GVN.h +++ b/include/optimization/GVN.h @@ -42,13 +42,23 @@ class ConstFolder { class Expression { public: // TODO: you need to extend expression types according to testcases - enum gvn_expr_t { e_constant, e_bin, e_phi, e_cast, e_gep, e_unique }; + enum gvn_expr_t { + e_constant, + e_bin, + e_phi, + e_cast, + e_gep, + e_unique, + e_global, + e_argument, + e_call + }; Expression(gvn_expr_t t) : expr_type(t) {} virtual ~Expression() = default; virtual std::string print() = 0; gvn_expr_t get_expr_type() const { return expr_type; } - private: + protected: gvn_expr_t expr_type; }; @@ -198,7 +208,6 @@ class GEPExpression : public Expression { std::vector> idxs_; }; -// unique expression: not equal to any one else class UniqueExpression : public Expression { public: static std::shared_ptr create(Instruction *instr, @@ -213,13 +222,88 @@ class UniqueExpression : public Expression { return instr_ == other->instr_; } - UniqueExpression(Instruction *instr, size_t index) - : Expression(e_unique), instr_(instr), index_(index) {} + UniqueExpression(Instruction *instr, size_t index, gvn_expr_t tp = e_unique) + : Expression(tp), instr_(instr), index_(index) {} - private: + protected: Instruction *instr_; size_t index_; }; + +// global variables +class GlobalVarExpression : public Expression { + public: + static std::shared_ptr create(GlobalVariable *g) { + return std::make_shared(g); + } + virtual std::string print() { return "<" + g_->get_name() + ">"; } + bool equiv(const GlobalVarExpression *other) const { + return g_ == other->g_; + } + GlobalVarExpression(GlobalVariable *g) : Expression(e_global), g_(g) {} + + private: + GlobalVariable *g_; +}; + +// this is the arguments passed to the current function, but not the arguments +// given to a function call. +class ArgExpression : public Expression { + public: + static std::shared_ptr create(Argument *a) { + return std::make_shared(a); + } + virtual std::string print() { return "[" + a_->get_name() + "]"; } + bool equiv(const ArgExpression *other) const { return a_ == other->a_; } + ArgExpression(Argument *a) : Expression(e_argument), a_(a) {} + + private: + Argument *a_; +}; + +// function call +class CallExpression : public UniqueExpression { + public: + static std::shared_ptr create( + Instruction *instr, + size_t index, + bool pure, + std::vector> &args) { + return std::make_shared(instr, index, pure, args); + } + virtual std::string print() { + std::string ret = "{ " + f_->get_name(); + for (auto arg : args_) + ret += " " + arg->print(); + return ret + " }"; + } + + bool equiv(const CallExpression *other) const { + // single instruction, should be same + if (instr_ == other->instr_) + return true; + if (not(pure_ and f_ == other->f_)) + return false; + for (int i = 0; i < args_.size(); i++) + if (not(*args_[i] == *other->args_[i])) + return false; + return true; + } + + CallExpression(Instruction *instr, + size_t index, + bool pure, + std::vector> &args) + : UniqueExpression(instr, index, e_call) + , pure_(pure) + , f_(static_cast(instr->get_operand(0))) + , args_(args) {} + + private: + bool pure_; + Function *f_; + std::vector> args_; +}; } // namespace GVNExpression /** @@ -304,7 +388,7 @@ class GVN : public Pass { private: bool dump_json_; - std::uint64_t next_value_number_ = 1; + std::uint64_t next_value_number_; Function *func_; std::map pin_, pout_; std::unique_ptr func_info_; @@ -312,30 +396,38 @@ class GVN : public Pass { std::unique_ptr dce_; // self add member - std::map _TOP; - partitions join_helper(BasicBlock *pre1, BasicBlock *pre2); + // std::uint64_t start_number_; BasicBlock *curr_bb; + std::map _TOP; std::map, size_t> start_idx_; + std::map> + global_map_; + std::map> arg_map_; // // self add function // + void add_map_(); std::uint64_t new_number() { return next_value_number_++; } + void deal_with_entry(BasicBlock *Entry); static int pretend_copy_stmt(Instruction *inst, BasicBlock *bb); std::shared_ptr search_ve( Value *v, const partitions &part); - + partitions join_helper(BasicBlock *pre1, BasicBlock *pre2); std::vector> valueExpr_core_( Instruction *instr, const partitions &part, const size_t count, - bool fold_ = true); + bool fold_ = false, + int start_index_ = 0); Constant *constFold_core(const size_t count, const partitions &part, Instruction *instr, const std::vector &operands); void assign_start_idx_(); void dump_tmp(Function &); + void reset_number() { next_value_number_ = 1; } }; bool operator==(const GVN::partitions &p1, const GVN::partitions &p2); diff --git a/src/optimization/GVN.cpp b/src/optimization/GVN.cpp index 10e6eca8e1e22ccc876f843a3f8d3f75ee0284e5..1accee1c77bbfb1986a1d79a6364594cc44c5a40 100644 --- a/src/optimization/GVN.cpp +++ b/src/optimization/GVN.cpp @@ -199,8 +199,6 @@ dump_bb2partition(const std::map &map) { static void print_partitions(const GVN::partitions &p) { if (p.empty()) { - // to del - // std::cout << "empty partitions\n"; LOG_DEBUG << "empty partitions\n"; return; } @@ -208,8 +206,6 @@ print_partitions(const GVN::partitions &p) { for (auto &cc : p) log += print_congruence_class(*cc); LOG_DEBUG << log; // please don't use std::cout - // to del - // std::cout << log; } } // namespace utils @@ -258,7 +254,6 @@ GVN::intersect(shared_ptr ci, shared_ptr cj) { *ci->value_phi_ == *cj->value_phi_) c->value_phi_ = ci->value_phi_; - // ?? // What if the ve is nullptr? if (c->members_.size()) // not empty { @@ -280,7 +275,9 @@ GVN::intersect(shared_ptr ci, shared_ptr cj) { c->index_ = exact_idx; c->value_expr_ = c->value_phi_ = PhiExpression::create(ci->value_expr_, cj->value_expr_); + } else if (ci->value_expr_->get_expr_type() == Expression::e_call) { } + // ?? c->leader_ = *c->members_.begin(); } @@ -288,21 +285,36 @@ GVN::intersect(shared_ptr ci, shared_ptr cj) { return c; } -// assign start index for each instruction, including copy statement -// ther logic here: -// - use the same traver order as main run +// This should be called before dealing with any function. +// It will assign start index for each instruction, including copy statement. +// the logic here: +// - use the same traverse order as main run // - assign an incremental number for each instruction void GVN::assign_start_idx_() { int res; + // this should be the travers order + // + reset_number(); + // first global_map_ + next_value_number_ += global_map_.size(); + // then arguments + next_value_number_ += func_->get_num_of_args(); + // to del + /* auto arg_it = func_->arg_begin(); + * for (; arg_it != func_->arg_end(); ++arg_it) { + * start_idx_[std::make_pair(nullptr, *arg_it)] = new_number(); + * } */ + // then instructions for (auto &bb : func_->get_basic_blocks()) { for (auto &instr : bb.get_instructions()) { if (not instr.is_phi() and not instr.is_void()) + // real instructions start_idx_[std::make_pair(&instr, nullptr)] = new_number(); } - // and the phi instruction in all the successors for (auto succ : bb.get_succ_basic_blocks()) { for (auto &instr : succ->get_instructions()) { + // and the phi instruction in all the successors if (instr.is_phi()) { if ((res = pretend_copy_stmt(&instr, &bb)) == -1) continue; @@ -312,13 +324,36 @@ GVN::assign_start_idx_() { } } } - next_value_number_ = 1; + // this is necessary, cause the initialization for Entry block will use it! + reset_number(); } +void +GVN::deal_with_entry(BasicBlock *Entry) { + // add congruence class for global variables + for (auto item : global_map_) { + auto c = createCongruenceClass(new_number()); + c->leader_ = item.first; + c->members_.insert(item.first); + c->value_expr_ = item.second; + pin_[Entry].insert(c); + } + for (auto arg : func_->get_args()) { + auto c = createCongruenceClass(new_number()); + c->leader_ = arg; + c->members_.insert(arg); + c->value_expr_ = arg_map_[arg]; + pin_[Entry].insert(c); + } + // order matters! + _TOP[Entry] = false; + pout_[Entry] = transferFunction(Entry); +} void GVN::detectEquivalences() { int times = 0; bool changed; + // DEBUG part std::cout << "all the instruction address:" << std::endl; for (auto &bb : func_->get_basic_blocks()) { for (auto &instr : bb.get_instructions()) @@ -331,10 +366,7 @@ GVN::detectEquivalences() { } auto Entry = func_->get_entry_block(); - _TOP[Entry] = false; - - pin_[Entry].clear(); - pout_[Entry] = transferFunction(Entry); + deal_with_entry(Entry); // iterate until converge do { @@ -376,7 +408,8 @@ GVN::detectEquivalences() { /* std::cout << "//-------\n" * << "//after transferFunction(basic-block=" - * << bb->get_name() << "), all pout:" << std::endl; */ + * << bb->get_name() << "), all pout:" << std::endl; + */ // dump_tmp(*func_); } // reset value number @@ -387,6 +420,10 @@ GVN::detectEquivalences() { // or return nullptr shared_ptr GVN::search_ve(Value *v, const GVN::partitions &part) { + if (dynamic_cast(v)) + return global_map_.at(dynamic_cast(v)); + if (dynamic_cast(v)) + return arg_map_.at(dynamic_cast(v)); for (auto c : part) { if (std::find(c->members_.begin(), c->members_.end(), v) != c->members_.end()) { @@ -441,7 +478,8 @@ std::vector> GVN::valueExpr_core_(Instruction *instr, const partitions &part, const size_t count, - bool fold_) { + bool fold_, + int start_idx) { assert(not(fold_ and count > 2)); Value *v; @@ -462,7 +500,7 @@ GVN::valueExpr_core_(Instruction *instr, // - try to find expression that already exists // - take care of constant for (int i = 0; i != count; i++) { - v = operands[i]; + v = operands[i + start_idx]; v_const = dynamic_cast(v); if (v_const) { ret.push_back(ConstantExpression::create(v_const)); @@ -478,7 +516,6 @@ GVN::valueExpr_core_(Instruction *instr, shared_ptr GVN::valueExpr(Instruction *instr, const partitions &part) { // TODO - // ?? should use part? std::string err{"Undefined"}; std::vector> res; @@ -488,7 +525,7 @@ GVN::valueExpr(Instruction *instr, const partitions &part) { return tmp; if (instr->isBinary() or instr->is_cmp() or instr->is_fcmp()) { - res = valueExpr_core_(instr, part, 2); + res = valueExpr_core_(instr, part, 2, true); if (res.size() == 1) // constant fold return res[0]; else @@ -497,7 +534,7 @@ GVN::valueExpr(Instruction *instr, const partitions &part) { } else if (instr->is_phi()) { err = "phi"; } else if (instr->is_fp2si() or instr->is_si2fp() or instr->is_zext()) { - res = valueExpr_core_(instr, part, 1); + res = valueExpr_core_(instr, part, 1, true); if (res[0]->get_expr_type() == Expression::e_constant) return res[0]; Type *dest_type; @@ -522,11 +559,18 @@ GVN::valueExpr(Instruction *instr, const partitions &part) { instr->get_instr_type(), res[0], dest_type); } } else if (instr->is_gep()) { - res = valueExpr_core_(instr, part, instr->get_operands().size(), false); + res = valueExpr_core_(instr, part, instr->get_operands().size()); auto ptr = res[0]; res.erase(res.begin()); return GEPExpression::create(ptr, res); - } else if (instr->is_load() or instr->is_alloca() or instr->is_call()) { + } else if (instr->is_call()) { + // 0 is the function*, the arguments start from 1 + auto f = static_cast( + static_cast(instr)->get_operand(0)); + res = valueExpr_core_(instr, part, f->get_num_of_args(), false, 1); + return CallExpression::create( + instr, next_value_number_, func_info_->is_pure_function(f), res); + } else if (instr->is_load() or instr->is_alloca()) { auto ret = search_ve(instr, part); if (ret) return ret; @@ -551,7 +595,6 @@ GVN::transferFunction(Instruction *instr, Value *e, partitions pin) { partitions pout = clone(pin); // TODO: deal with copy-stmt case - // ?? deal with copy statement auto e_instr = dynamic_cast(e); auto e_const = dynamic_cast(e); assert((not e or e_instr or e_const) && @@ -618,7 +661,6 @@ GVN::transferFunction(BasicBlock *bb) { // iterate through all instructions in the block for (auto &instr : bb->get_instructions()) { - // ?? what about orther instructions? Are they all ok? if (not instr.is_phi() and not instr.is_void()) part = transferFunction(&instr, nullptr, part); } @@ -745,7 +787,7 @@ void GVN::replace_cc_members() { for (auto &[_bb, part] : pout_) { auto bb = _bb; // workaround: structured bindings can't be captured - // in C++17 + // in C++17 for (auto &cc : part) { if (cc->index_ == 0) continue; @@ -780,6 +822,18 @@ GVN::replace_cc_members() { return; } +// should be called only once and before the detectEquivalences() starts +void +GVN::add_map_() { + // add map for global variables + for (auto &glb : m_->get_global_variable()) + global_map_[&glb] = GlobalVarExpression::create(&glb); + // add map for function parameters + for (auto &f : m_->get_functions()) + for (auto arg : f.get_args()) + arg_map_[arg] = ArgExpression::create(arg); +} + // top-level function, done for you void GVN::run() { @@ -795,6 +849,7 @@ GVN::run() { dce_ = std::make_unique(m_); dce_->run(); // let dce take care of some dead phis with undef + add_map_(); for (auto &f : m_->get_functions()) { if (f.get_basic_blocks().empty()) continue; @@ -823,7 +878,6 @@ GVN::run() { gvn_json << "},"; } replace_cc_members(); // don't delete instructions, just replace - // them } dce_->run(); // let dce do that for us if (dump_json_) @@ -868,6 +922,14 @@ GVNExpression::operator==(const Expression &lhs, const Expression &rhs) { return equiv_as(lhs, rhs); case Expression::e_unique: return equiv_as(lhs, rhs); + case Expression::e_global: + return equiv_as(lhs, rhs); + case Expression::e_argument: + return equiv_as(lhs, rhs); + case Expression::e_call: + return equiv_as(lhs, rhs); + default: + abort(); } } @@ -893,7 +955,6 @@ GVN::clone(const partitions &p) { bool operator==(const GVN::partitions &p1, const GVN::partitions &p2) { // TODO: how to compare partitions? - // cannot direct compare??? if (p1.size() != p2.size()) return false; auto it1 = p1.begin();