Commit a116d0f5 authored by 李晓奇's avatar 李晓奇

fix the cmp expression bug, modify lab3 script

parent bb6e3f98
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "DeadCode.h" #include "DeadCode.h"
#include "FuncInfo.h" #include "FuncInfo.h"
#include "Function.h" #include "Function.h"
#include "IRprinter.h"
#include "Instruction.h" #include "Instruction.h"
#include "Module.h" #include "Module.h"
#include "PassManager.hpp" #include "PassManager.hpp"
...@@ -19,6 +20,8 @@ ...@@ -19,6 +20,8 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#define __DEBUG__
class GVN; class GVN;
namespace GVNExpression { namespace GVNExpression {
...@@ -45,6 +48,7 @@ class Expression { ...@@ -45,6 +48,7 @@ class Expression {
enum gvn_expr_t { enum gvn_expr_t {
e_constant, e_constant,
e_bin, e_bin,
e_cmp,
e_phi, e_phi,
e_cast, e_cast,
e_gep, e_gep,
...@@ -111,14 +115,66 @@ class BinaryExpression : public Expression { ...@@ -111,14 +115,66 @@ class BinaryExpression : public Expression {
BinaryExpression(Instruction::OpID op, BinaryExpression(Instruction::OpID op,
std::shared_ptr<Expression> lhs, std::shared_ptr<Expression> lhs,
std::shared_ptr<Expression> rhs) std::shared_ptr<Expression> rhs,
: Expression(e_bin), op_(op), lhs_(lhs), rhs_(rhs) {} gvn_expr_t tp = e_bin)
: Expression(tp), op_(op), lhs_(lhs), rhs_(rhs) {}
private: protected:
Instruction::OpID op_; Instruction::OpID op_;
std::shared_ptr<Expression> lhs_, rhs_; std::shared_ptr<Expression> lhs_, rhs_;
}; };
// cmp expression, inherited from binary-expression
class CmpExpression : public BinaryExpression {
friend class ::GVN;
typedef union {
CmpInst::CmpOp int_cmp_op;
FCmpInst::CmpOp float_cmp_op;
} cmp_op_i_or_f;
public:
static std::shared_ptr<CmpExpression> create(
Instruction::OpID op,
cmp_op_i_or_f cmpop,
std::shared_ptr<Expression> lhs,
std::shared_ptr<Expression> rhs) {
return std::make_shared<CmpExpression>(op, cmpop, lhs, rhs);
}
virtual std::string print() {
std::string ret{"(" + Instruction::get_instr_op_name(op_) + " "};
ret += op_ == Instruction::cmp ? print_cmp_type(cmpop_.int_cmp_op)
: print_fcmp_type(cmpop_.float_cmp_op);
ret += " " + lhs_->print() + " " + rhs_->print() + ")";
return ret;
}
bool equiv(const CmpExpression *other) const {
if (not(op_ == other->op_))
return false;
if (op_ == Instruction::cmp) {
if (cmpop_.int_cmp_op != other->cmpop_.int_cmp_op)
return false;
} else {
// op_ == Instruction::fcmp
if (cmpop_.float_cmp_op != other->cmpop_.float_cmp_op)
return false;
}
return *lhs_ == *other->lhs_ and *rhs_ == *other->rhs_;
}
CmpExpression(Instruction::OpID op,
cmp_op_i_or_f cmpop,
std::shared_ptr<Expression> lhs,
std::shared_ptr<Expression> rhs)
: BinaryExpression(op, lhs, rhs, e_cmp), cmpop_(cmpop) {
assert(op == Instruction::cmp or
op == Instruction::fcmp && "Wrong instruction type!");
}
private:
cmp_op_i_or_f cmpop_;
};
class PhiExpression : public Expression { class PhiExpression : public Expression {
friend class ::GVN; friend class ::GVN;
......
...@@ -295,6 +295,7 @@ GVN::intersect(shared_ptr<CongruenceClass> ci, shared_ptr<CongruenceClass> cj) { ...@@ -295,6 +295,7 @@ GVN::intersect(shared_ptr<CongruenceClass> ci, shared_ptr<CongruenceClass> cj) {
// It will assign start index for each instruction, including copy statement. // It will assign start index for each instruction, including copy statement.
// the logic here: // the logic here:
// - use the same traverse order as main run // - use the same traverse order as main run
// - use the same parameters as transferFunction(x, e)
// - assign an incremental number for each instruction // - assign an incremental number for each instruction
void void
GVN::assign_start_idx_() { GVN::assign_start_idx_() {
...@@ -333,6 +334,9 @@ GVN::assign_start_idx_() { ...@@ -333,6 +334,9 @@ GVN::assign_start_idx_() {
// this is necessary, cause the initialization for Entry block will use it! // this is necessary, cause the initialization for Entry block will use it!
reset_number(); reset_number();
} }
// Add global variables and arguments into pin, these will flow to the end and
// pursue there.
void void
GVN::deal_with_entry(BasicBlock *Entry) { GVN::deal_with_entry(BasicBlock *Entry) {
// add congruence class for global variables // add congruence class for global variables
...@@ -360,11 +364,13 @@ GVN::detectEquivalences() { ...@@ -360,11 +364,13 @@ GVN::detectEquivalences() {
int times = 0; int times = 0;
bool changed; bool changed;
// DEBUG part // DEBUG part
#ifdef __DEBUG__
std::cout << "all the instruction address:" << std::endl; std::cout << "all the instruction address:" << std::endl;
for (auto &bb : func_->get_basic_blocks()) { for (auto &bb : func_->get_basic_blocks()) {
for (auto &instr : bb.get_instructions()) for (auto &instr : bb.get_instructions())
std::cout << &instr << "\t" << instr.print() << std::endl; std::cout << &instr << "\t" << instr.print() << std::endl;
} }
#endif
assign_start_idx_(); assign_start_idx_();
// initialize pout with top // initialize pout with top
for (auto &bb : func_->get_basic_blocks()) { for (auto &bb : func_->get_basic_blocks()) {
...@@ -377,7 +383,9 @@ GVN::detectEquivalences() { ...@@ -377,7 +383,9 @@ GVN::detectEquivalences() {
// iterate until converge // iterate until converge
do { do {
changed = false; changed = false;
#ifdef __DEBUG__
std::cout << ++times << "th iteration" << std::endl; std::cout << ++times << "th iteration" << std::endl;
#endif
for (auto &_bb : func_->get_basic_blocks()) { for (auto &_bb : func_->get_basic_blocks()) {
auto bb = &_bb; auto bb = &_bb;
if (bb == Entry) if (bb == Entry)
...@@ -399,10 +407,13 @@ GVN::detectEquivalences() { ...@@ -399,10 +407,13 @@ GVN::detectEquivalences() {
pin_[bb] = clone(pout_[pre]); pin_[bb] = clone(pout_[pre]);
break; break;
} }
default: default: {
LOG_ERROR << "block has count of predecessors: " std::cerr << "In function " << func_->get_name() << ", "
<< " block " << bb->get_name()
<< " has count of predecessors: "
<< pre_bbs_.size(); << pre_bbs_.size();
abort(); continue;
}
} }
} }
auto part = transferFunction(bb); auto part = transferFunction(bb);
...@@ -412,11 +423,12 @@ GVN::detectEquivalences() { ...@@ -412,11 +423,12 @@ GVN::detectEquivalences() {
pout_[bb] = clone(part); pout_[bb] = clone(part);
_TOP[bb] = false; _TOP[bb] = false;
/* std::cout << "//-------\n" #ifdef __DEBUG__
* << "//after transferFunction(basic-block=" std::cout << "//-------\n"
* << bb->get_name() << "), all pout:" << std::endl; << "//after transferFunction(basic-block="
*/ << bb->get_name() << "), all pout:" << std::endl;
// dump_tmp(*func_); dump_tmp(*func_);
#endif
} }
// reset value number // reset value number
} while (changed); } while (changed);
...@@ -534,9 +546,22 @@ GVN::valueExpr(Instruction *instr, const partitions &part) { ...@@ -534,9 +546,22 @@ GVN::valueExpr(Instruction *instr, const partitions &part) {
res = valueExpr_core_(instr, part, 2, true); res = valueExpr_core_(instr, part, 2, true);
if (res.size() == 1) // constant fold if (res.size() == 1) // constant fold
return res[0]; return res[0];
else else {
return BinaryExpression::create( if (instr->isBinary())
instr->get_instr_type(), res[0], res[1]); return BinaryExpression::create(
instr->get_instr_type(), res[0], res[1]);
else {
CmpExpression::cmp_op_i_or_f cmpop;
if (instr->is_cmp())
cmpop.int_cmp_op =
static_cast<CmpInst *>(instr)->get_cmp_op();
else
cmpop.float_cmp_op =
static_cast<FCmpInst *>(instr)->get_cmp_op();
return CmpExpression::create(
instr->get_instr_type(), cmpop, res[0], res[1]);
}
}
} else if (instr->is_phi()) { } else if (instr->is_phi()) {
err = "phi"; err = "phi";
} else if (instr->is_fp2si() or instr->is_si2fp() or instr->is_zext()) { } else if (instr->is_fp2si() or instr->is_si2fp() or instr->is_zext()) {
...@@ -601,9 +626,14 @@ GVN::transferFunction(Instruction *instr, Value *e, partitions pin) { ...@@ -601,9 +626,14 @@ GVN::transferFunction(Instruction *instr, Value *e, partitions pin) {
partitions pout = clone(pin); partitions pout = clone(pin);
// TODO: deal with copy-stmt case // TODO: deal with copy-stmt case
// if e is not null, then we are handling copy-stmt
// but will e is the global? Or will global variable be in phi?
// ??
auto e_instr = dynamic_cast<Instruction *>(e); auto e_instr = dynamic_cast<Instruction *>(e);
auto e_const = dynamic_cast<Constant *>(e); auto e_const = dynamic_cast<Constant *>(e);
assert((not e or e_instr or e_const) && // auto e_global = dynamic_cast<GlobalVariable *>(e);
auto e_argument = dynamic_cast<Argument *>(e);
assert((not e or e_instr or e_const or e_argument) &&
"A value must be from an instruction or constant"); "A value must be from an instruction or constant");
// erase the old record for x // erase the old record for x
std::set<Value *>::iterator it; std::set<Value *>::iterator it;
...@@ -622,8 +652,14 @@ GVN::transferFunction(Instruction *instr, Value *e, partitions pin) { ...@@ -622,8 +652,14 @@ GVN::transferFunction(Instruction *instr, Value *e, partitions pin) {
if (e) { if (e) {
if (e_const) if (e_const)
ve = ConstantExpression::create(e_const); ve = ConstantExpression::create(e_const);
else else if (e_instr)
ve = valueExpr(e_instr, pout); ve = valueExpr(e_instr, pout);
else if (e_argument) {
ve = search_ve(e_argument, pout);
assert(ve && "argument-expression should be there");
} else {
abort();
}
} else } else
ve = valueExpr(instr, pout); ve = valueExpr(instr, pout);
auto vpf = valuePhiFunc(ve, curr_bb, instr); auto vpf = valuePhiFunc(ve, curr_bb, instr);
...@@ -680,9 +716,11 @@ GVN::transferFunction(BasicBlock *bb) { ...@@ -680,9 +716,11 @@ GVN::transferFunction(BasicBlock *bb) {
} }
} }
} }
#ifdef __DEBUG__
std::cout << "-------\n" std::cout << "-------\n"
<< "for basic block " << bb->get_name() << ", pout:" << std::endl; << "for basic block " << bb->get_name() << ", pout:" << std::endl;
utils::print_partitions(part); utils::print_partitions(part);
#endif
return part; return part;
} }
...@@ -700,9 +738,13 @@ GVN::valuePhiFunc(shared_ptr<Expression> ve, ...@@ -700,9 +738,13 @@ GVN::valuePhiFunc(shared_ptr<Expression> ve,
auto pre_2 = *(++pre_bbs_.begin()); auto pre_2 = *(++pre_bbs_.begin());
// check expression form // check expression form
if (ve->get_expr_type() != Expression::e_bin) if (ve->get_expr_type() != Expression::e_bin and
ve->get_expr_type() != Expression::e_cmp)
return nullptr; return nullptr;
auto ve_bin = static_cast<BinaryExpression *>(ve.get()); auto ve_bin = static_cast<BinaryExpression *>(ve.get());
auto ve_cmp = dynamic_cast<CmpExpression *>(ve.get());
assert(ve->get_expr_type() != Expression::e_cmp or ve_cmp);
if (ve_bin->lhs_->get_expr_type() != Expression::e_phi and if (ve_bin->lhs_->get_expr_type() != Expression::e_phi and
ve_bin->rhs_->get_expr_type() != Expression::e_phi) ve_bin->rhs_->get_expr_type() != Expression::e_phi)
return nullptr; return nullptr;
...@@ -718,20 +760,41 @@ GVN::valuePhiFunc(shared_ptr<Expression> ve, ...@@ -718,20 +760,41 @@ GVN::valuePhiFunc(shared_ptr<Expression> ve,
if (ve_bin->lhs_->get_expr_type() == Expression::e_phi and if (ve_bin->lhs_->get_expr_type() == Expression::e_phi and
ve_bin->rhs_->get_expr_type() == Expression::e_phi) { ve_bin->rhs_->get_expr_type() == Expression::e_phi) {
// try to get the merged value expression // try to get the merged value expression
vl_merge = if (ve_cmp) {
BinaryExpression::create(ve_bin->op_, lhs_phi->lhs_, rhs_phi->lhs_); vl_merge = CmpExpression::create(
vr_merge = ve_bin->op_, ve_cmp->cmpop_, lhs_phi->lhs_, rhs_phi->lhs_);
BinaryExpression::create(ve_bin->op_, lhs_phi->rhs_, rhs_phi->rhs_); vr_merge = CmpExpression::create(
ve_bin->op_, ve_cmp->cmpop_, lhs_phi->rhs_, rhs_phi->rhs_);
} else {
vl_merge = BinaryExpression::create(
ve_bin->op_, lhs_phi->lhs_, rhs_phi->lhs_);
vr_merge = BinaryExpression::create(
ve_bin->op_, lhs_phi->rhs_, rhs_phi->rhs_);
}
} else if (ve_bin->lhs_->get_expr_type() == Expression::e_phi) { } else if (ve_bin->lhs_->get_expr_type() == Expression::e_phi) {
vl_merge = if (ve_cmp) {
BinaryExpression::create(ve_bin->op_, lhs_phi->lhs_, ve_bin->rhs_); vl_merge = CmpExpression::create(
vr_merge = ve_bin->op_, ve_cmp->cmpop_, lhs_phi->lhs_, ve_bin->rhs_);
BinaryExpression::create(ve_bin->op_, lhs_phi->rhs_, ve_bin->rhs_); vr_merge = CmpExpression::create(
ve_bin->op_, ve_cmp->cmpop_, lhs_phi->rhs_, ve_bin->rhs_);
} else {
vl_merge = BinaryExpression::create(
ve_bin->op_, lhs_phi->lhs_, ve_bin->rhs_);
vr_merge = BinaryExpression::create(
ve_bin->op_, lhs_phi->rhs_, ve_bin->rhs_);
}
} else { } else {
vl_merge = if (ve_cmp) {
BinaryExpression::create(ve_bin->op_, ve_bin->lhs_, rhs_phi->lhs_); vl_merge = CmpExpression::create(
vr_merge = ve_bin->op_, ve_cmp->cmpop_, ve_bin->lhs_, rhs_phi->lhs_);
BinaryExpression::create(ve_bin->op_, ve_bin->lhs_, rhs_phi->rhs_); vr_merge = CmpExpression::create(
ve_bin->op_, ve_cmp->cmpop_, ve_bin->lhs_, rhs_phi->rhs_);
} else {
vl_merge = BinaryExpression::create(
ve_bin->op_, ve_bin->lhs_, rhs_phi->lhs_);
vr_merge = BinaryExpression::create(
ve_bin->op_, ve_bin->lhs_, rhs_phi->rhs_);
}
} }
// constant fold // constant fold
...@@ -920,6 +983,8 @@ GVNExpression::operator==(const Expression &lhs, const Expression &rhs) { ...@@ -920,6 +983,8 @@ GVNExpression::operator==(const Expression &lhs, const Expression &rhs) {
return equiv_as<ConstantExpression>(lhs, rhs); return equiv_as<ConstantExpression>(lhs, rhs);
case Expression::e_bin: case Expression::e_bin:
return equiv_as<BinaryExpression>(lhs, rhs); return equiv_as<BinaryExpression>(lhs, rhs);
case Expression::e_cmp:
return equiv_as<CmpExpression>(lhs, rhs);
case Expression::e_phi: case Expression::e_phi:
return equiv_as<PhiExpression>(lhs, rhs); return equiv_as<PhiExpression>(lhs, rhs);
case Expression::e_cast: case Expression::e_cast:
...@@ -986,11 +1051,11 @@ GVN::pretend_copy_stmt(Instruction *instr, BasicBlock *bb) { ...@@ -986,11 +1051,11 @@ GVN::pretend_copy_stmt(Instruction *instr, BasicBlock *bb) {
// res = phi [op1, name1], [op2, name2] // res = phi [op1, name1], [op2, name2]
// ^0 ^1 ^2 ^3 // ^0 ^1 ^2 ^3
// if (phi->get_operand(1)->get_name() == bb->get_name()) { // if (phi->get_operand(1)->get_name() == bb->get_name()) {
if (static_cast<BasicBlock*>(phi->get_operand(1)) == bb) { if (static_cast<BasicBlock *>(phi->get_operand(1)) == bb) {
// pretend copy statement: // pretend copy statement:
// `res = op1` // `res = op1`
return 0; return 0;
} else if (static_cast<BasicBlock*>(phi->get_operand(3)) == bb) { } else if (static_cast<BasicBlock *>(phi->get_operand(3)) == bb) {
// `res = op2` // `res = op2`
return 2; return 2;
} }
......
...@@ -132,7 +132,7 @@ def eval(): ...@@ -132,7 +132,7 @@ def eval():
COMMAND = [TEST_PATH] COMMAND = [TEST_PATH]
try: try:
result = subprocess.run([EXE_PATH, TEST_PATH + ".cminus"], stderr=subprocess.PIPE, timeout=1) result = subprocess.run([EXE_PATH, "-mem2reg", "-gvn", TEST_PATH + ".cminus"], stderr=subprocess.PIPE, timeout=1)
except Exception as _: except Exception as _:
f.write('\tFail\n') f.write('\tFail\n')
continue continue
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment