diff --git a/include/optimization/GVN.h b/include/optimization/GVN.h index 8cc26eabb40566bd1963b2d3fe242c712497e027..4de1168bb4fe3320781c14e9095bc473c0f1f604 100644 --- a/include/optimization/GVN.h +++ b/include/optimization/GVN.h @@ -4,6 +4,7 @@ #include "DeadCode.h" #include "FuncInfo.h" #include "Function.h" +#include "IRprinter.h" #include "Instruction.h" #include "Module.h" #include "PassManager.hpp" @@ -19,6 +20,8 @@ #include <utility> #include <vector> +#define __DEBUG__ + class GVN; namespace GVNExpression { @@ -45,6 +48,7 @@ class Expression { enum gvn_expr_t { e_constant, e_bin, + e_cmp, e_phi, e_cast, e_gep, @@ -111,14 +115,66 @@ class BinaryExpression : public Expression { BinaryExpression(Instruction::OpID op, std::shared_ptr<Expression> lhs, - std::shared_ptr<Expression> rhs) - : Expression(e_bin), op_(op), lhs_(lhs), rhs_(rhs) {} + std::shared_ptr<Expression> rhs, + gvn_expr_t tp = e_bin) + : Expression(tp), op_(op), lhs_(lhs), rhs_(rhs) {} - private: + protected: Instruction::OpID op_; std::shared_ptr<Expression> lhs_, rhs_; }; +// cmp expression, inherited from binary-expression +class CmpExpression : public BinaryExpression { + friend class ::GVN; + typedef union { + CmpInst::CmpOp int_cmp_op; + FCmpInst::CmpOp float_cmp_op; + } cmp_op_i_or_f; + + public: + static std::shared_ptr<CmpExpression> create( + Instruction::OpID op, + cmp_op_i_or_f cmpop, + std::shared_ptr<Expression> lhs, + std::shared_ptr<Expression> rhs) { + return std::make_shared<CmpExpression>(op, cmpop, lhs, rhs); + } + virtual std::string print() { + std::string ret{"(" + Instruction::get_instr_op_name(op_) + " "}; + ret += op_ == Instruction::cmp ? print_cmp_type(cmpop_.int_cmp_op) + : print_fcmp_type(cmpop_.float_cmp_op); + ret += " " + lhs_->print() + " " + rhs_->print() + ")"; + return ret; + } + + bool equiv(const CmpExpression *other) const { + if (not(op_ == other->op_)) + return false; + if (op_ == Instruction::cmp) { + if (cmpop_.int_cmp_op != other->cmpop_.int_cmp_op) + return false; + } else { + // op_ == Instruction::fcmp + if (cmpop_.float_cmp_op != other->cmpop_.float_cmp_op) + return false; + } + return *lhs_ == *other->lhs_ and *rhs_ == *other->rhs_; + } + + CmpExpression(Instruction::OpID op, + cmp_op_i_or_f cmpop, + std::shared_ptr<Expression> lhs, + std::shared_ptr<Expression> rhs) + : BinaryExpression(op, lhs, rhs, e_cmp), cmpop_(cmpop) { + assert(op == Instruction::cmp or + op == Instruction::fcmp && "Wrong instruction type!"); + } + + private: + cmp_op_i_or_f cmpop_; +}; + class PhiExpression : public Expression { friend class ::GVN; diff --git a/src/optimization/GVN.cpp b/src/optimization/GVN.cpp index 7fcae18d83724489540f05b7b8568b9c0bfc3a95..378a382c7d2696a994374b10e229528a6d4eb272 100644 --- a/src/optimization/GVN.cpp +++ b/src/optimization/GVN.cpp @@ -295,6 +295,7 @@ GVN::intersect(shared_ptr<CongruenceClass> ci, shared_ptr<CongruenceClass> cj) { // It will assign start index for each instruction, including copy statement. // the logic here: // - use the same traverse order as main run +// - use the same parameters as transferFunction(x, e) // - assign an incremental number for each instruction void GVN::assign_start_idx_() { @@ -333,6 +334,9 @@ GVN::assign_start_idx_() { // this is necessary, cause the initialization for Entry block will use it! reset_number(); } + +// Add global variables and arguments into pin, these will flow to the end and +// pursue there. void GVN::deal_with_entry(BasicBlock *Entry) { // add congruence class for global variables @@ -360,11 +364,13 @@ GVN::detectEquivalences() { int times = 0; bool changed; // DEBUG part +#ifdef __DEBUG__ std::cout << "all the instruction address:" << std::endl; for (auto &bb : func_->get_basic_blocks()) { for (auto &instr : bb.get_instructions()) std::cout << &instr << "\t" << instr.print() << std::endl; } +#endif assign_start_idx_(); // initialize pout with top for (auto &bb : func_->get_basic_blocks()) { @@ -377,7 +383,9 @@ GVN::detectEquivalences() { // iterate until converge do { changed = false; +#ifdef __DEBUG__ std::cout << ++times << "th iteration" << std::endl; +#endif for (auto &_bb : func_->get_basic_blocks()) { auto bb = &_bb; if (bb == Entry) @@ -399,10 +407,13 @@ GVN::detectEquivalences() { pin_[bb] = clone(pout_[pre]); break; } - default: - LOG_ERROR << "block has count of predecessors: " + default: { + std::cerr << "In function " << func_->get_name() << ", " + << " block " << bb->get_name() + << " has count of predecessors: " << pre_bbs_.size(); - abort(); + continue; + } } } auto part = transferFunction(bb); @@ -412,11 +423,12 @@ GVN::detectEquivalences() { pout_[bb] = clone(part); _TOP[bb] = false; - /* std::cout << "//-------\n" - * << "//after transferFunction(basic-block=" - * << bb->get_name() << "), all pout:" << std::endl; - */ - // dump_tmp(*func_); +#ifdef __DEBUG__ + std::cout << "//-------\n" + << "//after transferFunction(basic-block=" + << bb->get_name() << "), all pout:" << std::endl; + dump_tmp(*func_); +#endif } // reset value number } while (changed); @@ -534,9 +546,22 @@ GVN::valueExpr(Instruction *instr, const partitions &part) { res = valueExpr_core_(instr, part, 2, true); if (res.size() == 1) // constant fold return res[0]; - else - return BinaryExpression::create( - instr->get_instr_type(), res[0], res[1]); + else { + if (instr->isBinary()) + return BinaryExpression::create( + instr->get_instr_type(), res[0], res[1]); + else { + CmpExpression::cmp_op_i_or_f cmpop; + if (instr->is_cmp()) + cmpop.int_cmp_op = + static_cast<CmpInst *>(instr)->get_cmp_op(); + else + cmpop.float_cmp_op = + static_cast<FCmpInst *>(instr)->get_cmp_op(); + return CmpExpression::create( + instr->get_instr_type(), cmpop, res[0], res[1]); + } + } } else if (instr->is_phi()) { err = "phi"; } else if (instr->is_fp2si() or instr->is_si2fp() or instr->is_zext()) { @@ -601,9 +626,14 @@ GVN::transferFunction(Instruction *instr, Value *e, partitions pin) { partitions pout = clone(pin); // TODO: deal with copy-stmt case + // if e is not null, then we are handling copy-stmt + // but will e is the global? Or will global variable be in phi? + // ?? auto e_instr = dynamic_cast<Instruction *>(e); auto e_const = dynamic_cast<Constant *>(e); - assert((not e or e_instr or e_const) && + // auto e_global = dynamic_cast<GlobalVariable *>(e); + auto e_argument = dynamic_cast<Argument *>(e); + assert((not e or e_instr or e_const or e_argument) && "A value must be from an instruction or constant"); // erase the old record for x std::set<Value *>::iterator it; @@ -622,8 +652,14 @@ GVN::transferFunction(Instruction *instr, Value *e, partitions pin) { if (e) { if (e_const) ve = ConstantExpression::create(e_const); - else + else if (e_instr) ve = valueExpr(e_instr, pout); + else if (e_argument) { + ve = search_ve(e_argument, pout); + assert(ve && "argument-expression should be there"); + } else { + abort(); + } } else ve = valueExpr(instr, pout); auto vpf = valuePhiFunc(ve, curr_bb, instr); @@ -680,9 +716,11 @@ GVN::transferFunction(BasicBlock *bb) { } } } +#ifdef __DEBUG__ std::cout << "-------\n" << "for basic block " << bb->get_name() << ", pout:" << std::endl; utils::print_partitions(part); +#endif return part; } @@ -700,9 +738,13 @@ GVN::valuePhiFunc(shared_ptr<Expression> ve, auto pre_2 = *(++pre_bbs_.begin()); // check expression form - if (ve->get_expr_type() != Expression::e_bin) + if (ve->get_expr_type() != Expression::e_bin and + ve->get_expr_type() != Expression::e_cmp) return nullptr; auto ve_bin = static_cast<BinaryExpression *>(ve.get()); + auto ve_cmp = dynamic_cast<CmpExpression *>(ve.get()); + assert(ve->get_expr_type() != Expression::e_cmp or ve_cmp); + if (ve_bin->lhs_->get_expr_type() != Expression::e_phi and ve_bin->rhs_->get_expr_type() != Expression::e_phi) return nullptr; @@ -718,20 +760,41 @@ GVN::valuePhiFunc(shared_ptr<Expression> ve, if (ve_bin->lhs_->get_expr_type() == Expression::e_phi and ve_bin->rhs_->get_expr_type() == Expression::e_phi) { // try to get the merged value expression - vl_merge = - BinaryExpression::create(ve_bin->op_, lhs_phi->lhs_, rhs_phi->lhs_); - vr_merge = - BinaryExpression::create(ve_bin->op_, lhs_phi->rhs_, rhs_phi->rhs_); + if (ve_cmp) { + vl_merge = CmpExpression::create( + ve_bin->op_, ve_cmp->cmpop_, lhs_phi->lhs_, rhs_phi->lhs_); + vr_merge = CmpExpression::create( + ve_bin->op_, ve_cmp->cmpop_, lhs_phi->rhs_, rhs_phi->rhs_); + } else { + vl_merge = BinaryExpression::create( + ve_bin->op_, lhs_phi->lhs_, rhs_phi->lhs_); + vr_merge = BinaryExpression::create( + ve_bin->op_, lhs_phi->rhs_, rhs_phi->rhs_); + } } else if (ve_bin->lhs_->get_expr_type() == Expression::e_phi) { - vl_merge = - BinaryExpression::create(ve_bin->op_, lhs_phi->lhs_, ve_bin->rhs_); - vr_merge = - BinaryExpression::create(ve_bin->op_, lhs_phi->rhs_, ve_bin->rhs_); + if (ve_cmp) { + vl_merge = CmpExpression::create( + ve_bin->op_, ve_cmp->cmpop_, lhs_phi->lhs_, ve_bin->rhs_); + vr_merge = CmpExpression::create( + ve_bin->op_, ve_cmp->cmpop_, lhs_phi->rhs_, ve_bin->rhs_); + } else { + vl_merge = BinaryExpression::create( + ve_bin->op_, lhs_phi->lhs_, ve_bin->rhs_); + vr_merge = BinaryExpression::create( + ve_bin->op_, lhs_phi->rhs_, ve_bin->rhs_); + } } else { - vl_merge = - BinaryExpression::create(ve_bin->op_, ve_bin->lhs_, rhs_phi->lhs_); - vr_merge = - BinaryExpression::create(ve_bin->op_, ve_bin->lhs_, rhs_phi->rhs_); + if (ve_cmp) { + vl_merge = CmpExpression::create( + ve_bin->op_, ve_cmp->cmpop_, ve_bin->lhs_, rhs_phi->lhs_); + vr_merge = CmpExpression::create( + ve_bin->op_, ve_cmp->cmpop_, ve_bin->lhs_, rhs_phi->rhs_); + } else { + vl_merge = BinaryExpression::create( + ve_bin->op_, ve_bin->lhs_, rhs_phi->lhs_); + vr_merge = BinaryExpression::create( + ve_bin->op_, ve_bin->lhs_, rhs_phi->rhs_); + } } // constant fold @@ -920,6 +983,8 @@ GVNExpression::operator==(const Expression &lhs, const Expression &rhs) { return equiv_as<ConstantExpression>(lhs, rhs); case Expression::e_bin: return equiv_as<BinaryExpression>(lhs, rhs); + case Expression::e_cmp: + return equiv_as<CmpExpression>(lhs, rhs); case Expression::e_phi: return equiv_as<PhiExpression>(lhs, rhs); case Expression::e_cast: @@ -986,11 +1051,11 @@ GVN::pretend_copy_stmt(Instruction *instr, BasicBlock *bb) { // res = phi [op1, name1], [op2, name2] // ^0 ^1 ^2 ^3 // if (phi->get_operand(1)->get_name() == bb->get_name()) { - if (static_cast<BasicBlock*>(phi->get_operand(1)) == bb) { + if (static_cast<BasicBlock *>(phi->get_operand(1)) == bb) { // pretend copy statement: // `res = op1` return 0; - } else if (static_cast<BasicBlock*>(phi->get_operand(3)) == bb) { + } else if (static_cast<BasicBlock *>(phi->get_operand(3)) == bb) { // `res = op2` return 2; } diff --git a/tests/3-ir-gen/eval.py b/tests/3-ir-gen/eval.py index ca9c246132233fa7b69c855a147cd534075d71b9..b2c92eff52010ea4923ffd3521d427954d346da1 100755 --- a/tests/3-ir-gen/eval.py +++ b/tests/3-ir-gen/eval.py @@ -132,7 +132,7 @@ def eval(): COMMAND = [TEST_PATH] try: - result = subprocess.run([EXE_PATH, TEST_PATH + ".cminus"], stderr=subprocess.PIPE, timeout=1) + result = subprocess.run([EXE_PATH, "-mem2reg", "-gvn", TEST_PATH + ".cminus"], stderr=subprocess.PIPE, timeout=1) except Exception as _: f.write('\tFail\n') continue