Commit a116d0f5 authored by 李晓奇's avatar 李晓奇

fix the cmp expression bug, modify lab3 script

parent bb6e3f98
......@@ -4,6 +4,7 @@
#include "DeadCode.h"
#include "FuncInfo.h"
#include "Function.h"
#include "IRprinter.h"
#include "Instruction.h"
#include "Module.h"
#include "PassManager.hpp"
......@@ -19,6 +20,8 @@
#include <utility>
#include <vector>
#define __DEBUG__
class GVN;
namespace GVNExpression {
......@@ -45,6 +48,7 @@ class Expression {
enum gvn_expr_t {
e_constant,
e_bin,
e_cmp,
e_phi,
e_cast,
e_gep,
......@@ -111,14 +115,66 @@ class BinaryExpression : public Expression {
BinaryExpression(Instruction::OpID op,
std::shared_ptr<Expression> lhs,
std::shared_ptr<Expression> rhs)
: Expression(e_bin), op_(op), lhs_(lhs), rhs_(rhs) {}
std::shared_ptr<Expression> rhs,
gvn_expr_t tp = e_bin)
: Expression(tp), op_(op), lhs_(lhs), rhs_(rhs) {}
private:
protected:
Instruction::OpID op_;
std::shared_ptr<Expression> lhs_, rhs_;
};
// cmp expression, inherited from binary-expression
class CmpExpression : public BinaryExpression {
friend class ::GVN;
typedef union {
CmpInst::CmpOp int_cmp_op;
FCmpInst::CmpOp float_cmp_op;
} cmp_op_i_or_f;
public:
static std::shared_ptr<CmpExpression> create(
Instruction::OpID op,
cmp_op_i_or_f cmpop,
std::shared_ptr<Expression> lhs,
std::shared_ptr<Expression> rhs) {
return std::make_shared<CmpExpression>(op, cmpop, lhs, rhs);
}
virtual std::string print() {
std::string ret{"(" + Instruction::get_instr_op_name(op_) + " "};
ret += op_ == Instruction::cmp ? print_cmp_type(cmpop_.int_cmp_op)
: print_fcmp_type(cmpop_.float_cmp_op);
ret += " " + lhs_->print() + " " + rhs_->print() + ")";
return ret;
}
bool equiv(const CmpExpression *other) const {
if (not(op_ == other->op_))
return false;
if (op_ == Instruction::cmp) {
if (cmpop_.int_cmp_op != other->cmpop_.int_cmp_op)
return false;
} else {
// op_ == Instruction::fcmp
if (cmpop_.float_cmp_op != other->cmpop_.float_cmp_op)
return false;
}
return *lhs_ == *other->lhs_ and *rhs_ == *other->rhs_;
}
CmpExpression(Instruction::OpID op,
cmp_op_i_or_f cmpop,
std::shared_ptr<Expression> lhs,
std::shared_ptr<Expression> rhs)
: BinaryExpression(op, lhs, rhs, e_cmp), cmpop_(cmpop) {
assert(op == Instruction::cmp or
op == Instruction::fcmp && "Wrong instruction type!");
}
private:
cmp_op_i_or_f cmpop_;
};
class PhiExpression : public Expression {
friend class ::GVN;
......
......@@ -295,6 +295,7 @@ GVN::intersect(shared_ptr<CongruenceClass> ci, shared_ptr<CongruenceClass> cj) {
// It will assign start index for each instruction, including copy statement.
// the logic here:
// - use the same traverse order as main run
// - use the same parameters as transferFunction(x, e)
// - assign an incremental number for each instruction
void
GVN::assign_start_idx_() {
......@@ -333,6 +334,9 @@ GVN::assign_start_idx_() {
// this is necessary, cause the initialization for Entry block will use it!
reset_number();
}
// Add global variables and arguments into pin, these will flow to the end and
// pursue there.
void
GVN::deal_with_entry(BasicBlock *Entry) {
// add congruence class for global variables
......@@ -360,11 +364,13 @@ GVN::detectEquivalences() {
int times = 0;
bool changed;
// DEBUG part
#ifdef __DEBUG__
std::cout << "all the instruction address:" << std::endl;
for (auto &bb : func_->get_basic_blocks()) {
for (auto &instr : bb.get_instructions())
std::cout << &instr << "\t" << instr.print() << std::endl;
}
#endif
assign_start_idx_();
// initialize pout with top
for (auto &bb : func_->get_basic_blocks()) {
......@@ -377,7 +383,9 @@ GVN::detectEquivalences() {
// iterate until converge
do {
changed = false;
#ifdef __DEBUG__
std::cout << ++times << "th iteration" << std::endl;
#endif
for (auto &_bb : func_->get_basic_blocks()) {
auto bb = &_bb;
if (bb == Entry)
......@@ -399,10 +407,13 @@ GVN::detectEquivalences() {
pin_[bb] = clone(pout_[pre]);
break;
}
default:
LOG_ERROR << "block has count of predecessors: "
default: {
std::cerr << "In function " << func_->get_name() << ", "
<< " block " << bb->get_name()
<< " has count of predecessors: "
<< pre_bbs_.size();
abort();
continue;
}
}
}
auto part = transferFunction(bb);
......@@ -412,11 +423,12 @@ GVN::detectEquivalences() {
pout_[bb] = clone(part);
_TOP[bb] = false;
/* std::cout << "//-------\n"
* << "//after transferFunction(basic-block="
* << bb->get_name() << "), all pout:" << std::endl;
*/
// dump_tmp(*func_);
#ifdef __DEBUG__
std::cout << "//-------\n"
<< "//after transferFunction(basic-block="
<< bb->get_name() << "), all pout:" << std::endl;
dump_tmp(*func_);
#endif
}
// reset value number
} while (changed);
......@@ -534,9 +546,22 @@ GVN::valueExpr(Instruction *instr, const partitions &part) {
res = valueExpr_core_(instr, part, 2, true);
if (res.size() == 1) // constant fold
return res[0];
else
return BinaryExpression::create(
instr->get_instr_type(), res[0], res[1]);
else {
if (instr->isBinary())
return BinaryExpression::create(
instr->get_instr_type(), res[0], res[1]);
else {
CmpExpression::cmp_op_i_or_f cmpop;
if (instr->is_cmp())
cmpop.int_cmp_op =
static_cast<CmpInst *>(instr)->get_cmp_op();
else
cmpop.float_cmp_op =
static_cast<FCmpInst *>(instr)->get_cmp_op();
return CmpExpression::create(
instr->get_instr_type(), cmpop, res[0], res[1]);
}
}
} else if (instr->is_phi()) {
err = "phi";
} else if (instr->is_fp2si() or instr->is_si2fp() or instr->is_zext()) {
......@@ -601,9 +626,14 @@ GVN::transferFunction(Instruction *instr, Value *e, partitions pin) {
partitions pout = clone(pin);
// TODO: deal with copy-stmt case
// if e is not null, then we are handling copy-stmt
// but will e is the global? Or will global variable be in phi?
// ??
auto e_instr = dynamic_cast<Instruction *>(e);
auto e_const = dynamic_cast<Constant *>(e);
assert((not e or e_instr or e_const) &&
// auto e_global = dynamic_cast<GlobalVariable *>(e);
auto e_argument = dynamic_cast<Argument *>(e);
assert((not e or e_instr or e_const or e_argument) &&
"A value must be from an instruction or constant");
// erase the old record for x
std::set<Value *>::iterator it;
......@@ -622,8 +652,14 @@ GVN::transferFunction(Instruction *instr, Value *e, partitions pin) {
if (e) {
if (e_const)
ve = ConstantExpression::create(e_const);
else
else if (e_instr)
ve = valueExpr(e_instr, pout);
else if (e_argument) {
ve = search_ve(e_argument, pout);
assert(ve && "argument-expression should be there");
} else {
abort();
}
} else
ve = valueExpr(instr, pout);
auto vpf = valuePhiFunc(ve, curr_bb, instr);
......@@ -680,9 +716,11 @@ GVN::transferFunction(BasicBlock *bb) {
}
}
}
#ifdef __DEBUG__
std::cout << "-------\n"
<< "for basic block " << bb->get_name() << ", pout:" << std::endl;
utils::print_partitions(part);
#endif
return part;
}
......@@ -700,9 +738,13 @@ GVN::valuePhiFunc(shared_ptr<Expression> ve,
auto pre_2 = *(++pre_bbs_.begin());
// check expression form
if (ve->get_expr_type() != Expression::e_bin)
if (ve->get_expr_type() != Expression::e_bin and
ve->get_expr_type() != Expression::e_cmp)
return nullptr;
auto ve_bin = static_cast<BinaryExpression *>(ve.get());
auto ve_cmp = dynamic_cast<CmpExpression *>(ve.get());
assert(ve->get_expr_type() != Expression::e_cmp or ve_cmp);
if (ve_bin->lhs_->get_expr_type() != Expression::e_phi and
ve_bin->rhs_->get_expr_type() != Expression::e_phi)
return nullptr;
......@@ -718,20 +760,41 @@ GVN::valuePhiFunc(shared_ptr<Expression> ve,
if (ve_bin->lhs_->get_expr_type() == Expression::e_phi and
ve_bin->rhs_->get_expr_type() == Expression::e_phi) {
// try to get the merged value expression
vl_merge =
BinaryExpression::create(ve_bin->op_, lhs_phi->lhs_, rhs_phi->lhs_);
vr_merge =
BinaryExpression::create(ve_bin->op_, lhs_phi->rhs_, rhs_phi->rhs_);
if (ve_cmp) {
vl_merge = CmpExpression::create(
ve_bin->op_, ve_cmp->cmpop_, lhs_phi->lhs_, rhs_phi->lhs_);
vr_merge = CmpExpression::create(
ve_bin->op_, ve_cmp->cmpop_, lhs_phi->rhs_, rhs_phi->rhs_);
} else {
vl_merge = BinaryExpression::create(
ve_bin->op_, lhs_phi->lhs_, rhs_phi->lhs_);
vr_merge = BinaryExpression::create(
ve_bin->op_, lhs_phi->rhs_, rhs_phi->rhs_);
}
} else if (ve_bin->lhs_->get_expr_type() == Expression::e_phi) {
vl_merge =
BinaryExpression::create(ve_bin->op_, lhs_phi->lhs_, ve_bin->rhs_);
vr_merge =
BinaryExpression::create(ve_bin->op_, lhs_phi->rhs_, ve_bin->rhs_);
if (ve_cmp) {
vl_merge = CmpExpression::create(
ve_bin->op_, ve_cmp->cmpop_, lhs_phi->lhs_, ve_bin->rhs_);
vr_merge = CmpExpression::create(
ve_bin->op_, ve_cmp->cmpop_, lhs_phi->rhs_, ve_bin->rhs_);
} else {
vl_merge = BinaryExpression::create(
ve_bin->op_, lhs_phi->lhs_, ve_bin->rhs_);
vr_merge = BinaryExpression::create(
ve_bin->op_, lhs_phi->rhs_, ve_bin->rhs_);
}
} else {
vl_merge =
BinaryExpression::create(ve_bin->op_, ve_bin->lhs_, rhs_phi->lhs_);
vr_merge =
BinaryExpression::create(ve_bin->op_, ve_bin->lhs_, rhs_phi->rhs_);
if (ve_cmp) {
vl_merge = CmpExpression::create(
ve_bin->op_, ve_cmp->cmpop_, ve_bin->lhs_, rhs_phi->lhs_);
vr_merge = CmpExpression::create(
ve_bin->op_, ve_cmp->cmpop_, ve_bin->lhs_, rhs_phi->rhs_);
} else {
vl_merge = BinaryExpression::create(
ve_bin->op_, ve_bin->lhs_, rhs_phi->lhs_);
vr_merge = BinaryExpression::create(
ve_bin->op_, ve_bin->lhs_, rhs_phi->rhs_);
}
}
// constant fold
......@@ -920,6 +983,8 @@ GVNExpression::operator==(const Expression &lhs, const Expression &rhs) {
return equiv_as<ConstantExpression>(lhs, rhs);
case Expression::e_bin:
return equiv_as<BinaryExpression>(lhs, rhs);
case Expression::e_cmp:
return equiv_as<CmpExpression>(lhs, rhs);
case Expression::e_phi:
return equiv_as<PhiExpression>(lhs, rhs);
case Expression::e_cast:
......@@ -986,11 +1051,11 @@ GVN::pretend_copy_stmt(Instruction *instr, BasicBlock *bb) {
// res = phi [op1, name1], [op2, name2]
// ^0 ^1 ^2 ^3
// if (phi->get_operand(1)->get_name() == bb->get_name()) {
if (static_cast<BasicBlock*>(phi->get_operand(1)) == bb) {
if (static_cast<BasicBlock *>(phi->get_operand(1)) == bb) {
// pretend copy statement:
// `res = op1`
return 0;
} else if (static_cast<BasicBlock*>(phi->get_operand(3)) == bb) {
} else if (static_cast<BasicBlock *>(phi->get_operand(3)) == bb) {
// `res = op2`
return 2;
}
......
......@@ -132,7 +132,7 @@ def eval():
COMMAND = [TEST_PATH]
try:
result = subprocess.run([EXE_PATH, TEST_PATH + ".cminus"], stderr=subprocess.PIPE, timeout=1)
result = subprocess.run([EXE_PATH, "-mem2reg", "-gvn", TEST_PATH + ".cminus"], stderr=subprocess.PIPE, timeout=1)
except Exception as _:
f.write('\tFail\n')
continue
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment