Commit bd7b44fb authored by 李晓奇's avatar 李晓奇

can pass functional

parent f268905c
......@@ -42,13 +42,23 @@ class ConstFolder {
class Expression {
public:
// TODO: you need to extend expression types according to testcases
enum gvn_expr_t { e_constant, e_bin, e_phi, e_cast, e_gep, e_unique };
enum gvn_expr_t {
e_constant,
e_bin,
e_phi,
e_cast,
e_gep,
e_unique,
e_global,
e_argument,
e_call
};
Expression(gvn_expr_t t) : expr_type(t) {}
virtual ~Expression() = default;
virtual std::string print() = 0;
gvn_expr_t get_expr_type() const { return expr_type; }
private:
protected:
gvn_expr_t expr_type;
};
......@@ -198,7 +208,6 @@ class GEPExpression : public Expression {
std::vector<std::shared_ptr<Expression>> idxs_;
};
// unique expression: not equal to any one else
class UniqueExpression : public Expression {
public:
static std::shared_ptr<UniqueExpression> create(Instruction *instr,
......@@ -213,13 +222,88 @@ class UniqueExpression : public Expression {
return instr_ == other->instr_;
}
UniqueExpression(Instruction *instr, size_t index)
: Expression(e_unique), instr_(instr), index_(index) {}
UniqueExpression(Instruction *instr, size_t index, gvn_expr_t tp = e_unique)
: Expression(tp), instr_(instr), index_(index) {}
private:
protected:
Instruction *instr_;
size_t index_;
};
// global variables
class GlobalVarExpression : public Expression {
public:
static std::shared_ptr<GlobalVarExpression> create(GlobalVariable *g) {
return std::make_shared<GlobalVarExpression>(g);
}
virtual std::string print() { return "<" + g_->get_name() + ">"; }
bool equiv(const GlobalVarExpression *other) const {
return g_ == other->g_;
}
GlobalVarExpression(GlobalVariable *g) : Expression(e_global), g_(g) {}
private:
GlobalVariable *g_;
};
// this is the arguments passed to the current function, but not the arguments
// given to a function call.
class ArgExpression : public Expression {
public:
static std::shared_ptr<ArgExpression> create(Argument *a) {
return std::make_shared<ArgExpression>(a);
}
virtual std::string print() { return "[" + a_->get_name() + "]"; }
bool equiv(const ArgExpression *other) const { return a_ == other->a_; }
ArgExpression(Argument *a) : Expression(e_argument), a_(a) {}
private:
Argument *a_;
};
// function call
class CallExpression : public UniqueExpression {
public:
static std::shared_ptr<CallExpression> create(
Instruction *instr,
size_t index,
bool pure,
std::vector<std::shared_ptr<Expression>> &args) {
return std::make_shared<CallExpression>(instr, index, pure, args);
}
virtual std::string print() {
std::string ret = "{ " + f_->get_name();
for (auto arg : args_)
ret += " " + arg->print();
return ret + " }";
}
bool equiv(const CallExpression *other) const {
// single instruction, should be same
if (instr_ == other->instr_)
return true;
if (not(pure_ and f_ == other->f_))
return false;
for (int i = 0; i < args_.size(); i++)
if (not(*args_[i] == *other->args_[i]))
return false;
return true;
}
CallExpression(Instruction *instr,
size_t index,
bool pure,
std::vector<std::shared_ptr<Expression>> &args)
: UniqueExpression(instr, index, e_call)
, pure_(pure)
, f_(static_cast<Function *>(instr->get_operand(0)))
, args_(args) {}
private:
bool pure_;
Function *f_;
std::vector<std::shared_ptr<Expression>> args_;
};
} // namespace GVNExpression
/**
......@@ -304,7 +388,7 @@ class GVN : public Pass {
private:
bool dump_json_;
std::uint64_t next_value_number_ = 1;
std::uint64_t next_value_number_;
Function *func_;
std::map<BasicBlock *, partitions> pin_, pout_;
std::unique_ptr<FuncInfo> func_info_;
......@@ -312,30 +396,38 @@ class GVN : public Pass {
std::unique_ptr<DeadCode> dce_;
// self add member
std::map<BasicBlock *, bool> _TOP;
partitions join_helper(BasicBlock *pre1, BasicBlock *pre2);
// std::uint64_t start_number_;
BasicBlock *curr_bb;
std::map<BasicBlock *, bool> _TOP;
std::map<std::pair<Instruction *, Value *>, size_t> start_idx_;
std::map<GlobalVariable *,
std::shared_ptr<GVNExpression::GlobalVarExpression>>
global_map_;
std::map<Argument *, std::shared_ptr<GVNExpression::Expression>> arg_map_;
//
// self add function
//
void add_map_();
std::uint64_t new_number() { return next_value_number_++; }
void deal_with_entry(BasicBlock *Entry);
static int pretend_copy_stmt(Instruction *inst, BasicBlock *bb);
std::shared_ptr<GVNExpression::Expression> search_ve(
Value *v,
const partitions &part);
partitions join_helper(BasicBlock *pre1, BasicBlock *pre2);
std::vector<std::shared_ptr<GVNExpression::Expression>> valueExpr_core_(
Instruction *instr,
const partitions &part,
const size_t count,
bool fold_ = true);
bool fold_ = false,
int start_index_ = 0);
Constant *constFold_core(const size_t count,
const partitions &part,
Instruction *instr,
const std::vector<Value *> &operands);
void assign_start_idx_();
void dump_tmp(Function &);
void reset_number() { next_value_number_ = 1; }
};
bool operator==(const GVN::partitions &p1, const GVN::partitions &p2);
......@@ -199,8 +199,6 @@ dump_bb2partition(const std::map<BasicBlock *, GVN::partitions> &map) {
static void
print_partitions(const GVN::partitions &p) {
if (p.empty()) {
// to del
// std::cout << "empty partitions\n";
LOG_DEBUG << "empty partitions\n";
return;
}
......@@ -208,8 +206,6 @@ print_partitions(const GVN::partitions &p) {
for (auto &cc : p)
log += print_congruence_class(*cc);
LOG_DEBUG << log; // please don't use std::cout
// to del
// std::cout << log;
}
} // namespace utils
......@@ -258,7 +254,6 @@ GVN::intersect(shared_ptr<CongruenceClass> ci, shared_ptr<CongruenceClass> cj) {
*ci->value_phi_ == *cj->value_phi_)
c->value_phi_ = ci->value_phi_;
// ??
// What if the ve is nullptr?
if (c->members_.size()) // not empty
{
......@@ -280,7 +275,9 @@ GVN::intersect(shared_ptr<CongruenceClass> ci, shared_ptr<CongruenceClass> cj) {
c->index_ = exact_idx;
c->value_expr_ = c->value_phi_ =
PhiExpression::create(ci->value_expr_, cj->value_expr_);
} else if (ci->value_expr_->get_expr_type() == Expression::e_call) {
}
// ??
c->leader_ = *c->members_.begin();
}
......@@ -288,21 +285,36 @@ GVN::intersect(shared_ptr<CongruenceClass> ci, shared_ptr<CongruenceClass> cj) {
return c;
}
// assign start index for each instruction, including copy statement
// ther logic here:
// - use the same traver order as main run
// This should be called before dealing with any function.
// It will assign start index for each instruction, including copy statement.
// the logic here:
// - use the same traverse order as main run
// - assign an incremental number for each instruction
void
GVN::assign_start_idx_() {
int res;
// this should be the travers order
//
reset_number();
// first global_map_
next_value_number_ += global_map_.size();
// then arguments
next_value_number_ += func_->get_num_of_args();
// to del
/* auto arg_it = func_->arg_begin();
* for (; arg_it != func_->arg_end(); ++arg_it) {
* start_idx_[std::make_pair(nullptr, *arg_it)] = new_number();
* } */
// then instructions
for (auto &bb : func_->get_basic_blocks()) {
for (auto &instr : bb.get_instructions()) {
if (not instr.is_phi() and not instr.is_void())
// real instructions
start_idx_[std::make_pair(&instr, nullptr)] = new_number();
}
// and the phi instruction in all the successors
for (auto succ : bb.get_succ_basic_blocks()) {
for (auto &instr : succ->get_instructions()) {
// and the phi instruction in all the successors
if (instr.is_phi()) {
if ((res = pretend_copy_stmt(&instr, &bb)) == -1)
continue;
......@@ -312,13 +324,36 @@ GVN::assign_start_idx_() {
}
}
}
next_value_number_ = 1;
// this is necessary, cause the initialization for Entry block will use it!
reset_number();
}
void
GVN::deal_with_entry(BasicBlock *Entry) {
// add congruence class for global variables
for (auto item : global_map_) {
auto c = createCongruenceClass(new_number());
c->leader_ = item.first;
c->members_.insert(item.first);
c->value_expr_ = item.second;
pin_[Entry].insert(c);
}
for (auto arg : func_->get_args()) {
auto c = createCongruenceClass(new_number());
c->leader_ = arg;
c->members_.insert(arg);
c->value_expr_ = arg_map_[arg];
pin_[Entry].insert(c);
}
// order matters!
_TOP[Entry] = false;
pout_[Entry] = transferFunction(Entry);
}
void
GVN::detectEquivalences() {
int times = 0;
bool changed;
// DEBUG part
std::cout << "all the instruction address:" << std::endl;
for (auto &bb : func_->get_basic_blocks()) {
for (auto &instr : bb.get_instructions())
......@@ -331,10 +366,7 @@ GVN::detectEquivalences() {
}
auto Entry = func_->get_entry_block();
_TOP[Entry] = false;
pin_[Entry].clear();
pout_[Entry] = transferFunction(Entry);
deal_with_entry(Entry);
// iterate until converge
do {
......@@ -376,7 +408,8 @@ GVN::detectEquivalences() {
/* std::cout << "//-------\n"
* << "//after transferFunction(basic-block="
* << bb->get_name() << "), all pout:" << std::endl; */
* << bb->get_name() << "), all pout:" << std::endl;
*/
// dump_tmp(*func_);
}
// reset value number
......@@ -387,6 +420,10 @@ GVN::detectEquivalences() {
// or return nullptr
shared_ptr<Expression>
GVN::search_ve(Value *v, const GVN::partitions &part) {
if (dynamic_cast<GlobalVariable *>(v))
return global_map_.at(dynamic_cast<GlobalVariable *>(v));
if (dynamic_cast<Argument *>(v))
return arg_map_.at(dynamic_cast<Argument *>(v));
for (auto c : part) {
if (std::find(c->members_.begin(), c->members_.end(), v) !=
c->members_.end()) {
......@@ -441,7 +478,8 @@ std::vector<shared_ptr<Expression>>
GVN::valueExpr_core_(Instruction *instr,
const partitions &part,
const size_t count,
bool fold_) {
bool fold_,
int start_idx) {
assert(not(fold_ and count > 2));
Value *v;
......@@ -462,7 +500,7 @@ GVN::valueExpr_core_(Instruction *instr,
// - try to find expression that already exists
// - take care of constant
for (int i = 0; i != count; i++) {
v = operands[i];
v = operands[i + start_idx];
v_const = dynamic_cast<Constant *>(v);
if (v_const) {
ret.push_back(ConstantExpression::create(v_const));
......@@ -478,7 +516,6 @@ GVN::valueExpr_core_(Instruction *instr,
shared_ptr<Expression>
GVN::valueExpr(Instruction *instr, const partitions &part) {
// TODO
// ?? should use part?
std::string err{"Undefined"};
std::vector<shared_ptr<Expression>> res;
......@@ -488,7 +525,7 @@ GVN::valueExpr(Instruction *instr, const partitions &part) {
return tmp;
if (instr->isBinary() or instr->is_cmp() or instr->is_fcmp()) {
res = valueExpr_core_(instr, part, 2);
res = valueExpr_core_(instr, part, 2, true);
if (res.size() == 1) // constant fold
return res[0];
else
......@@ -497,7 +534,7 @@ GVN::valueExpr(Instruction *instr, const partitions &part) {
} else if (instr->is_phi()) {
err = "phi";
} else if (instr->is_fp2si() or instr->is_si2fp() or instr->is_zext()) {
res = valueExpr_core_(instr, part, 1);
res = valueExpr_core_(instr, part, 1, true);
if (res[0]->get_expr_type() == Expression::e_constant)
return res[0];
Type *dest_type;
......@@ -522,11 +559,18 @@ GVN::valueExpr(Instruction *instr, const partitions &part) {
instr->get_instr_type(), res[0], dest_type);
}
} else if (instr->is_gep()) {
res = valueExpr_core_(instr, part, instr->get_operands().size(), false);
res = valueExpr_core_(instr, part, instr->get_operands().size());
auto ptr = res[0];
res.erase(res.begin());
return GEPExpression::create(ptr, res);
} else if (instr->is_load() or instr->is_alloca() or instr->is_call()) {
} else if (instr->is_call()) {
// 0 is the function*, the arguments start from 1
auto f = static_cast<Function *>(
static_cast<CallInst *>(instr)->get_operand(0));
res = valueExpr_core_(instr, part, f->get_num_of_args(), false, 1);
return CallExpression::create(
instr, next_value_number_, func_info_->is_pure_function(f), res);
} else if (instr->is_load() or instr->is_alloca()) {
auto ret = search_ve(instr, part);
if (ret)
return ret;
......@@ -551,7 +595,6 @@ GVN::transferFunction(Instruction *instr, Value *e, partitions pin) {
partitions pout = clone(pin);
// TODO: deal with copy-stmt case
// ?? deal with copy statement
auto e_instr = dynamic_cast<Instruction *>(e);
auto e_const = dynamic_cast<Constant *>(e);
assert((not e or e_instr or e_const) &&
......@@ -618,7 +661,6 @@ GVN::transferFunction(BasicBlock *bb) {
// iterate through all instructions in the block
for (auto &instr : bb->get_instructions()) {
// ?? what about orther instructions? Are they all ok?
if (not instr.is_phi() and not instr.is_void())
part = transferFunction(&instr, nullptr, part);
}
......@@ -745,7 +787,7 @@ void
GVN::replace_cc_members() {
for (auto &[_bb, part] : pout_) {
auto bb = _bb; // workaround: structured bindings can't be captured
// in C++17
// in C++17
for (auto &cc : part) {
if (cc->index_ == 0)
continue;
......@@ -780,6 +822,18 @@ GVN::replace_cc_members() {
return;
}
// should be called only once and before the detectEquivalences() starts
void
GVN::add_map_() {
// add map for global variables
for (auto &glb : m_->get_global_variable())
global_map_[&glb] = GlobalVarExpression::create(&glb);
// add map for function parameters
for (auto &f : m_->get_functions())
for (auto arg : f.get_args())
arg_map_[arg] = ArgExpression::create(arg);
}
// top-level function, done for you
void
GVN::run() {
......@@ -795,6 +849,7 @@ GVN::run() {
dce_ = std::make_unique<DeadCode>(m_);
dce_->run(); // let dce take care of some dead phis with undef
add_map_();
for (auto &f : m_->get_functions()) {
if (f.get_basic_blocks().empty())
continue;
......@@ -823,7 +878,6 @@ GVN::run() {
gvn_json << "},";
}
replace_cc_members(); // don't delete instructions, just replace
// them
}
dce_->run(); // let dce do that for us
if (dump_json_)
......@@ -868,6 +922,14 @@ GVNExpression::operator==(const Expression &lhs, const Expression &rhs) {
return equiv_as<GEPExpression>(lhs, rhs);
case Expression::e_unique:
return equiv_as<UniqueExpression>(lhs, rhs);
case Expression::e_global:
return equiv_as<GlobalVarExpression>(lhs, rhs);
case Expression::e_argument:
return equiv_as<ArgExpression>(lhs, rhs);
case Expression::e_call:
return equiv_as<CallExpression>(lhs, rhs);
default:
abort();
}
}
......@@ -893,7 +955,6 @@ GVN::clone(const partitions &p) {
bool
operator==(const GVN::partitions &p1, const GVN::partitions &p2) {
// TODO: how to compare partitions?
// cannot direct compare???
if (p1.size() != p2.size())
return false;
auto it1 = p1.begin();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment