#include "GVN.h" #include "BasicBlock.h" #include "Constant.h" #include "DeadCode.h" #include "FuncInfo.h" #include "Function.h" #include "Instruction.h" #include "logging.hpp" #include #include #include #include #include #include #include #include #include using namespace GVNExpression; using std::string_literals::operator""s; using std::shared_ptr; static auto get_const_int_value = [](Value *v) { return dynamic_cast(v)->get_value(); }; static auto get_const_fp_value = [](Value *v) { return dynamic_cast(v)->get_value(); }; // Constant Propagation helper, folders are done for you Constant * ConstFolder::compute(Instruction *instr, Constant *value1, Constant *value2) { auto op = instr->get_instr_type(); switch (op) { case Instruction::add: return ConstantInt::get(get_const_int_value(value1) + get_const_int_value(value2), module_); case Instruction::sub: return ConstantInt::get(get_const_int_value(value1) - get_const_int_value(value2), module_); case Instruction::mul: return ConstantInt::get(get_const_int_value(value1) * get_const_int_value(value2), module_); case Instruction::sdiv: return ConstantInt::get(get_const_int_value(value1) / get_const_int_value(value2), module_); case Instruction::fadd: return ConstantFP::get(get_const_fp_value(value1) + get_const_fp_value(value2), module_); case Instruction::fsub: return ConstantFP::get(get_const_fp_value(value1) - get_const_fp_value(value2), module_); case Instruction::fmul: return ConstantFP::get(get_const_fp_value(value1) * get_const_fp_value(value2), module_); case Instruction::fdiv: return ConstantFP::get(get_const_fp_value(value1) / get_const_fp_value(value2), module_); case Instruction::cmp: switch (dynamic_cast(instr)->get_cmp_op()) { case CmpInst::EQ: return ConstantInt::get(get_const_int_value(value1) == get_const_int_value(value2), module_); case CmpInst::NE: return ConstantInt::get(get_const_int_value(value1) != get_const_int_value(value2), module_); case CmpInst::GT: return ConstantInt::get(get_const_int_value(value1) > get_const_int_value(value2), module_); case CmpInst::GE: return ConstantInt::get(get_const_int_value(value1) >= get_const_int_value(value2), module_); case CmpInst::LT: return ConstantInt::get(get_const_int_value(value1) < get_const_int_value(value2), module_); case CmpInst::LE: return ConstantInt::get(get_const_int_value(value1) <= get_const_int_value(value2), module_); } case Instruction::fcmp: switch (dynamic_cast(instr)->get_cmp_op()) { case FCmpInst::EQ: return ConstantInt::get(get_const_fp_value(value1) == get_const_fp_value(value2), module_); case FCmpInst::NE: return ConstantInt::get(get_const_fp_value(value1) != get_const_fp_value(value2), module_); case FCmpInst::GT: return ConstantInt::get(get_const_fp_value(value1) > get_const_fp_value(value2), module_); case FCmpInst::GE: return ConstantInt::get(get_const_fp_value(value1) >= get_const_fp_value(value2), module_); case FCmpInst::LT: return ConstantInt::get(get_const_fp_value(value1) < get_const_fp_value(value2), module_); case FCmpInst::LE: return ConstantInt::get(get_const_fp_value(value1) <= get_const_fp_value(value2), module_); } default: return nullptr; } } Constant * ConstFolder::compute(Instruction *instr, Constant *value1) { auto op = instr->get_instr_type(); switch (op) { case Instruction::sitofp: return ConstantFP::get((float)get_const_int_value(value1), module_); case Instruction::fptosi: return ConstantInt::get((int)get_const_fp_value(value1), module_); case Instruction::zext: return ConstantInt::get((int)get_const_int_value(value1), module_); default: return nullptr; } } namespace utils { static std::string print_congruence_class(const CongruenceClass &cc) { std::stringstream ss; if (cc.index_ == 0) { ss << "top class\n"; return ss.str(); } ss << "\nindex: " << cc.index_ << "\nleader: " << cc.leader_->print() << "\nvalue phi: " << (cc.value_phi_ ? cc.value_phi_->print() : "nullptr"s) << "\nvalue expr: " << (cc.value_expr_ ? cc.value_expr_->print() : "nullptr"s) << "\nmembers: {"; for (auto &member : cc.members_) ss << member->print() << "; "; ss << "}\n"; return ss.str(); } static std::string dump_cc_json(const CongruenceClass &cc) { std::string json; json += "["; for (auto member : cc.members_) { if (auto c = dynamic_cast(member)) json += member->print() + ", "; else json += "\"%" + member->get_name() + "\", "; } json += "]"; return json; } static std::string dump_partition_json(const GVN::partitions &p) { std::string json; json += "["; for (auto cc : p) json += dump_cc_json(*cc) + ", "; json += "]"; return json; } static std::string dump_bb2partition(const std::map &map) { std::string json; json += "{"; for (auto [bb, p] : map) json += "\"" + bb->get_name() + "\": " + dump_partition_json(p) + ","; json += "}"; return json; } // logging utility for you static void print_partitions(const GVN::partitions &p) { if (p.empty()) { LOG_DEBUG << "empty partitions\n"; return; } std::string log; for (auto &cc : p) log += print_congruence_class(*cc); LOG_DEBUG << log; // please don't use std::cout } } // namespace utils GVN::partitions GVN::join_helper(BasicBlock *pre1, BasicBlock *pre2) { assert(not _TOP[pre1] or not _TOP[pre2] && "should flow here, not jump"); if (_TOP[pre1]) return pout_[pre2]; else if (_TOP[pre2]) return pout_[pre1]; return join(pout_[pre1], pout_[pre2]); } GVN::partitions GVN::join(const partitions &P1, const partitions &P2) { // TODO: do intersection pair-wise partitions P{}; for (auto c1 : P1) for (auto c2 : P2) { auto c = intersect(c1, c2); if (c->members_.empty()) continue; P.insert(c); } return P; } std::shared_ptr GVN::intersect(std::shared_ptr ci, std::shared_ptr cj) { // TODO auto c = createCongruenceClass(); std::set intersection; std::set_intersection(ci->members_.begin(), ci->members_.end(), cj->members_.begin(), cj->members_.end(), std::inserter(intersection, intersection.begin())); c->members_ = intersection; if (ci->index_ == cj->index_) c->index_ = ci->index_; if (ci->value_expr_ == cj->value_expr_) c->value_expr_ = ci->value_expr_; if (ci->value_phi_ and cj->value_phi_ and *ci->value_phi_ == *cj->value_phi_) c->value_phi_ = ci->value_phi_; // if (c->members_.size() or c->value_expr_ or c->value_phi_) // not empty // ?? // What if the ve is nullptr? if (c->members_.size()) // not empty if (c->index_ == 0) { c->index_ = new_number(); c->value_phi_ = PhiExpression::create(ci->value_expr_, cj->value_expr_); } return c; } void GVN::detectEquivalences() { bool changed; std::cout << "all the instruction address:" << std::endl; for (auto &bb : func_->get_basic_blocks()) { for (auto &instr : bb.get_instructions()) std::cout << &instr << "\t" << instr.print() << std::endl; } // initialize pout with top for (auto &bb : func_->get_basic_blocks()) { _TOP[&bb] = true; } auto Entry = func_->get_entry_block(); _TOP[Entry] = false; pin_[Entry].clear(); pout_[Entry] = transferFunction(Entry); // iterate until converge do { changed = false; for (auto &_bb : func_->get_basic_blocks()) { auto bb = &_bb; if (bb == Entry) continue; // get PIN of bb from predecessor(s) auto pre_bbs_ = bb->get_pre_basic_blocks(); if (bb != Entry) { // only update PIN for blocks that are not Entry // that is: the PIN for entry is always empty switch (pre_bbs_.size()) { case 2: { auto pre_1 = *pre_bbs_.begin(); auto pre_2 = *(++pre_bbs_.begin()); pin_[bb] = join_helper(pre_1, pre_2); break; } case 1: { auto pre = *(pre_bbs_.begin()); pin_[bb] = pout_[pre]; break; } default: LOG_ERROR << "block has count of predecessors: " << pre_bbs_.size(); abort(); } } auto part = transferFunction(bb); // check changes in pout changed |= not(part == pout_[bb]); pout_[bb] = part; _TOP[bb] = false; } } while (changed); } shared_ptr GVN::valueExpr(Instruction *instr, partitions *part) { // TODO // ?? should use part? std::string err{"Undefined"}; std::cout << instr->print() << std::endl; if (instr->isBinary() or instr->is_cmp() or instr->is_fcmp()) { auto op1 = instr->get_operand(0); auto op2 = instr->get_operand(1); auto op1_const = dynamic_cast(op1); auto op2_const = dynamic_cast(op2); if (op1_const and op2_const) { // both are constant number, so: // constant fold! return ConstantExpression::create( folder_->compute(instr, op1_const, op2_const)); } else { // both none constant auto op1_instr = dynamic_cast(op1); auto op2_instr = dynamic_cast(op2); assert((op1_instr or op1_const) and (op2_instr or op2_const) && "must be this case"); return BinaryExpression::create( instr->get_instr_type(), (op1_const ? ConstantExpression::create(op1_const) : valueExpr(op1_instr)), (op2_const ? ConstantExpression::create(op2_const) : valueExpr(op2_instr))); } } else if (instr->is_phi()) { err = "phi"; } else if (instr->is_fp2si() or instr->is_si2fp() or instr->is_zext()) { auto op = instr->get_operand(0); auto op_const = dynamic_cast(op); auto op_instr = dynamic_cast(op); assert(op_instr or op_const); // get dest type auto instr_fp2si = dynamic_cast(instr); auto instr_si2fp = dynamic_cast(instr); auto instr_zext = dynamic_cast(instr); Type *dest_type = nullptr; if (instr_fp2si) dest_type = instr_fp2si->get_dest_type(); else if (instr_si2fp) dest_type = instr_si2fp->get_dest_type(); else if (instr_zext) dest_type = instr_zext->get_dest_type(); else err = "cast"; if (dest_type) { if (op_const) return ConstantExpression::create( folder_->compute(instr, op_const)); else return CastExpression::create( instr->get_instr_type(), valueExpr(op_instr), dest_type); } } else if (instr->is_gep()) { auto operands = instr->get_operands(); auto ptr = operands[0]; std::vector> idxs; // check for base address assert(not dynamic_cast(ptr) and dynamic_cast(ptr) && "base address should only be from instruction"); // set idxes for (int i = 1; i < operands.size(); i++) { if (dynamic_cast(operands[i])) idxs.push_back(ConstantExpression::create( dynamic_cast(operands[i]))); else { assert(dynamic_cast(operands[i])); idxs.push_back( valueExpr(dynamic_cast(operands[i]))); } } return GEPExpression::create( valueExpr(dynamic_cast(ptr)), idxs); } else if (instr->is_load() or instr->is_alloca() or instr->is_call()) { return UniqueExpression::create(instr); } std::cerr << "Undefined case: " << err << std::endl; abort(); } // instruction of the form `x = e`, mostly x is just e (SSA), // but for copy stmt x is a phi instruction in the successor. // Phi values (not copy stmt) should be handled in detectEquiv // // assert the x is an instruction that can generate a new value // /// \param bb basic block in which the transfer function is called GVN::partitions GVN::transferFunction(Instruction *x, Value *e, partitions pin) { partitions pout = pin; // TODO: deal with copy-stmt case // ?? deal with copy statement auto e_instr = dynamic_cast(e); auto e_const = dynamic_cast(e); assert((not e or e_instr or e_const) && "A value must be from an instruction or constant"); // erase the old record for x std::set::iterator it; for (auto c : pin) if ((it = std::find(c->members_.begin(), c->members_.end(), x)) != c->members_.end()) { c->members_.erase(it); } // TODO: get different ValueExpr by Instruction::OpID, modify pout // ?? // get ve and vpf shared_ptr ve; if (e) { if (e_const) ve = ConstantExpression::create(e_const); else ve = valueExpr(e_instr, &pin); } else ve = valueExpr(x, &pin); auto vpf = valuePhiFunc(ve, curr_bb); for (auto c : pout) { if (ve == c->value_expr_ or (vpf and vpf == c->value_phi_)) { c->value_expr_ = ve; c->members_.insert(x); } else { auto c = createCongruenceClass(new_number()); c->members_.insert(x); c->value_expr_ = ve; c->value_phi_ = vpf; pout.insert(c); } } /* // first version: ignore ve and vpf * // and only update index, leader and members * auto c = createCongruenceClass(new_number()); * c->leader_ = x; * c->members_.insert(x); * pout.insert(c); */ return pout; } /* * read the pin for the block and then execute transferFunction() for all * instructions inside. */ GVN::partitions GVN::transferFunction(BasicBlock *bb) { curr_bb = bb; int res; auto part = pin_[bb]; /* LOG_INFO << "transferFunction(bb=" << bb->get_name() << ")\n"; * LOG_INFO << "pin:\n"; * utils::print_partitions(pin_[bb]); * LOG_INFO << "pout before:\n"; * utils::print_partitions(pout_[bb]); */ // iterate through all instructions in the block for (auto &instr : bb->get_instructions()) { // ?? what about orther instructions? Are they all ok? if (not instr.is_phi() and not instr.is_void()) part = transferFunction(&instr, nullptr, part); } // and the phi instruction in all the successors for (auto succ : bb->get_succ_basic_blocks()) { for (auto &instr : succ->get_instructions()) { if (instr.is_phi()) { if ((res = pretend_copy_stmt(&instr, bb) == -1)) continue; part = transferFunction(&instr, instr.get_operand(res), part); } } } /* LOG_INFO << "pout after:\n"; * utils::print_partitions(part); * std::cout << std::endl; */ return part; } shared_ptr GVN::valuePhiFunc(shared_ptr ve, BasicBlock *bb) { // TODO if (ve->get_expr_type() != Expression::e_bin) return nullptr; auto ve_bin = static_cast(ve.get()); if (ve_bin->lhs_->get_expr_type() != Expression::e_phi or ve_bin->rhs_->get_expr_type() != Expression::e_phi) return nullptr; // get 2 phi expressions auto lhs = static_cast(ve_bin->lhs_.get()); auto rhs = static_cast(ve_bin->rhs_.get()); // get 2 predecessors auto pre_bbs_ = bb->get_pre_basic_blocks(); auto pre_1 = *pre_bbs_.begin(); auto pre_2 = *(++pre_bbs_.begin()); // try to get the merged value expression auto vl_merge = BinaryExpression::create(ve_bin->op_, lhs->lhs_, rhs->lhs_); auto vr_merge = BinaryExpression::create(ve_bin->op_, lhs->rhs_, rhs->rhs_); auto vi = getVN(pout_[pre_1], vl_merge); auto vj = getVN(pout_[pre_2], vr_merge); if (vi == nullptr) vi = valuePhiFunc(vl_merge, pre_1); if (vj == nullptr) vj = valuePhiFunc(vr_merge, pre_2); if (vi and vj) return PhiExpression::create(vi, vj); else return nullptr; } shared_ptr GVN::getVN(const partitions &pout, shared_ptr ve) { // TODO: return what? /* for (auto c : pout) { * if (c->value_expr_ == ve) * return ve; * } */ for (auto it = pout.begin(); it != pout.end(); it++) if ((*it)->value_expr_ and *(*it)->value_expr_ == *ve) return ve; return nullptr; } void GVN::initPerFunction() { next_value_number_ = 1; pin_.clear(); pout_.clear(); } void GVN::replace_cc_members() { for (auto &[_bb, part] : pout_) { auto bb = _bb; // workaround: structured bindings can't be captured in C++17 for (auto &cc : part) { if (cc->index_ == 0) continue; // if you are planning to do constant propagation, leaders should be // set to constant at some point for (auto &member : cc->members_) { bool member_is_phi = dynamic_cast(member); bool value_phi = cc->value_phi_ != nullptr; if (member != cc->leader_ and (value_phi or !member_is_phi)) { // only replace the members if users are in the same block // as bb member->replace_use_with_when( cc->leader_, [bb](User *user) { if (auto instr = dynamic_cast(user)) { auto parent = instr->get_parent(); auto &bb_pre = parent->get_pre_basic_blocks(); if (instr->is_phi()) // as copy stmt, the phi // belongs to this block return std::find(bb_pre.begin(), bb_pre.end(), bb) != bb_pre.end(); else return parent == bb; } return false; }); } } } } return; } // top-level function, done for you void GVN::run() { std::ofstream gvn_json; if (dump_json_) { gvn_json.open("gvn.json", std::ios::out); gvn_json << "["; } folder_ = std::make_unique(m_); func_info_ = std::make_unique(m_); func_info_->run(); dce_ = std::make_unique(m_); dce_->run(); // let dce take care of some dead phis with undef for (auto &f : m_->get_functions()) { if (f.get_basic_blocks().empty()) continue; func_ = &f; initPerFunction(); LOG_INFO << "Processing " << f.get_name(); detectEquivalences(); LOG_INFO << "===============pin=========================\n"; for (auto &[bb, part] : pin_) { LOG_INFO << "\n===============bb: " << bb->get_name() << "=========================\npartitionIn: "; for (auto &cc : part) LOG_INFO << utils::print_congruence_class(*cc); } LOG_INFO << "\n===============pout=========================\n"; for (auto &[bb, part] : pout_) { LOG_INFO << "\n=====bb: " << bb->get_name() << "=====\npartitionOut: "; for (auto &cc : part) LOG_INFO << utils::print_congruence_class(*cc); } if (dump_json_) { gvn_json << "{\n\"function\": "; gvn_json << "\"" << f.get_name() << "\", "; gvn_json << "\n\"pout\": " << utils::dump_bb2partition(pout_); gvn_json << "},"; } replace_cc_members(); // don't delete instructions, just replace them } dce_->run(); // let dce do that for us if (dump_json_) gvn_json << "]"; } template static bool equiv_as(const Expression &lhs, const Expression &rhs) { // we use static_cast because we are very sure that both operands are // actually T, not other types. return static_cast(&lhs)->equiv(static_cast(&rhs)); } bool GVNExpression::operator==(const Expression &lhs, const Expression &rhs) { if (lhs.get_expr_type() != rhs.get_expr_type()) return false; switch (lhs.get_expr_type()) { case Expression::e_constant: return equiv_as(lhs, rhs); case Expression::e_bin: return equiv_as(lhs, rhs); case Expression::e_phi: return equiv_as(lhs, rhs); case Expression::e_cast: return equiv_as(lhs, rhs); case Expression::e_gep: return equiv_as(lhs, rhs); case Expression::e_unique: return equiv_as(lhs, rhs); } } bool GVNExpression::operator==(const shared_ptr &lhs, const shared_ptr &rhs) { if (lhs == nullptr and rhs == nullptr) // is the nullptr check necessary here? return true; return lhs and rhs and *lhs == *rhs; } GVN::partitions GVN::clone(const partitions &p) { partitions data; for (auto &cc : p) { data.insert(std::make_shared(*cc)); } return data; } bool operator==(const GVN::partitions &p1, const GVN::partitions &p2) { // TODO: how to compare partitions? // cannot direct compare??? if (p1.size() != p2.size()) return false; auto it1 = p1.begin(); auto it2 = p2.begin(); for (; it1 != p1.end(); ++it1, ++it2) if (not(**it1 == **it2)) return false; return true; } // only compare members bool CongruenceClass::operator==(const CongruenceClass &other) const { // TODO: which fields need to be compared? if (members_.size() != other.members_.size()) return false; return members_ == other.members_; } int GVN::pretend_copy_stmt(Instruction *instr, BasicBlock *bb) { auto phi = static_cast(instr); // res = phi [op1, name1], [op2, name2] // ^0 ^1 ^2 ^3 if (phi->get_operand(1)->get_name() == bb->get_name()) { // pretend copy statement: // `res = op1` return 0; } else if (phi->get_operand(3)->get_name() == bb->get_name()) { // `res = op2` return 2; } return -1; }