Commit bd7b44fb authored by 李晓奇's avatar 李晓奇

can pass functional

parent f268905c
...@@ -42,13 +42,23 @@ class ConstFolder { ...@@ -42,13 +42,23 @@ class ConstFolder {
class Expression { class Expression {
public: public:
// TODO: you need to extend expression types according to testcases // TODO: you need to extend expression types according to testcases
enum gvn_expr_t { e_constant, e_bin, e_phi, e_cast, e_gep, e_unique }; enum gvn_expr_t {
e_constant,
e_bin,
e_phi,
e_cast,
e_gep,
e_unique,
e_global,
e_argument,
e_call
};
Expression(gvn_expr_t t) : expr_type(t) {} Expression(gvn_expr_t t) : expr_type(t) {}
virtual ~Expression() = default; virtual ~Expression() = default;
virtual std::string print() = 0; virtual std::string print() = 0;
gvn_expr_t get_expr_type() const { return expr_type; } gvn_expr_t get_expr_type() const { return expr_type; }
private: protected:
gvn_expr_t expr_type; gvn_expr_t expr_type;
}; };
...@@ -198,7 +208,6 @@ class GEPExpression : public Expression { ...@@ -198,7 +208,6 @@ class GEPExpression : public Expression {
std::vector<std::shared_ptr<Expression>> idxs_; std::vector<std::shared_ptr<Expression>> idxs_;
}; };
// unique expression: not equal to any one else
class UniqueExpression : public Expression { class UniqueExpression : public Expression {
public: public:
static std::shared_ptr<UniqueExpression> create(Instruction *instr, static std::shared_ptr<UniqueExpression> create(Instruction *instr,
...@@ -213,13 +222,88 @@ class UniqueExpression : public Expression { ...@@ -213,13 +222,88 @@ class UniqueExpression : public Expression {
return instr_ == other->instr_; return instr_ == other->instr_;
} }
UniqueExpression(Instruction *instr, size_t index) UniqueExpression(Instruction *instr, size_t index, gvn_expr_t tp = e_unique)
: Expression(e_unique), instr_(instr), index_(index) {} : Expression(tp), instr_(instr), index_(index) {}
private: protected:
Instruction *instr_; Instruction *instr_;
size_t index_; size_t index_;
}; };
// global variables
class GlobalVarExpression : public Expression {
public:
static std::shared_ptr<GlobalVarExpression> create(GlobalVariable *g) {
return std::make_shared<GlobalVarExpression>(g);
}
virtual std::string print() { return "<" + g_->get_name() + ">"; }
bool equiv(const GlobalVarExpression *other) const {
return g_ == other->g_;
}
GlobalVarExpression(GlobalVariable *g) : Expression(e_global), g_(g) {}
private:
GlobalVariable *g_;
};
// this is the arguments passed to the current function, but not the arguments
// given to a function call.
class ArgExpression : public Expression {
public:
static std::shared_ptr<ArgExpression> create(Argument *a) {
return std::make_shared<ArgExpression>(a);
}
virtual std::string print() { return "[" + a_->get_name() + "]"; }
bool equiv(const ArgExpression *other) const { return a_ == other->a_; }
ArgExpression(Argument *a) : Expression(e_argument), a_(a) {}
private:
Argument *a_;
};
// function call
class CallExpression : public UniqueExpression {
public:
static std::shared_ptr<CallExpression> create(
Instruction *instr,
size_t index,
bool pure,
std::vector<std::shared_ptr<Expression>> &args) {
return std::make_shared<CallExpression>(instr, index, pure, args);
}
virtual std::string print() {
std::string ret = "{ " + f_->get_name();
for (auto arg : args_)
ret += " " + arg->print();
return ret + " }";
}
bool equiv(const CallExpression *other) const {
// single instruction, should be same
if (instr_ == other->instr_)
return true;
if (not(pure_ and f_ == other->f_))
return false;
for (int i = 0; i < args_.size(); i++)
if (not(*args_[i] == *other->args_[i]))
return false;
return true;
}
CallExpression(Instruction *instr,
size_t index,
bool pure,
std::vector<std::shared_ptr<Expression>> &args)
: UniqueExpression(instr, index, e_call)
, pure_(pure)
, f_(static_cast<Function *>(instr->get_operand(0)))
, args_(args) {}
private:
bool pure_;
Function *f_;
std::vector<std::shared_ptr<Expression>> args_;
};
} // namespace GVNExpression } // namespace GVNExpression
/** /**
...@@ -304,7 +388,7 @@ class GVN : public Pass { ...@@ -304,7 +388,7 @@ class GVN : public Pass {
private: private:
bool dump_json_; bool dump_json_;
std::uint64_t next_value_number_ = 1; std::uint64_t next_value_number_;
Function *func_; Function *func_;
std::map<BasicBlock *, partitions> pin_, pout_; std::map<BasicBlock *, partitions> pin_, pout_;
std::unique_ptr<FuncInfo> func_info_; std::unique_ptr<FuncInfo> func_info_;
...@@ -312,30 +396,38 @@ class GVN : public Pass { ...@@ -312,30 +396,38 @@ class GVN : public Pass {
std::unique_ptr<DeadCode> dce_; std::unique_ptr<DeadCode> dce_;
// self add member // self add member
std::map<BasicBlock *, bool> _TOP; // std::uint64_t start_number_;
partitions join_helper(BasicBlock *pre1, BasicBlock *pre2);
BasicBlock *curr_bb; BasicBlock *curr_bb;
std::map<BasicBlock *, bool> _TOP;
std::map<std::pair<Instruction *, Value *>, size_t> start_idx_; std::map<std::pair<Instruction *, Value *>, size_t> start_idx_;
std::map<GlobalVariable *,
std::shared_ptr<GVNExpression::GlobalVarExpression>>
global_map_;
std::map<Argument *, std::shared_ptr<GVNExpression::Expression>> arg_map_;
// //
// self add function // self add function
// //
void add_map_();
std::uint64_t new_number() { return next_value_number_++; } std::uint64_t new_number() { return next_value_number_++; }
void deal_with_entry(BasicBlock *Entry);
static int pretend_copy_stmt(Instruction *inst, BasicBlock *bb); static int pretend_copy_stmt(Instruction *inst, BasicBlock *bb);
std::shared_ptr<GVNExpression::Expression> search_ve( std::shared_ptr<GVNExpression::Expression> search_ve(
Value *v, Value *v,
const partitions &part); const partitions &part);
partitions join_helper(BasicBlock *pre1, BasicBlock *pre2);
std::vector<std::shared_ptr<GVNExpression::Expression>> valueExpr_core_( std::vector<std::shared_ptr<GVNExpression::Expression>> valueExpr_core_(
Instruction *instr, Instruction *instr,
const partitions &part, const partitions &part,
const size_t count, const size_t count,
bool fold_ = true); bool fold_ = false,
int start_index_ = 0);
Constant *constFold_core(const size_t count, Constant *constFold_core(const size_t count,
const partitions &part, const partitions &part,
Instruction *instr, Instruction *instr,
const std::vector<Value *> &operands); const std::vector<Value *> &operands);
void assign_start_idx_(); void assign_start_idx_();
void dump_tmp(Function &); void dump_tmp(Function &);
void reset_number() { next_value_number_ = 1; }
}; };
bool operator==(const GVN::partitions &p1, const GVN::partitions &p2); bool operator==(const GVN::partitions &p1, const GVN::partitions &p2);
...@@ -199,8 +199,6 @@ dump_bb2partition(const std::map<BasicBlock *, GVN::partitions> &map) { ...@@ -199,8 +199,6 @@ dump_bb2partition(const std::map<BasicBlock *, GVN::partitions> &map) {
static void static void
print_partitions(const GVN::partitions &p) { print_partitions(const GVN::partitions &p) {
if (p.empty()) { if (p.empty()) {
// to del
// std::cout << "empty partitions\n";
LOG_DEBUG << "empty partitions\n"; LOG_DEBUG << "empty partitions\n";
return; return;
} }
...@@ -208,8 +206,6 @@ print_partitions(const GVN::partitions &p) { ...@@ -208,8 +206,6 @@ print_partitions(const GVN::partitions &p) {
for (auto &cc : p) for (auto &cc : p)
log += print_congruence_class(*cc); log += print_congruence_class(*cc);
LOG_DEBUG << log; // please don't use std::cout LOG_DEBUG << log; // please don't use std::cout
// to del
// std::cout << log;
} }
} // namespace utils } // namespace utils
...@@ -258,7 +254,6 @@ GVN::intersect(shared_ptr<CongruenceClass> ci, shared_ptr<CongruenceClass> cj) { ...@@ -258,7 +254,6 @@ GVN::intersect(shared_ptr<CongruenceClass> ci, shared_ptr<CongruenceClass> cj) {
*ci->value_phi_ == *cj->value_phi_) *ci->value_phi_ == *cj->value_phi_)
c->value_phi_ = ci->value_phi_; c->value_phi_ = ci->value_phi_;
// ??
// What if the ve is nullptr? // What if the ve is nullptr?
if (c->members_.size()) // not empty if (c->members_.size()) // not empty
{ {
...@@ -280,7 +275,9 @@ GVN::intersect(shared_ptr<CongruenceClass> ci, shared_ptr<CongruenceClass> cj) { ...@@ -280,7 +275,9 @@ GVN::intersect(shared_ptr<CongruenceClass> ci, shared_ptr<CongruenceClass> cj) {
c->index_ = exact_idx; c->index_ = exact_idx;
c->value_expr_ = c->value_phi_ = c->value_expr_ = c->value_phi_ =
PhiExpression::create(ci->value_expr_, cj->value_expr_); PhiExpression::create(ci->value_expr_, cj->value_expr_);
} else if (ci->value_expr_->get_expr_type() == Expression::e_call) {
} }
// ?? // ??
c->leader_ = *c->members_.begin(); c->leader_ = *c->members_.begin();
} }
...@@ -288,21 +285,36 @@ GVN::intersect(shared_ptr<CongruenceClass> ci, shared_ptr<CongruenceClass> cj) { ...@@ -288,21 +285,36 @@ GVN::intersect(shared_ptr<CongruenceClass> ci, shared_ptr<CongruenceClass> cj) {
return c; return c;
} }
// assign start index for each instruction, including copy statement // This should be called before dealing with any function.
// ther logic here: // It will assign start index for each instruction, including copy statement.
// - use the same traver order as main run // the logic here:
// - use the same traverse order as main run
// - assign an incremental number for each instruction // - assign an incremental number for each instruction
void void
GVN::assign_start_idx_() { GVN::assign_start_idx_() {
int res; int res;
// this should be the travers order
//
reset_number();
// first global_map_
next_value_number_ += global_map_.size();
// then arguments
next_value_number_ += func_->get_num_of_args();
// to del
/* auto arg_it = func_->arg_begin();
* for (; arg_it != func_->arg_end(); ++arg_it) {
* start_idx_[std::make_pair(nullptr, *arg_it)] = new_number();
* } */
// then instructions
for (auto &bb : func_->get_basic_blocks()) { for (auto &bb : func_->get_basic_blocks()) {
for (auto &instr : bb.get_instructions()) { for (auto &instr : bb.get_instructions()) {
if (not instr.is_phi() and not instr.is_void()) if (not instr.is_phi() and not instr.is_void())
// real instructions
start_idx_[std::make_pair(&instr, nullptr)] = new_number(); start_idx_[std::make_pair(&instr, nullptr)] = new_number();
} }
// and the phi instruction in all the successors
for (auto succ : bb.get_succ_basic_blocks()) { for (auto succ : bb.get_succ_basic_blocks()) {
for (auto &instr : succ->get_instructions()) { for (auto &instr : succ->get_instructions()) {
// and the phi instruction in all the successors
if (instr.is_phi()) { if (instr.is_phi()) {
if ((res = pretend_copy_stmt(&instr, &bb)) == -1) if ((res = pretend_copy_stmt(&instr, &bb)) == -1)
continue; continue;
...@@ -312,13 +324,36 @@ GVN::assign_start_idx_() { ...@@ -312,13 +324,36 @@ GVN::assign_start_idx_() {
} }
} }
} }
next_value_number_ = 1; // this is necessary, cause the initialization for Entry block will use it!
reset_number();
} }
void
GVN::deal_with_entry(BasicBlock *Entry) {
// add congruence class for global variables
for (auto item : global_map_) {
auto c = createCongruenceClass(new_number());
c->leader_ = item.first;
c->members_.insert(item.first);
c->value_expr_ = item.second;
pin_[Entry].insert(c);
}
for (auto arg : func_->get_args()) {
auto c = createCongruenceClass(new_number());
c->leader_ = arg;
c->members_.insert(arg);
c->value_expr_ = arg_map_[arg];
pin_[Entry].insert(c);
}
// order matters!
_TOP[Entry] = false;
pout_[Entry] = transferFunction(Entry);
}
void void
GVN::detectEquivalences() { GVN::detectEquivalences() {
int times = 0; int times = 0;
bool changed; bool changed;
// DEBUG part
std::cout << "all the instruction address:" << std::endl; std::cout << "all the instruction address:" << std::endl;
for (auto &bb : func_->get_basic_blocks()) { for (auto &bb : func_->get_basic_blocks()) {
for (auto &instr : bb.get_instructions()) for (auto &instr : bb.get_instructions())
...@@ -331,10 +366,7 @@ GVN::detectEquivalences() { ...@@ -331,10 +366,7 @@ GVN::detectEquivalences() {
} }
auto Entry = func_->get_entry_block(); auto Entry = func_->get_entry_block();
_TOP[Entry] = false; deal_with_entry(Entry);
pin_[Entry].clear();
pout_[Entry] = transferFunction(Entry);
// iterate until converge // iterate until converge
do { do {
...@@ -376,7 +408,8 @@ GVN::detectEquivalences() { ...@@ -376,7 +408,8 @@ GVN::detectEquivalences() {
/* std::cout << "//-------\n" /* std::cout << "//-------\n"
* << "//after transferFunction(basic-block=" * << "//after transferFunction(basic-block="
* << bb->get_name() << "), all pout:" << std::endl; */ * << bb->get_name() << "), all pout:" << std::endl;
*/
// dump_tmp(*func_); // dump_tmp(*func_);
} }
// reset value number // reset value number
...@@ -387,6 +420,10 @@ GVN::detectEquivalences() { ...@@ -387,6 +420,10 @@ GVN::detectEquivalences() {
// or return nullptr // or return nullptr
shared_ptr<Expression> shared_ptr<Expression>
GVN::search_ve(Value *v, const GVN::partitions &part) { GVN::search_ve(Value *v, const GVN::partitions &part) {
if (dynamic_cast<GlobalVariable *>(v))
return global_map_.at(dynamic_cast<GlobalVariable *>(v));
if (dynamic_cast<Argument *>(v))
return arg_map_.at(dynamic_cast<Argument *>(v));
for (auto c : part) { for (auto c : part) {
if (std::find(c->members_.begin(), c->members_.end(), v) != if (std::find(c->members_.begin(), c->members_.end(), v) !=
c->members_.end()) { c->members_.end()) {
...@@ -441,7 +478,8 @@ std::vector<shared_ptr<Expression>> ...@@ -441,7 +478,8 @@ std::vector<shared_ptr<Expression>>
GVN::valueExpr_core_(Instruction *instr, GVN::valueExpr_core_(Instruction *instr,
const partitions &part, const partitions &part,
const size_t count, const size_t count,
bool fold_) { bool fold_,
int start_idx) {
assert(not(fold_ and count > 2)); assert(not(fold_ and count > 2));
Value *v; Value *v;
...@@ -462,7 +500,7 @@ GVN::valueExpr_core_(Instruction *instr, ...@@ -462,7 +500,7 @@ GVN::valueExpr_core_(Instruction *instr,
// - try to find expression that already exists // - try to find expression that already exists
// - take care of constant // - take care of constant
for (int i = 0; i != count; i++) { for (int i = 0; i != count; i++) {
v = operands[i]; v = operands[i + start_idx];
v_const = dynamic_cast<Constant *>(v); v_const = dynamic_cast<Constant *>(v);
if (v_const) { if (v_const) {
ret.push_back(ConstantExpression::create(v_const)); ret.push_back(ConstantExpression::create(v_const));
...@@ -478,7 +516,6 @@ GVN::valueExpr_core_(Instruction *instr, ...@@ -478,7 +516,6 @@ GVN::valueExpr_core_(Instruction *instr,
shared_ptr<Expression> shared_ptr<Expression>
GVN::valueExpr(Instruction *instr, const partitions &part) { GVN::valueExpr(Instruction *instr, const partitions &part) {
// TODO // TODO
// ?? should use part?
std::string err{"Undefined"}; std::string err{"Undefined"};
std::vector<shared_ptr<Expression>> res; std::vector<shared_ptr<Expression>> res;
...@@ -488,7 +525,7 @@ GVN::valueExpr(Instruction *instr, const partitions &part) { ...@@ -488,7 +525,7 @@ GVN::valueExpr(Instruction *instr, const partitions &part) {
return tmp; return tmp;
if (instr->isBinary() or instr->is_cmp() or instr->is_fcmp()) { if (instr->isBinary() or instr->is_cmp() or instr->is_fcmp()) {
res = valueExpr_core_(instr, part, 2); res = valueExpr_core_(instr, part, 2, true);
if (res.size() == 1) // constant fold if (res.size() == 1) // constant fold
return res[0]; return res[0];
else else
...@@ -497,7 +534,7 @@ GVN::valueExpr(Instruction *instr, const partitions &part) { ...@@ -497,7 +534,7 @@ GVN::valueExpr(Instruction *instr, const partitions &part) {
} else if (instr->is_phi()) { } else if (instr->is_phi()) {
err = "phi"; err = "phi";
} else if (instr->is_fp2si() or instr->is_si2fp() or instr->is_zext()) { } else if (instr->is_fp2si() or instr->is_si2fp() or instr->is_zext()) {
res = valueExpr_core_(instr, part, 1); res = valueExpr_core_(instr, part, 1, true);
if (res[0]->get_expr_type() == Expression::e_constant) if (res[0]->get_expr_type() == Expression::e_constant)
return res[0]; return res[0];
Type *dest_type; Type *dest_type;
...@@ -522,11 +559,18 @@ GVN::valueExpr(Instruction *instr, const partitions &part) { ...@@ -522,11 +559,18 @@ GVN::valueExpr(Instruction *instr, const partitions &part) {
instr->get_instr_type(), res[0], dest_type); instr->get_instr_type(), res[0], dest_type);
} }
} else if (instr->is_gep()) { } else if (instr->is_gep()) {
res = valueExpr_core_(instr, part, instr->get_operands().size(), false); res = valueExpr_core_(instr, part, instr->get_operands().size());
auto ptr = res[0]; auto ptr = res[0];
res.erase(res.begin()); res.erase(res.begin());
return GEPExpression::create(ptr, res); return GEPExpression::create(ptr, res);
} else if (instr->is_load() or instr->is_alloca() or instr->is_call()) { } else if (instr->is_call()) {
// 0 is the function*, the arguments start from 1
auto f = static_cast<Function *>(
static_cast<CallInst *>(instr)->get_operand(0));
res = valueExpr_core_(instr, part, f->get_num_of_args(), false, 1);
return CallExpression::create(
instr, next_value_number_, func_info_->is_pure_function(f), res);
} else if (instr->is_load() or instr->is_alloca()) {
auto ret = search_ve(instr, part); auto ret = search_ve(instr, part);
if (ret) if (ret)
return ret; return ret;
...@@ -551,7 +595,6 @@ GVN::transferFunction(Instruction *instr, Value *e, partitions pin) { ...@@ -551,7 +595,6 @@ GVN::transferFunction(Instruction *instr, Value *e, partitions pin) {
partitions pout = clone(pin); partitions pout = clone(pin);
// TODO: deal with copy-stmt case // TODO: deal with copy-stmt case
// ?? deal with copy statement
auto e_instr = dynamic_cast<Instruction *>(e); auto e_instr = dynamic_cast<Instruction *>(e);
auto e_const = dynamic_cast<Constant *>(e); auto e_const = dynamic_cast<Constant *>(e);
assert((not e or e_instr or e_const) && assert((not e or e_instr or e_const) &&
...@@ -618,7 +661,6 @@ GVN::transferFunction(BasicBlock *bb) { ...@@ -618,7 +661,6 @@ GVN::transferFunction(BasicBlock *bb) {
// iterate through all instructions in the block // iterate through all instructions in the block
for (auto &instr : bb->get_instructions()) { for (auto &instr : bb->get_instructions()) {
// ?? what about orther instructions? Are they all ok?
if (not instr.is_phi() and not instr.is_void()) if (not instr.is_phi() and not instr.is_void())
part = transferFunction(&instr, nullptr, part); part = transferFunction(&instr, nullptr, part);
} }
...@@ -745,7 +787,7 @@ void ...@@ -745,7 +787,7 @@ void
GVN::replace_cc_members() { GVN::replace_cc_members() {
for (auto &[_bb, part] : pout_) { for (auto &[_bb, part] : pout_) {
auto bb = _bb; // workaround: structured bindings can't be captured auto bb = _bb; // workaround: structured bindings can't be captured
// in C++17 // in C++17
for (auto &cc : part) { for (auto &cc : part) {
if (cc->index_ == 0) if (cc->index_ == 0)
continue; continue;
...@@ -780,6 +822,18 @@ GVN::replace_cc_members() { ...@@ -780,6 +822,18 @@ GVN::replace_cc_members() {
return; return;
} }
// should be called only once and before the detectEquivalences() starts
void
GVN::add_map_() {
// add map for global variables
for (auto &glb : m_->get_global_variable())
global_map_[&glb] = GlobalVarExpression::create(&glb);
// add map for function parameters
for (auto &f : m_->get_functions())
for (auto arg : f.get_args())
arg_map_[arg] = ArgExpression::create(arg);
}
// top-level function, done for you // top-level function, done for you
void void
GVN::run() { GVN::run() {
...@@ -795,6 +849,7 @@ GVN::run() { ...@@ -795,6 +849,7 @@ GVN::run() {
dce_ = std::make_unique<DeadCode>(m_); dce_ = std::make_unique<DeadCode>(m_);
dce_->run(); // let dce take care of some dead phis with undef dce_->run(); // let dce take care of some dead phis with undef
add_map_();
for (auto &f : m_->get_functions()) { for (auto &f : m_->get_functions()) {
if (f.get_basic_blocks().empty()) if (f.get_basic_blocks().empty())
continue; continue;
...@@ -823,7 +878,6 @@ GVN::run() { ...@@ -823,7 +878,6 @@ GVN::run() {
gvn_json << "},"; gvn_json << "},";
} }
replace_cc_members(); // don't delete instructions, just replace replace_cc_members(); // don't delete instructions, just replace
// them
} }
dce_->run(); // let dce do that for us dce_->run(); // let dce do that for us
if (dump_json_) if (dump_json_)
...@@ -868,6 +922,14 @@ GVNExpression::operator==(const Expression &lhs, const Expression &rhs) { ...@@ -868,6 +922,14 @@ GVNExpression::operator==(const Expression &lhs, const Expression &rhs) {
return equiv_as<GEPExpression>(lhs, rhs); return equiv_as<GEPExpression>(lhs, rhs);
case Expression::e_unique: case Expression::e_unique:
return equiv_as<UniqueExpression>(lhs, rhs); return equiv_as<UniqueExpression>(lhs, rhs);
case Expression::e_global:
return equiv_as<GlobalVarExpression>(lhs, rhs);
case Expression::e_argument:
return equiv_as<ArgExpression>(lhs, rhs);
case Expression::e_call:
return equiv_as<CallExpression>(lhs, rhs);
default:
abort();
} }
} }
...@@ -893,7 +955,6 @@ GVN::clone(const partitions &p) { ...@@ -893,7 +955,6 @@ GVN::clone(const partitions &p) {
bool bool
operator==(const GVN::partitions &p1, const GVN::partitions &p2) { operator==(const GVN::partitions &p1, const GVN::partitions &p2) {
// TODO: how to compare partitions? // TODO: how to compare partitions?
// cannot direct compare???
if (p1.size() != p2.size()) if (p1.size() != p2.size())
return false; return false;
auto it1 = p1.begin(); auto it1 = p1.begin();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment