Commit 93702b62 authored by Yang's avatar Yang

fix lab4

parent 80b0a14d
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "Value.hpp" #include "Value.hpp"
#include <list> #include <list>
#include <set>
#include <string> #include <string>
class Function; class Function;
...@@ -64,8 +65,12 @@ class BasicBlock : public Value { ...@@ -64,8 +65,12 @@ class BasicBlock : public Value {
// 从 BasicBlock 移除 Instruction,并 delete 这个 Instruction // 从 BasicBlock 移除 Instruction,并 delete 这个 Instruction
void erase_instr(Instruction* instr) { instr_list_.remove(instr); delete instr; } void erase_instr(Instruction* instr) { instr_list_.remove(instr); delete instr; }
// 从 BasicBlock 移除集合中的 Instruction,并 delete 这些 Instruction
void erase_instrs(const std::set<Instruction*>& instr);
// 从 BasicBlock 移除 Instruction,你需要自己 delete 它 // 从 BasicBlock 移除 Instruction,你需要自己 delete 它
void remove_instr(Instruction *instr) { instr_list_.remove(instr); } void remove_instr(Instruction *instr) { instr_list_.remove(instr); }
// 从 BasicBlock 移除集合中的 Instruction,你需要自己 delete 它们
void remove_instrs(const std::set<Instruction*>& instr);
// 移除的 Instruction 需要自己 delete // 移除的 Instruction 需要自己 delete
std::list<Instruction*> &get_instructions() { return instr_list_; } std::list<Instruction*> &get_instructions() { return instr_list_; }
......
...@@ -342,16 +342,7 @@ class PhiInst : public Instruction { ...@@ -342,16 +342,7 @@ class PhiInst : public Instruction {
this->add_operand(val); this->add_operand(val);
this->add_operand(pre_bb); this->add_operand(pre_bb);
} }
std::vector<std::pair<Value *, BasicBlock *>> get_phi_pairs() const std::vector<std::pair<Value*, BasicBlock*>> get_phi_pairs() const;
{
std::vector<std::pair<Value *, BasicBlock *>> res;
int ops = static_cast<int>(get_num_operand());
for (int i = 0; i < ops; i += 2) {
res.emplace_back(this->get_operand(i),
this->get_operand(i + 1)->as<BasicBlock>());
}
return res;
}
std::string print() override; std::string print() override;
}; };
......
#pragma once #pragma once
#include "FuncInfo.hpp" #include <deque>
#include "PassManager.hpp"
#include <unordered_set> #include "PassManager.hpp"
class FuncInfo;
/** /**
* 死代码消除:参见 * 死代码消除:假设所有指令都可以去掉,然后只保留具有副作用的指令和它们所影响的指令。去掉不可达的基本块。
*https://www.clear.rice.edu/comp512/Lectures/10Dead-Clean-SCCP.pdf *
* 在初始化时指定一个 bool 参数 remove_unreachable_bb, 代表是否去除函数不可达基本块
*
* 参见 https://www.clear.rice.edu/comp512/Lectures/10Dead-Clean-SCCP.pdf
**/ **/
class DeadCode : public Pass { class DeadCode : public TransformPass {
public: public:
DeadCode(Module *m) : Pass(m), func_info(std::make_shared<FuncInfo>(m)) {}
void run(); /**
*
* @param m 所属 Module
* @param remove_unreachable_bb 是否需要删除不可达的 BasicBlocks
*/
DeadCode(Module *m, bool remove_unreachable_bb) : TransformPass(m), remove_bb_(remove_unreachable_bb), func_info(nullptr) {}
void run() override;
private: private:
std::shared_ptr<FuncInfo> func_info; bool remove_bb_;
int ins_count{0}; // 用以衡量死代码消除的性能 FuncInfo* func_info;
std::deque<Instruction *> work_list{};
std::unordered_map<Instruction *, bool> marked{}; std::unordered_map<Instruction *, bool> marked{};
std::deque<Instruction*> work_list{};
// 标记函数中不可删除指令
void mark(Function *func); void mark(Function *func);
void mark(Instruction *ins); // 标记某不可删除的指令依赖的指令
void mark(const Instruction *ins);
// 删除函数中无用指令
bool sweep(Function *func); bool sweep(Function *func);
bool clear_basic_blocks(Function *func); // 从 entry 开始对基本块进行搜索,删除不可达基本块
bool is_critical(Instruction *ins); static bool clear_basic_blocks(Function *func);
void sweep_globally(); // 指令是否有副作用
bool is_critical(Instruction *ins) const;
// 删除无用函数和全局变量
void sweep_globally() const;
}; };
#pragma once #pragma once
#include <map>
#include <set>
#include "BasicBlock.hpp" #include "BasicBlock.hpp"
#include "PassManager.hpp" #include "PassManager.hpp"
#include <map>
#include <set>
class Dominators : public Pass { /**
* 分析 Pass, 获得某函数的支配树信息
*
* 由于它是针对某一特定 Function 的分析 Pass, 你无法通过 m_ 获取 Module, 但可以通过 f_ 获取 Function
*/
class Dominators : public FunctionAnalysisPass {
public: public:
using BBSet = std::set<BasicBlock *>;
explicit Dominators(Module *m) : Pass(m) {} explicit Dominators(Function* f) : FunctionAnalysisPass(f) { assert(!f->is_declaration() && "Dominators can not apply to function declaration."); }
~Dominators() = default; ~Dominators() override = default;
void run() override; void run() override;
void run_on_func(Function *f);
// functions for getting information // 获取基本块的直接支配节点
BasicBlock *get_idom(BasicBlock *bb) { return idom_.at(bb); } BasicBlock *get_idom(const BasicBlock *bb) const { return idom_.at(bb); }
const BBSet &get_dominance_frontier(BasicBlock *bb) { const std::set<BasicBlock*> &get_dominance_frontier(const BasicBlock *bb) {
return dom_frontier_.at(bb); return dom_frontier_.at(bb);
} }
const BBSet &get_dom_tree_succ_blocks(BasicBlock *bb) { const std::set<BasicBlock*> &get_dom_tree_succ_blocks(const BasicBlock *bb) {
return dom_tree_succ_blocks_.at(bb); return dom_tree_succ_blocks_.at(bb);
} }
// print cfg or dominance tree // print cfg or dominance tree
void dump_cfg(Function *f); void dump_cfg() const;
void dump_dominator_tree(Function *f); void dump_dominator_tree();
// functions for dominance tree // functions for dominance tree
const bool is_dominate(BasicBlock *bb1, BasicBlock *bb2) { bool is_dominate(const BasicBlock *bb1, const BasicBlock *bb2) const {
return dom_tree_L_.at(bb1) <= dom_tree_L_.at(bb2) && return dom_tree_L_.at(bb1) <= dom_tree_L_.at(bb2) &&
dom_tree_R_.at(bb1) >= dom_tree_L_.at(bb2); dom_tree_R_.at(bb1) >= dom_tree_L_.at(bb2);
} }
...@@ -45,42 +49,38 @@ class Dominators : public Pass { ...@@ -45,42 +49,38 @@ class Dominators : public Pass {
private: private:
void dfs(BasicBlock *bb, std::set<BasicBlock *> &visited); void dfs(BasicBlock *bb, std::set<BasicBlock *> &visited);
void create_idom(Function *f); void create_idom();
void create_dominance_frontier(Function *f); void create_dominance_frontier();
void create_dom_tree_succ(Function *f); void create_dom_tree_succ();
void create_dom_dfs_order(Function *f); void create_dom_dfs_order();
BasicBlock * intersect(BasicBlock *b1, BasicBlock *b2); BasicBlock * intersect(BasicBlock *b1, const BasicBlock *b2) const;
void create_reverse_post_order(Function *f); void create_reverse_post_order();
void set_idom(BasicBlock *bb, BasicBlock *idom) { idom_[bb] = idom; } void set_idom(const BasicBlock *bb, BasicBlock *idom) { idom_[bb] = idom; }
void set_dominance_frontier(BasicBlock *bb, BBSet &df) { void set_dominance_frontier(const BasicBlock *bb, std::set<BasicBlock*>&df) {
dom_frontier_[bb].clear(); dom_frontier_[bb].clear();
dom_frontier_[bb].insert(df.begin(), df.end()); dom_frontier_[bb].insert(df.begin(), df.end());
} }
void add_dom_tree_succ_block(BasicBlock *bb, BasicBlock *dom_tree_succ_bb) { void add_dom_tree_succ_block(const BasicBlock *bb, BasicBlock *dom_tree_succ_bb) {
dom_tree_succ_blocks_[bb].insert(dom_tree_succ_bb); dom_tree_succ_blocks_[bb].insert(dom_tree_succ_bb);
} }
unsigned int get_post_order(BasicBlock *bb) { unsigned int get_post_order(const BasicBlock *bb) const {
return post_order_.at(bb); return post_order_.at(bb);
} }
// for debug // for debug
void print_idom(Function *f); void print_idom() const;
void print_dominance_frontier(Function *f); void print_dominance_frontier();
// TODO 补充需要的函数
std::list<BasicBlock *> reverse_post_order_{};
std::map<BasicBlock *, int> post_order_id_{}; // the root has highest ID
std::vector<BasicBlock *> post_order_vec_{}; // 逆后序 std::vector<BasicBlock *> post_order_vec_{}; // 逆后序
std::map<BasicBlock *, unsigned int> post_order_{}; // 逆后序 std::map<const BasicBlock *, unsigned int> post_order_{}; // 逆后序
std::map<BasicBlock *, BasicBlock *> idom_{}; // 直接支配 std::map<const BasicBlock *, BasicBlock *> idom_{}; // 直接支配
std::map<BasicBlock *, BBSet> dom_frontier_{}; // 支配边界集合 std::map<const BasicBlock *, std::set<BasicBlock*>> dom_frontier_{}; // 支配边界集合
std::map<BasicBlock *, BBSet> dom_tree_succ_blocks_{}; // 支配树中的后继节点 std::map<const BasicBlock *, std::set<BasicBlock*>> dom_tree_succ_blocks_{}; // 支配树中的后继节点
// 支配树上的dfs序L,R // 支配树上的dfs序L,R
std::map<BasicBlock *, unsigned int> dom_tree_L_; std::map<const BasicBlock *, unsigned int> dom_tree_L_;
std::map<BasicBlock *, unsigned int> dom_tree_R_; std::map<const BasicBlock *, unsigned int> dom_tree_R_;
std::vector<BasicBlock *> dom_dfs_order_; std::vector<BasicBlock *> dom_dfs_order_;
std::vector<BasicBlock *> dom_post_order_; std::vector<BasicBlock *> dom_post_order_;
......
#pragma once #pragma once
#include "PassManager.hpp"
#include "logging.hpp"
#include <deque>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include "PassManager.hpp"
/** /**
* 计算哪些函数是纯函数 * 分析函数的信息,包括哪些函数是纯函数,每个函数存储的变量
* WARN:
* 假定所有函数都是纯函数,除非他写入了全局变量、修改了传入的数组、或者直接间接调用了非纯函数
*/ */
class FuncInfo : public Pass { class FuncInfo : public ModuleAnalysisPass {
// 非纯函数的 load / store 信息
struct UseMessage
{
// 影响的全局变量(注意此处不包含常全局变量,目前的文法也不支持常全局变量))
std::unordered_set<GlobalVariable*> globals_;
// 影响的参数(第一个参数序号为 0)
std::unordered_set<Argument*> arguments_;
void add(Value* val);
bool have(Value* val) const;
bool empty() const;
};
public: public:
FuncInfo(Module *m) : Pass(m) {} FuncInfo(Module *m) : ModuleAnalysisPass(m) {}
void run(); void run() override;
bool is_pure_function(Function *func) const { return is_pure.at(func); } // 函数是否是纯函数
bool is_pure(Function *func) { return !func->is_declaration() && !use_libs.count(func) && loads[func].empty() && stores[func].empty(); }
// 返回 StoreInst 存入的变量(全局/局部变量或函数参数)
static Value* store_ptr(const StoreInst* st);
// 返回 LoadInst 加载的变量(全局/局部变量或函数参数)
static Value* load_ptr(const LoadInst* ld);
// 返回 CallInst 代表的函数调用间接存入的变量(全局/局部变量或函数参数)
std::unordered_set<Value*> get_stores(const CallInst* call);
private: private:
std::deque<Function *> worklist; // 函数存储的值
std::unordered_map<Function *, bool> is_pure; std::unordered_map<Function*, UseMessage> stores;
// 函数加载的值
void trivial_mark(Function *func); std::unordered_map<Function*, UseMessage> loads;
void process(Function *func); // 函数是否因为调用库函数而变得非纯函数
Value *get_first_addr(Value *val); std::unordered_map<Function*, bool> use_libs;
bool is_side_effect_inst(Instruction *inst); // 将所有由变量 var 计算出的指针的来源都设置为变量 var, 并记录在函数内直接对 var 的 load/store
bool is_local_load(LoadInst *inst); void cal_val_2_var(Value* var, std::unordered_map<Value*, Value*>& val_2_var);
bool is_local_store(StoreInst *inst); static Value* trace_ptr(Value* val);
void log(); void log() const;
}; };
#pragma once
#include "FuncInfo.hpp" #include "FuncInfo.hpp"
#include "LoopDetection.hpp" #include "LoopDetection.hpp"
#include "PassManager.hpp" #include "PassManager.hpp"
#include <memory>
#include <unordered_map>
class LoopInvariantCodeMotion : public Pass { class LoopInvariantCodeMotion : public TransformPass {
public: public:
LoopInvariantCodeMotion(Module *m) : Pass(m) {} LoopInvariantCodeMotion(Module *m) : TransformPass(m), loop_detection_(nullptr), func_info_(nullptr) {}
~LoopInvariantCodeMotion() = default; ~LoopInvariantCodeMotion() override = default;
void run() override; void run() override;
private: private:
std::unordered_map<std::shared_ptr<Loop>, bool> is_loop_done_; LoopDetection* loop_detection_;
std::unique_ptr<LoopDetection> loop_detection_; FuncInfo* func_info_;
std::unique_ptr<FuncInfo> func_info_; std::unordered_set<Value*> collect_loop_store_vars(Loop* loop);
void traverse_loop(std::shared_ptr<Loop> loop); std::vector<Instruction*> collect_insts(Loop* loop);
void run_on_loop(std::shared_ptr<Loop> loop); void traverse_loop(Loop* loop);
void collect_loop_info(std::shared_ptr<Loop> loop, void run_on_loop(Loop* loop);
void collect_loop_info(Loop* loop,
std::set<Value *> &loop_instructions, std::set<Value *> &loop_instructions,
std::set<Value *> &updated_global, std::set<Value *> &updated_global,
bool &contains_impure_call); bool &contains_impure_call);
......
#pragma once #pragma once
#include "Dominators.hpp"
#include "PassManager.hpp"
#include <memory>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "Dominators.hpp"
#include "PassManager.hpp"
class BasicBlock; class BasicBlock;
class Dominators; class Dominators;
class Function; class Function;
class Module; class Module;
using BBset = std::set<BasicBlock *>;
using BBvec = std::vector<BasicBlock *>;
class Loop { class Loop {
private: private:
// attribute: // attribute:
...@@ -20,9 +19,9 @@ class Loop { ...@@ -20,9 +19,9 @@ class Loop {
BasicBlock *preheader_ = nullptr; BasicBlock *preheader_ = nullptr;
BasicBlock *header_; BasicBlock *header_;
std::shared_ptr<Loop> parent_ = nullptr; Loop* parent_ = nullptr;
BBvec blocks_; std::vector<BasicBlock*> blocks_;
std::vector<std::shared_ptr<Loop>> sub_loops_; std::vector<Loop*> sub_loops_;
std::unordered_set<BasicBlock *> latches_; std::unordered_set<BasicBlock *> latches_;
public: public:
...@@ -31,34 +30,31 @@ class Loop { ...@@ -31,34 +30,31 @@ class Loop {
} }
~Loop() = default; ~Loop() = default;
void add_block(BasicBlock *bb) { blocks_.push_back(bb); } void add_block(BasicBlock *bb) { blocks_.push_back(bb); }
BasicBlock *get_header() { return header_; } BasicBlock *get_header() const { return header_; }
BasicBlock *get_preheader() { return preheader_; } BasicBlock *get_preheader() const { return preheader_; }
std::shared_ptr<Loop> get_parent() { return parent_; } Loop* get_parent() const { return parent_; }
void set_parent(std::shared_ptr<Loop> parent) { parent_ = parent; } void set_parent(Loop* parent) { parent_ = parent; }
void set_preheader(BasicBlock *bb) { preheader_ = bb; } void set_preheader(BasicBlock *bb) { preheader_ = bb; }
void add_sub_loop(std::shared_ptr<Loop> loop) { sub_loops_.push_back(loop); } void add_sub_loop(Loop* loop) { sub_loops_.push_back(loop); }
const BBvec& get_blocks() { return blocks_; } const std::vector<BasicBlock*>& get_blocks() { return blocks_; }
const std::vector<std::shared_ptr<Loop>>& get_sub_loops() { return sub_loops_; } const std::vector<Loop*>& get_sub_loops() { return sub_loops_; }
const std::unordered_set<BasicBlock *>& get_latches() { return latches_; } const std::unordered_set<BasicBlock *>& get_latches() { return latches_; }
void add_latch(BasicBlock *bb) { latches_.insert(bb); } void add_latch(BasicBlock *bb) { latches_.insert(bb); }
}; };
class LoopDetection : public Pass { class LoopDetection : public FunctionAnalysisPass {
private: Dominators* dominators_;
Function *func_; std::vector<Loop*> loops_;
std::unique_ptr<Dominators> dominators_;
std::vector<std::shared_ptr<Loop>> loops_;
// map from header to loop // map from header to loop
std::unordered_map<BasicBlock *, std::shared_ptr<Loop>> bb_to_loop_; std::unordered_map<BasicBlock *, Loop*> bb_to_loop_;
void discover_loop_and_sub_loops(BasicBlock *bb, BBset &latches, void discover_loop_and_sub_loops(BasicBlock *bb, std::set<BasicBlock*>&latches,
std::shared_ptr<Loop> loop); Loop* loop);
public: public:
LoopDetection(Module *m) : Pass(m) {} LoopDetection(Function *f) : FunctionAnalysisPass(f), dominators_(nullptr) { assert(!f->is_declaration() && "LoopDetection can not apply to function declaration." ); }
~LoopDetection() = default; ~LoopDetection() override;
void run() override; void run() override;
void run_on_func(Function *f); void print() const;
void print() ; std::vector<Loop*> &get_loops() { return loops_; }
std::vector<std::shared_ptr<Loop>> &get_loops() { return loops_; }
}; };
#pragma once #pragma once
#include <map>
#include "Dominators.hpp" #include "Dominators.hpp"
#include "Instruction.hpp" #include "Instruction.hpp"
#include "Value.hpp" #include "Value.hpp"
#include <map> class Mem2Reg : public TransformPass {
#include <memory>
class Mem2Reg : public Pass {
private: private:
// 当前函数
Function *func_; Function *func_;
std::unique_ptr<Dominators> dominators_; // 当前函数对应的支配树
std::map<Value *, Value *> phi_map; Dominators* dominators_;
// TODO 添加需要的变量 // TODO 添加需要的变量
// 所有需要处理的变量
std::list<AllocaInst*> allocas_;
// 变量定值栈 // 变量定值栈
std::map<Value *, std::vector<Value *>> var_val_stack; std::map<AllocaInst*, std::vector<Value *>> var_val_stack;
// phi指令对应的左值(地址) // Phi 对应的局部变量
std::map<PhiInst *, Value *> phi_lval; std::map<PhiInst *, AllocaInst*> phi_to_alloca_;
// 在某个基本块的 Phi
std::map<BasicBlock*, std::list<PhiInst*>> bb_to_phi_;
public: public:
Mem2Reg(Module *m) : Pass(m) {} Mem2Reg(Module *m) : TransformPass(m), func_(nullptr), dominators_(nullptr) {}
~Mem2Reg() = default; ~Mem2Reg() override = default;
void run() override; void run() override;
void generate_phi(); void generate_phi();
void rename(BasicBlock *bb); void rename(BasicBlock *bb);
static inline bool is_global_variable(Value *l_val) {
return dynamic_cast<GlobalVariable *>(l_val) != nullptr;
}
static inline bool is_gep_instr(Value *l_val) {
return dynamic_cast<GetElementPtrInst *>(l_val) != nullptr;
}
static inline bool is_valid_ptr(Value *l_val) {
return not is_global_variable(l_val) and not is_gep_instr(l_val);
}
}; };
#pragma once #pragma once
#include "Module.hpp"
#include <memory> #include <memory>
#include <vector> #include <vector>
class Pass { #include "Module.hpp"
// 转换 Pass, 例如 mem2reg, licm, deadcode
class TransformPass {
public:
TransformPass(Module* m) : m_(m) {}
virtual ~TransformPass();
virtual void run() = 0;
protected:
Module* m_;
};
// 依赖于整个 Module 进行分析的分析 Pass, 例如 funcinfo
class ModuleAnalysisPass {
public: public:
Pass(Module *m) : m_(m) {} ModuleAnalysisPass(Module *m) : m_(m) {}
virtual ~Pass() = default; virtual ~ModuleAnalysisPass();
virtual void run() = 0; virtual void run() = 0;
protected: protected:
Module *m_; Module *m_;
}; };
// 依赖于单个 Function 进行分析的分析 Pass, 例如 dominators, loopdetection
class FunctionAnalysisPass {
public:
FunctionAnalysisPass(Function* f) : f_(f) {}
virtual ~FunctionAnalysisPass();
virtual void run() = 0;
protected:
Function* f_;
};
class PassManager { class PassManager {
public: public:
PassManager(Module *m) : m_(m) {} PassManager(Module *m) : m_(m) {}
// 添加一个 Transform Pass, 添加的 Pass 被顺序运行
template <typename PassType, typename... Args> template <typename PassType, typename... Args>
void add_pass(Args &&...args) { void add_pass(Args &&...args) {
static_assert(std::is_base_of_v<TransformPass, PassType>, "Pass must derive from TransformPass");
passes_.emplace_back(new PassType(m_, std::forward<Args>(args)...)); passes_.emplace_back(new PassType(m_, std::forward<Args>(args)...));
} }
void run() { void run() {
for (auto &pass : passes_) { for (auto& pass : passes_) {
pass->run(); pass->run();
delete pass;
pass = nullptr;
} }
} }
private: private:
std::vector<std::unique_ptr<Pass>> passes_; // 它们会被顺序运行
std::vector<TransformPass*> passes_;
Module *m_; Module *m_;
}; };
...@@ -64,13 +64,12 @@ int main(int argc, char **argv) { ...@@ -64,13 +64,12 @@ int main(int argc, char **argv) {
PassManager PM(m); PassManager PM(m);
// optimization // optimization
if(config.mem2reg) { if(config.mem2reg) {
PM.add_pass<DeadCode>();
PM.add_pass<Mem2Reg>(); PM.add_pass<Mem2Reg>();
PM.add_pass<DeadCode>(); PM.add_pass<DeadCode>(false);
} }
if(config.licm) { if(config.licm) {
PM.add_pass<LoopInvariantCodeMotion>(); PM.add_pass<LoopInvariantCodeMotion>();
PM.add_pass<DeadCode>(); PM.add_pass<DeadCode>(false);
} }
PM.run(); PM.run();
......
...@@ -77,6 +77,27 @@ void BasicBlock::add_instr_before_terminator(Instruction* instr) { ...@@ -77,6 +77,27 @@ void BasicBlock::add_instr_before_terminator(Instruction* instr) {
else instr_list_.insert(std::prev(instr_list_.end()), instr); else instr_list_.insert(std::prev(instr_list_.end()), instr);
} }
void BasicBlock::erase_instrs(const std::set<Instruction*>& instr)
{
std::list<Instruction*> ok;
for (auto i : instr_list_)
{
if (!instr.count(i)) ok.emplace_back(i);
}
instr_list_ = std::move(ok);
for (auto i : instr) delete i;
}
void BasicBlock::remove_instrs(const std::set<Instruction*>& instr)
{
std::list<Instruction*> ok;
for (auto i : instr_list_)
{
if (!instr.count(i)) ok.emplace_back(i);
}
instr_list_ = std::move(ok);
}
std::string BasicBlock::print() { std::string BasicBlock::print() {
std::string bb_ir; std::string bb_ir;
bb_ir += this->get_name(); bb_ir += this->get_name();
......
...@@ -121,7 +121,7 @@ std::string Function::print() { ...@@ -121,7 +121,7 @@ std::string Function::print() {
} }
else { else {
for (auto arg : get_args()) { for (auto arg : get_args()) {
if (&arg != &*get_args().begin()) if (arg != get_args().front())
func_ir += ", "; func_ir += ", ";
func_ir += arg->print(); func_ir += arg->print();
} }
......
...@@ -368,4 +368,15 @@ PhiInst *PhiInst::create_phi(Type *ty, BasicBlock *bb, ...@@ -368,4 +368,15 @@ PhiInst *PhiInst::create_phi(Type *ty, BasicBlock *bb,
return new PhiInst(ty, vals, val_bbs, bb, name); return new PhiInst(ty, vals, val_bbs, bb, name);
} }
std::vector<std::pair<Value*, BasicBlock*>> PhiInst::get_phi_pairs() const
{
std::vector<std::pair<Value*, BasicBlock*>> res;
int ops = static_cast<int>(get_num_operand());
for (int i = 0; i < ops; i += 2) {
auto bb = dynamic_cast<BasicBlock*>(this->get_operand(i + 1));
res.emplace_back(this->get_operand(i), bb);
}
return res;
}
Names GLOBAL_INSTRUCTION_NAMES_{"op", "_" }; Names GLOBAL_INSTRUCTION_NAMES_{"op", "_" };
\ No newline at end of file
...@@ -16,7 +16,7 @@ void User::set_operand(unsigned i, Value *v) { ...@@ -16,7 +16,7 @@ void User::set_operand(unsigned i, Value *v) {
} }
void User::add_operand(Value *v) { void User::add_operand(Value *v) {
assert(v != nullptr && "bad use: add_operand(nullptr)"); if (v == nullptr) return;
v->add_use(this, static_cast<unsigned>(operands_.size())); v->add_use(this, static_cast<unsigned>(operands_.size()));
operands_.push_back(v); operands_.push_back(v);
} }
...@@ -34,10 +34,14 @@ void User::remove_operand(unsigned idx) { ...@@ -34,10 +34,14 @@ void User::remove_operand(unsigned idx) {
assert(idx < operands_.size() && "remove_operand out of index"); assert(idx < operands_.size() && "remove_operand out of index");
// influence on other operands // influence on other operands
for (unsigned i = idx + 1; i < operands_.size(); ++i) { for (unsigned i = idx + 1; i < operands_.size(); ++i) {
if (operands_[i])
{
operands_[i]->remove_use(this, i); operands_[i]->remove_use(this, i);
operands_[i]->add_use(this, i - 1); operands_[i]->add_use(this, i - 1);
} }
}
// remove the designated operand // remove the designated operand
if (operands_[idx])
operands_[idx]->remove_use(this, idx); operands_[idx]->remove_use(this, idx);
operands_.erase(operands_.begin() + idx); operands_.erase(operands_.begin() + idx);
} }
...@@ -14,10 +14,12 @@ bool Value::set_name(const std::string& name) { ...@@ -14,10 +14,12 @@ bool Value::set_name(const std::string& name) {
} }
void Value::add_use(User *user, unsigned arg_no) { void Value::add_use(User *user, unsigned arg_no) {
if (user == nullptr) return;
use_list_.emplace_back(user, arg_no); use_list_.emplace_back(user, arg_no);
}; };
void Value::remove_use(User *user, unsigned arg_no) { void Value::remove_use(User *user, unsigned arg_no) {
if (user == nullptr) return;
auto target_use = Use(user, arg_no); auto target_use = Use(user, arg_no);
use_list_.remove_if([&](const Use &use) { return use == target_use; }); use_list_.remove_if([&](const Use &use) { return use == target_use; });
} }
...@@ -28,7 +30,8 @@ void Value::replace_all_use_with(Value *new_val) const ...@@ -28,7 +30,8 @@ void Value::replace_all_use_with(Value *new_val) const
return; return;
while (!use_list_.empty()) { while (!use_list_.empty()) {
auto use = use_list_.begin(); auto use = use_list_.begin();
use->val_->set_operand(use->arg_no_, new_val); auto val = use->val_;
if (val != nullptr) val->set_operand(use->arg_no_, new_val);
} }
} }
...@@ -38,6 +41,7 @@ void Value::replace_use_with_if(Value *new_val, ...@@ -38,6 +41,7 @@ void Value::replace_use_with_if(Value *new_val,
return; return;
for (auto iter = use_list_.begin(); iter != use_list_.end();) { for (auto iter = use_list_.begin(); iter != use_list_.end();) {
auto &use = *iter++; auto &use = *iter++;
if (use.val_ == nullptr) continue;
if (not should_replace(&use)) if (not should_replace(&use))
continue; continue;
use.val_->set_operand(use.arg_no_, new_val); use.val_->set_operand(use.arg_no_, new_val);
......
...@@ -6,4 +6,4 @@ add_library( ...@@ -6,4 +6,4 @@ add_library(
LoopDetection.cpp LoopDetection.cpp
LICM.cpp LICM.cpp
Mem2Reg.cpp Mem2Reg.cpp
) PassManager.cpp)
\ No newline at end of file \ No newline at end of file
#include "DeadCode.hpp" #include "DeadCode.hpp"
#include "logging.hpp"
#include <queue>
#include <unordered_set>
#include <vector> #include <vector>
#include "FuncInfo.hpp"
#include "logging.hpp"
// 处理流程:两趟处理,mark 标记有用变量,sweep 删除无用指令 // 处理流程:两趟处理,mark 标记有用变量,sweep 删除无用指令
void DeadCode::run() { void DeadCode::run() {
bool changed{}; bool changed;
func_info = new FuncInfo(m_);
func_info->run(); func_info->run();
do { do {
changed = false; changed = false;
for (auto func : m_->get_functions()) { for (auto func : m_->get_functions()) {
changed |= clear_basic_blocks(func); if (remove_bb_) changed |= clear_basic_blocks(func);
mark(func); mark(func);
changed |= sweep(func); changed |= sweep(func);
} }
} while (changed); } while (changed);
LOG_INFO << "dead code pass erased " << ins_count << " instructions"; delete func_info;
func_info = nullptr;
}
static void remove_phi_operand_if_in(PhiInst* inst, const std::unordered_set<BasicBlock*>& in)
{
int opc = static_cast<int>(inst->get_num_operand());
for (int i = opc - 1; i >= 0; i -= 2)
{
auto bb = dynamic_cast<BasicBlock*>(inst->get_operand(i));
if (in.count(bb))
{
inst->remove_operand(i);
inst->remove_operand(i - 1);
}
}
} }
bool DeadCode::clear_basic_blocks(Function *func) { bool DeadCode::clear_basic_blocks(Function *func) {
bool changed = 0; // 已经访问的基本块
std::vector<BasicBlock *> to_erase; std::unordered_set<BasicBlock*> visited;
for (auto bb : func->get_basic_blocks()) { // 还未访问的基本块
if(bb->get_pre_basic_blocks().empty() && bb != func->get_entry_block()) { std::queue<BasicBlock*> toVisit;
to_erase.push_back(bb); toVisit.emplace(func->get_entry_block());
changed = 1; visited.emplace(func->get_entry_block());
while (!toVisit.empty())
{
auto bb = toVisit.front();
toVisit.pop();
for (auto suc : bb->get_succ_basic_blocks())
{
if (!visited.count(suc))
{
visited.emplace(suc);
toVisit.emplace(suc);
} }
} }
for (auto bb : to_erase) { }
// 遍历后剩余的基本块不可达
std::unordered_set<BasicBlock*> erase_set;
std::list<BasicBlock*> erase_list;
for (auto bb : func->get_basic_blocks())
{
if (!visited.count(bb))
{
erase_set.emplace(bb);
erase_list.emplace_back(bb);
}
}
// 删除可达基本块中对不可达基本块中变量的 phi 引用
// 例如 A -> B, B 中具有 phi [val , A], 则要将这一对删除
// 当一个 phi 因为这个原因只剩下一对操作数时, 它可以被消除, 不过这里不管它
for (auto i : func->get_basic_blocks())
{
if (erase_set.count(i)) continue;
for (auto j : i->get_instructions())
{
auto phi = dynamic_cast<PhiInst*>(j);
if (phi == nullptr) break; // 假定所有 phi 都在基本块指令的最前面,见 https://ustc-compiler-2025.github.io/homepage/exp_platform_intro/TA/#%E4%BD%BF%E7%94%A8%E4%B8%A4%E6%AE%B5%E5%8C%96-instruction-list
remove_phi_operand_if_in(phi, erase_set);
}
}
for (auto bb : erase_list) {
bb->erase_from_parent(); bb->erase_from_parent();
delete bb; delete bb;
} }
return changed; return !erase_list.empty();
} }
void DeadCode::mark(Function *func) { void DeadCode::mark(Function *func) {
...@@ -54,7 +113,7 @@ void DeadCode::mark(Function *func) { ...@@ -54,7 +113,7 @@ void DeadCode::mark(Function *func) {
} }
} }
void DeadCode::mark(Instruction *ins) { void DeadCode::mark(const Instruction *ins) {
for (auto op : ins->get_operands()) { for (auto op : ins->get_operands()) {
auto def = dynamic_cast<Instruction *>(op); auto def = dynamic_cast<Instruction *>(op);
if (def == nullptr) if (def == nullptr)
...@@ -69,33 +128,26 @@ void DeadCode::mark(Instruction *ins) { ...@@ -69,33 +128,26 @@ void DeadCode::mark(Instruction *ins) {
} }
bool DeadCode::sweep(Function *func) { bool DeadCode::sweep(Function *func) {
std::unordered_set<Instruction *> wait_del{}; bool rm = false; // changed
std::unordered_set<Instruction *> wait_del;
for (auto bb : func->get_basic_blocks()) { for (auto bb : func->get_basic_blocks()) {
for (auto it = bb->get_instructions().begin(); for (auto inst : bb->get_instructions()) {
it != bb->get_instructions().end();) { if (marked[inst]) continue;
if (marked[*it]) { wait_del.emplace(inst);
++it;
continue;
} else {
wait_del.insert(*it);
it++;
}
} }
bb->get_instructions().remove_if([&wait_del](Instruction* i) -> bool {return wait_del.count(i); });
if (!wait_del.empty()) rm = true;
wait_del.clear();
} }
for (auto inst : wait_del) return rm;
inst->remove_all_operands();
for (auto inst : wait_del)
inst->get_parent()->erase_instr(inst);
ins_count += wait_del.size();
return not wait_del.empty(); // changed
} }
bool DeadCode::is_critical(Instruction *ins) { bool DeadCode::is_critical(Instruction *ins) const {
// 对纯函数的无用调用也可以在删除之列 // 对纯函数的无用调用也可以在删除之列
if (ins->is_call()) { if (ins->is_call()) {
auto call_inst = dynamic_cast<CallInst *>(ins); auto call_inst = dynamic_cast<CallInst *>(ins);
auto callee = dynamic_cast<Function *>(call_inst->get_operand(0)); auto callee = dynamic_cast<Function *>(call_inst->get_operand(0));
if (func_info->is_pure_function(callee)) if (func_info->is_pure(callee))
return false; return false;
return true; return true;
} }
...@@ -106,15 +158,15 @@ bool DeadCode::is_critical(Instruction *ins) { ...@@ -106,15 +158,15 @@ bool DeadCode::is_critical(Instruction *ins) {
return false; return false;
} }
void DeadCode::sweep_globally() { void DeadCode::sweep_globally() const {
std::vector<Function *> unused_funcs; std::vector<Function *> unused_funcs;
std::vector<GlobalVariable *> unused_globals; std::vector<GlobalVariable *> unused_globals;
for (auto f_r : m_->get_functions()) { for (auto f_r : m_->get_functions()) {
if (f_r->get_use_list().size() == 0 and f_r->get_name() != "main") if (f_r->get_use_list().empty() and f_r->get_name() != "main")
unused_funcs.push_back(f_r); unused_funcs.push_back(f_r);
} }
for (auto glob_var_r : m_->get_global_variable()) { for (auto glob_var_r : m_->get_global_variable()) {
if (glob_var_r->get_use_list().size() == 0) if (glob_var_r->get_use_list().empty())
unused_globals.push_back(glob_var_r); unused_globals.push_back(glob_var_r);
} }
// changed |= unused_funcs.size() or unused_globals.size(); // changed |= unused_funcs.size() or unused_globals.size();
......
This diff is collapsed.
This diff is collapsed.
#include "LICM.hpp"
#include <memory>
#include <vector>
#include "BasicBlock.hpp" #include "BasicBlock.hpp"
#include "Constant.hpp"
#include "Function.hpp" #include "Function.hpp"
#include "GlobalVariable.hpp"
#include "Instruction.hpp" #include "Instruction.hpp"
#include "LICM.hpp"
#include "PassManager.hpp" #include "PassManager.hpp"
#include <cstddef>
#include <memory>
#include <vector>
/** /**
* @brief 循环不变式外提Pass的主入口函数 * @brief 循环不变式外提Pass的主入口函数
* *
*/ */
void LoopInvariantCodeMotion::run() { void LoopInvariantCodeMotion::run()
{
loop_detection_ = std::make_unique<LoopDetection>(m_); func_info_ = new FuncInfo(m_);
loop_detection_->run();
func_info_ = std::make_unique<FuncInfo>(m_);
func_info_->run(); func_info_->run();
for (auto &loop : loop_detection_->get_loops()) { for (auto func : m_->get_functions())
is_loop_done_[loop] = false; {
if (func->is_declaration()) continue;
loop_detection_ = new LoopDetection(func);
loop_detection_->run();
for (auto loop : loop_detection_->get_loops())
{
// 遍历处理顶层循环
if (loop->get_parent() == nullptr) traverse_loop(loop);
} }
delete loop_detection_;
for (auto &loop : loop_detection_->get_loops()) { loop_detection_ = nullptr;
traverse_loop(loop);
} }
delete func_info_;
func_info_ = nullptr;
} }
/** /**
...@@ -33,102 +38,147 @@ void LoopInvariantCodeMotion::run() { ...@@ -33,102 +38,147 @@ void LoopInvariantCodeMotion::run() {
* @param loop 当前要处理的循环 * @param loop 当前要处理的循环
* *
*/ */
void LoopInvariantCodeMotion::traverse_loop(std::shared_ptr<Loop> loop) { void LoopInvariantCodeMotion::traverse_loop(Loop* loop)
if (is_loop_done_[loop]) { {
return; // 先外层再内层,这样不用在插入 preheader 后更改循环
} run_on_loop(loop);
is_loop_done_[loop] = true; for (auto sub_loop : loop->get_sub_loops())
for (auto &sub_loop : loop->get_sub_loops()) { {
traverse_loop(sub_loop); traverse_loop(sub_loop);
} }
run_on_loop(loop);
} }
// TODO: 收集并返回循环 store 过的变量
// 例如
// %a = alloca ...
// %b = getelementptr %a ...
// store ... %b
// 则你应该返回 %a 而非 %b
std::unordered_set<Value*> LoopInvariantCodeMotion::collect_loop_store_vars(Loop* loop)
{
// 可能用到
// FuncInfo::store_ptr, FuncInfo::get_stores
throw std::runtime_error("Lab4: 你有一个TODO需要完成!");
}
// TODO: 收集并返回循环中的所有指令
std::vector<Instruction*> LoopInvariantCodeMotion::collect_insts(Loop* loop)
{
throw std::runtime_error("Lab4: 你有一个TODO需要完成!");
}
// TODO: 实现collect_loop_info函数 // TODO: 实现collect_loop_info函数
// 1. 遍历当前循环及其子循环的所有指令 // 1. 遍历当前循环及其子循环的所有指令
// 2. 收集所有指令到loop_instructions中 // 2. 收集所有指令到loop_instructions中
// 3. 检查store指令是否修改了全局变量,如果是则添加到updated_global中 // 3. 检查store指令是否修改了全局变量,如果是则添加到updated_global中
// 4. 检查是否包含非纯函数调用,如果有则设置contains_impure_call为true // 4. 检查是否包含非纯函数调用,如果有则设置contains_impure_call为true
void LoopInvariantCodeMotion::collect_loop_info( void LoopInvariantCodeMotion::collect_loop_info(
std::shared_ptr<Loop> loop, Loop* loop,
std::set<Value *> &loop_instructions, std::set<Value*>& loop_instructions,
std::set<Value *> &updated_global, std::set<Value*>& updated_global,
bool &contains_impure_call) { bool& contains_impure_call)
{
throw std::runtime_error("Lab4: 你有一个TODO需要完成!"); throw std::runtime_error("Lab4: 你有一个TODO需要完成!");
} }
enum InstructionType: std::uint8_t
{
UNKNOWN, VARIANT, INVARIANT
};
/** /**
* @brief 对单个循环执行不变式外提优化 * @brief 对单个循环执行不变式外提优化
* @param loop 要优化的循环 * @param loop 要优化的循环
* *
*/ */
void LoopInvariantCodeMotion::run_on_loop(std::shared_ptr<Loop> loop) { void LoopInvariantCodeMotion::run_on_loop(Loop* loop)
std::set<Value *> loop_instructions; {
std::set<Value *> updated_global; // 循环 store 过的变量
bool contains_impure_call = false; std::unordered_set<Value*> loop_stores_var = collect_loop_store_vars(loop);
collect_loop_info(loop, loop_instructions, updated_global, contains_impure_call); // 循环中的所有指令
std::vector<Instruction*> instructions = collect_insts(loop);
std::vector<Value *> loop_invariant; int insts_count = static_cast<int>(instructions.size());
// 循环的所有基本块
std::unordered_set<BasicBlock*> bbs;
for (auto i : loop->get_blocks()) bbs.emplace(i);
// val 是否在循环内定义,可以当成函数进行调用
auto is_val_in_loop = [&bbs](Value* val)->bool
{
auto inst = dynamic_cast<Instruction*>(val);
if (inst == nullptr) return true;
return bbs.count(inst->get_parent());
};
// inst_type[i] 代表 instructions[i] 是循环变量(每次循环都会变)/ 循环不变量 还是 不知道
std::vector<InstructionType> inst_type;
inst_type.resize(insts_count);
// 遍历后是不是还有指令不知道 InstructionType
bool have_inst_can_not_decide;
// 是否存在 invariant
bool have_invariant = false;
do
{
have_inst_can_not_decide = false;
for (int i = 0; i < insts_count; i++)
{
Instruction* inst = instructions[i];
InstructionType type = inst_type[i];
if (type != UNKNOWN) continue;
// 可能有用的函数
// FuncInfo::load_ptr
// TODO: 识别循环不变式指令 // TODO: 识别循环不变式指令
// // - 将 store、ret、br、phi 等指令与非纯函数调用标记为 VARIANT
// - 如果指令已被标记为不变式则跳过 // - 如果 load 指令加载的变量是循环 store 过的变量,标记为 VARIANT
// - 跳过 store、ret、br、phi 等指令与非纯函数调用 // - 如果指令有 VARIANT 操作数,标记为 VARIANT
// - 特殊处理全局变量的 load 指令 // - 如果指令所有操作数都是 INVARIANT (或者不在循环内),标记为 INVARIANT, 设置 have_invariant
// - 检查指令的所有操作数是否都是循环不变的 // - 否则设置 have_inst_can_not_decide
// - 如果有新的不变式被添加则注意更新 changed 标志,继续迭代
bool changed;
do {
changed = false;
throw std::runtime_error("Lab4: 你有一个TODO需要完成!"); throw std::runtime_error("Lab4: 你有一个TODO需要完成!");
} while (changed);
if (loop->get_preheader() == nullptr) {
loop->set_preheader(
BasicBlock::create(m_, "", loop->get_header()->get_parent()));
} }
}
while (have_inst_can_not_decide);
if (loop_invariant.empty()) if (!have_invariant) return;
return;
// insert preheader auto header = loop->get_header();
auto preheader = loop->get_preheader();
// TODO: 更新 phi 指令 if (header->get_pre_basic_blocks().size() > 1 || header->get_pre_basic_blocks().front()->get_succ_basic_blocks().size() > 1)
for (auto phi_inst_ : loop->get_header()->get_instructions()) { {
if (phi_inst_->get_instr_type() != Instruction::phi) // 插入 preheader
break; auto bb = BasicBlock::create(m_, "", loop->get_header()->get_parent());
loop->set_preheader(bb);
for (auto phi : loop->get_header()->get_instructions())
{
if (phi->get_instr_type() != Instruction::phi) break;
// TODO: 分裂 phi 指令
throw std::runtime_error("Lab4: 你有一个TODO需要完成!"); throw std::runtime_error("Lab4: 你有一个TODO需要完成!");
} }
// TODO: 用跳转指令重构控制流图 // TODO: 维护 bb, header, 与 header 前驱块的基本块关系
// 将所有非 latch 的 header 前驱块的跳转指向 preheader
// 并将 preheader 的跳转指向 header
// 注意这里需要更新前驱块的后继和后继的前驱
std::vector<BasicBlock *> pred_to_remove;
for (auto pred : loop->get_header()->get_pre_basic_blocks()) {
throw std::runtime_error("Lab4: 你有一个TODO需要完成!"); throw std::runtime_error("Lab4: 你有一个TODO需要完成!");
}
for (auto pred : pred_to_remove) { bb->add_instruction(BranchInst::create_br(header, bb));
loop->get_header()->remove_pre_basic_block(pred);
// 若你想维护 LoopDetection 在 LICM 后保持正确
// auto loop2 = loop->get_parent();
// while (loop2 != nullptr)
// {
// loop2->get_parent()->add_block(bb);
// loop2 = loop2->get_parent();
// }
} }
else loop->set_preheader(header->get_pre_basic_blocks().front());
// insert preheader
auto preheader = loop->get_preheader();
auto terminator = preheader->get_instructions().back();
preheader->get_instructions().pop_back();
// TODO: 外提循环不变指令 // TODO: 外提循环不变指令
throw std::runtime_error("Lab4: 你有一个TODO需要完成!"); throw std::runtime_error("Lab4: 你有一个TODO需要完成!");
// insert preheader br to header preheader->add_instruction(terminator);
BranchInst::create_br(loop->get_header(), preheader);
// insert preheader to parent loop
if (loop->get_parent() != nullptr) {
loop->get_parent()->add_block(preheader);
}
} }
#include "LoopDetection.hpp" #include "LoopDetection.hpp"
#include "Dominators.hpp" #include "Dominators.hpp"
#include <memory>
/** using std::set;
* @brief 循环检测Pass的主入口函数 using std::vector;
using std::map;
LoopDetection::~LoopDetection()
{
for (auto loop : loops_) delete loop;
}
/**
* @brief 对单个函数执行循环检测
* *
* 该函数执行以下步骤 * 该函数通过以下步骤检测循环
* 1. 创建支配树分析实例 * 1. 创建支配树分析实例
* 2. 遍历模块中的所有函数 * 2. 运行支配树分析
* 3. 对每个非声明函数执行循环检测 * 3. 按支配树后序遍历所有基本块
* 4. 最后打印检测结果 * 4. 对每个块,检查其前驱是否存在回边
* 5. 如果存在回边,创建新的循环并:
* - 设置循环header
* - 添加latch节点
* - 发现循环体和子循环
* 6. 最后打印检测结果
*/ */
void LoopDetection::run() { void LoopDetection::run() {
dominators_ = std::make_unique<Dominators>(m_); dominators_ = new Dominators(f_);
for (auto f : m_->get_functions()) { dominators_->run();
if (f->is_declaration()) for (auto bb : dominators_->get_dom_post_order()) {
std::set<BasicBlock*> latches;
for (auto pred : bb->get_pre_basic_blocks()) {
if (dominators_->is_dominate(bb, pred)) {
// pred is a back edge
// pred -> bb , pred is the latch node
latches.insert(pred);
}
}
if (latches.empty()) {
continue; continue;
func_ = f; }
run_on_func(f); // create loop
auto loop = new Loop(bb);
bb_to_loop_[bb] = loop;
// add latch nodes
for (auto latch : latches) {
loop->add_latch(latch);
}
loops_.push_back(loop);
discover_loop_and_sub_loops(bb, latches, loop);
} }
print(); print();
delete dominators_;
} }
/** /**
...@@ -28,8 +60,7 @@ void LoopDetection::run() { ...@@ -28,8 +60,7 @@ void LoopDetection::run() {
* @param latches 循环的回边终点(latch)集合 * @param latches 循环的回边终点(latch)集合
* @param loop 当前正在处理的循环对象 * @param loop 当前正在处理的循环对象
*/ */
void LoopDetection::discover_loop_and_sub_loops(BasicBlock *bb, BBset &latches, void LoopDetection::discover_loop_and_sub_loops(BasicBlock *bb, std::set<BasicBlock*>&latches, Loop* loop) {
std::shared_ptr<Loop> loop) {
// TODO List: // TODO List:
// 1. 初始化工作表,将所有latch块加入 // 1. 初始化工作表,将所有latch块加入
// 2. 实现主循环逻辑 // 2. 实现主循环逻辑
...@@ -37,14 +68,14 @@ void LoopDetection::discover_loop_and_sub_loops(BasicBlock *bb, BBset &latches, ...@@ -37,14 +68,14 @@ void LoopDetection::discover_loop_and_sub_loops(BasicBlock *bb, BBset &latches,
// 4. 处理已属于其他循环的节点 // 4. 处理已属于其他循环的节点
// 5. 建立正确的循环嵌套关系 // 5. 建立正确的循环嵌套关系
BBvec work_list = {latches.begin(), latches.end()}; // 初始化工作表 std::vector<BasicBlock*> work_list = {latches.begin(), latches.end()}; // 初始化工作表
while (!work_list.empty()) { // 当工作表非空时继续处理 while (!work_list.empty()) { // 当工作表非空时继续处理
auto bb = work_list.back(); auto bb2 = work_list.back();
work_list.pop_back(); work_list.pop_back();
// TODO-1: 处理未分配给任何循环的节点 // TODO-1: 处理未分配给任何循环的节点
if (bb_to_loop_.find(bb) == bb_to_loop_.end()) { if (bb_to_loop_.find(bb2) == bb_to_loop_.end()) {
/* 在此添加代码: /* 在此添加代码:
* 1. 使用loop->add_block将bb加入当前循环 * 1. 使用loop->add_block将bb加入当前循环
* 2. 更新bb_to_loop_映射 * 2. 更新bb_to_loop_映射
...@@ -54,7 +85,7 @@ void LoopDetection::discover_loop_and_sub_loops(BasicBlock *bb, BBset &latches, ...@@ -54,7 +85,7 @@ void LoopDetection::discover_loop_and_sub_loops(BasicBlock *bb, BBset &latches,
} }
// TODO-2: 处理已属于其他循环的节点 // TODO-2: 处理已属于其他循环的节点
else if (bb_to_loop_[bb] != loop) { if (bb_to_loop_[bb2] != loop) {
/* 在此添加代码: /* 在此添加代码:
* 1. 获取bb当前所属的循环sub_loop * 1. 获取bb当前所属的循环sub_loop
* 2. 找到sub_loop的最顶层父循环 * 2. 找到sub_loop的最顶层父循环
...@@ -62,7 +93,7 @@ void LoopDetection::discover_loop_and_sub_loops(BasicBlock *bb, BBset &latches, ...@@ -62,7 +93,7 @@ void LoopDetection::discover_loop_and_sub_loops(BasicBlock *bb, BBset &latches,
* 4. 建立循环嵌套关系: * 4. 建立循环嵌套关系:
* - 设置父循环 * - 设置父循环
* - 添加子循环 * - 添加子循环
* 5. 将子循环的所有基本加入到父循环中 * 5. 将子循环的所有基本加入到父循环中
* 6. 将子循环header的前驱加入工作表 * 6. 将子循环header的前驱加入工作表
*/ */
...@@ -72,44 +103,6 @@ void LoopDetection::discover_loop_and_sub_loops(BasicBlock *bb, BBset &latches, ...@@ -72,44 +103,6 @@ void LoopDetection::discover_loop_and_sub_loops(BasicBlock *bb, BBset &latches,
} }
} }
/**
* @brief 对单个函数执行循环检测
* @param f 要分析的函数
*
* 该函数通过以下步骤检测循环:
* 1. 运行支配树分析
* 2. 按支配树后序遍历所有基本块
* 3. 对每个块,检查其前驱是否存在回边
* 4. 如果存在回边,创建新的循环并:
* - 设置循环header
* - 添加latch节点
* - 发现循环体和子循环
*/
void LoopDetection::run_on_func(Function *f) {
dominators_->run_on_func(f);
for (auto bb : dominators_->get_dom_post_order()) {
BBset latches;
for (auto pred : bb->get_pre_basic_blocks()) {
if (dominators_->is_dominate(bb, pred)) {
// pred is a back edge
// pred -> bb , pred is the latch node
latches.insert(pred);
}
}
if (latches.empty()) {
continue;
}
// create loop
auto loop = std::make_shared<Loop>(bb);
bb_to_loop_[bb] = loop;
// add latch nodes
for (auto latch : latches) {
loop->add_latch(latch);
}
loops_.push_back(loop);
discover_loop_and_sub_loops(bb, latches, loop);
}
}
/** /**
* @brief 打印循环检测的结果 * @brief 打印循环检测的结果
...@@ -119,21 +112,21 @@ void LoopDetection::run_on_func(Function *f) { ...@@ -119,21 +112,21 @@ void LoopDetection::run_on_func(Function *f) {
* 2. 循环包含的所有基本块 * 2. 循环包含的所有基本块
* 3. 循环的所有子循环 * 3. 循环的所有子循环
*/ */
void LoopDetection::print() { void LoopDetection::print() const {
m_->set_print_name(); f_->get_parent()->set_print_name();
std::cerr << "Loop Detection Result:" << std::endl; std::cerr << "Loop Detection Result:\n";
for (auto &loop : loops_) { for (auto &loop : loops_) {
std::cerr << "Loop header: " << loop->get_header()->get_name() std::cerr << "Loop header: " << loop->get_header()->get_name()
<< std::endl; << '\n';
std::cerr << "Loop blocks: "; std::cerr << "Loop blocks: ";
for (auto bb : loop->get_blocks()) { for (auto bb : loop->get_blocks()) {
std::cerr << bb->get_name() << " "; std::cerr << bb->get_name() << " ";
} }
std::cerr << std::endl; std::cerr << '\n';
std::cerr << "Sub loops: "; std::cerr << "Sub loops: ";
for (auto &sub_loop : loop->get_sub_loops()) { for (auto &sub_loop : loop->get_sub_loops()) {
std::cerr << sub_loop->get_header()->get_name() << " "; std::cerr << sub_loop->get_header()->get_name() << " ";
} }
std::cerr << std::endl; std::cerr << '\n';
} }
} }
\ No newline at end of file
#include "Mem2Reg.hpp" #include "Mem2Reg.hpp"
#include "IRBuilder.hpp" #include "IRBuilder.hpp"
#include "Value.hpp" #include "Value.hpp"
#include <memory>
// l_val 是否是非数组 alloca 变量
static bool is_not_array_alloca(Value* l_val)
{
auto alloca = dynamic_cast<AllocaInst*>(l_val);
return alloca != nullptr && !alloca->get_alloca_type()->is_array_type();
}
/** /**
* @brief Mem2Reg Pass的主入口函数 * @brief Mem2Reg Pass的主入口函数
...@@ -18,23 +25,30 @@ ...@@ -18,23 +25,30 @@
* 注意:函数执行后,冗余的局部变量分配指令将由后续的死代码删除Pass处理 * 注意:函数执行后,冗余的局部变量分配指令将由后续的死代码删除Pass处理
*/ */
void Mem2Reg::run() { void Mem2Reg::run() {
// 创建支配树分析 Pass 的实例
dominators_ = std::make_unique<Dominators>(m_);
// 建立支配树
dominators_->run();
// 以函数为单元遍历实现 Mem2Reg 算法 // 以函数为单元遍历实现 Mem2Reg 算法
for (auto f : m_->get_functions()) { for (auto f : m_->get_functions()) {
if (f->is_declaration()) if (f->is_declaration())
continue; continue;
func_ = f; func_ = f;
// 创建 func_ 支配树
dominators_ = new Dominators(func_);
// 建立支配树
dominators_->run();
allocas_.clear();
var_val_stack.clear(); var_val_stack.clear();
phi_lval.clear(); phi_to_alloca_.clear();
if (func_->get_basic_blocks().size() >= 1) { bb_to_phi_.clear();
if (!func_->get_basic_blocks().empty()) {
// 对应伪代码中 phi 指令插入的阶段 // 对应伪代码中 phi 指令插入的阶段
generate_phi(); generate_phi();
// 确保每个局部变量的栈都有初始值
for (auto var : allocas_)
var_val_stack[var].emplace_back(ConstantZero::get(var->get_alloca_type(), m_));
// 对应伪代码中重命名阶段 // 对应伪代码中重命名阶段
rename(func_->get_entry_block()); rename(func_->get_entry_block());
} }
delete dominators_;
dominators_ = nullptr;
// 后续 DeadCode 将移除冗余的局部变量的分配空间 // 后续 DeadCode 将移除冗余的局部变量的分配空间
} }
} }
...@@ -55,34 +69,41 @@ void Mem2Reg::run() { ...@@ -55,34 +69,41 @@ void Mem2Reg::run() {
* phi指令的插入遵循最小化原则,只在必要的位置插入phi节点 * phi指令的插入遵循最小化原则,只在必要的位置插入phi节点
*/ */
void Mem2Reg::generate_phi() { void Mem2Reg::generate_phi() {
// global_live_var_name 是全局名字集合,以 alloca 出的局部变量来统计。 // 步骤一:找到活跃在多个 block 的名字集合,以及它们所属的 bb 块
// 步骤一:找到活跃在多个 block 的全局名字集合,以及它们所属的 bb 块
std::set<Value *> global_live_var_name; // global_live_var_name 包括函数中所有非数组 alloca 变量
std::map<Value *, std::set<BasicBlock *>> live_var_2blocks; std::set<AllocaInst *> not_array_allocas;
// 每个 alloca 在什么基本块被 store (可能重复)
std::map<AllocaInst*, std::list<BasicBlock *>> allocas_stored_bbs;
for (auto bb : func_->get_basic_blocks()) { for (auto bb : func_->get_basic_blocks()) {
std::set<Value *> var_is_killed;
for (auto instr : bb->get_instructions()) { for (auto instr : bb->get_instructions()) {
if (instr->is_store()) { if (instr->is_store()) {
// store i32 a, i32 *b // store i32 a, i32 *b
// a is r_val, b is l_val // a is r_val, b is l_val
auto l_val = static_cast<StoreInst *>(instr)->get_lval(); auto l_val = dynamic_cast<StoreInst *>(instr)->get_lval();
if (is_valid_ptr(l_val)) { if (is_not_array_alloca(l_val)) {
global_live_var_name.insert(l_val); auto lalloca = dynamic_cast<AllocaInst*>(instr);
live_var_2blocks[l_val].insert(bb); not_array_allocas.insert(lalloca);
allocas_.emplace_back(lalloca);
allocas_stored_bbs[lalloca].emplace_back(bb);
} }
} }
} }
} }
// 步骤二:从支配树获取支配边界信息,并在对应位置插入 phi 指令 // 步骤二:从支配树获取支配边界信息,并在对应位置插入 phi 指令
std::map<std::pair<BasicBlock *, Value *>, bool>
bb_has_var_phi; // bb has phi for var // 基本块是否已经有了对特定 alloca 变量的 phi
for (auto var : global_live_var_name) { std::set<std::pair<BasicBlock *, AllocaInst *>> bb_has_var_phi;
for (auto var : not_array_allocas) {
std::vector<BasicBlock *> work_list; std::vector<BasicBlock *> work_list;
work_list.assign(live_var_2blocks[var].begin(), std::set<BasicBlock*> already_handled;
live_var_2blocks[var].end()); work_list.assign(allocas_stored_bbs[var].begin(), allocas_stored_bbs[var].end());
for (unsigned i = 0; i < work_list.size(); i++) { for (unsigned i = 0; i < work_list.size(); i++) {
auto bb = work_list[i]; auto bb = work_list[i];
// 防止在同一基本块重复运行
if (already_handled.count(bb)) continue;
already_handled.emplace(bb);
for (auto bb_dominance_frontier_bb : for (auto bb_dominance_frontier_bb :
dominators_->get_dominance_frontier(bb)) { dominators_->get_dominance_frontier(bb)) {
if (bb_has_var_phi.find({bb_dominance_frontier_bb, var}) == if (bb_has_var_phi.find({bb_dominance_frontier_bb, var}) ==
...@@ -92,10 +113,11 @@ void Mem2Reg::generate_phi() { ...@@ -92,10 +113,11 @@ void Mem2Reg::generate_phi() {
auto phi = PhiInst::create_phi( auto phi = PhiInst::create_phi(
var->get_type()->get_pointer_element_type(), var->get_type()->get_pointer_element_type(),
bb_dominance_frontier_bb); bb_dominance_frontier_bb);
phi_lval.emplace(phi, var); phi_to_alloca_.emplace(phi, var);
bb_to_phi_[bb_dominance_frontier_bb].emplace_back(phi);
bb_dominance_frontier_bb->add_instr_begin(phi); bb_dominance_frontier_bb->add_instr_begin(phi);
work_list.push_back(bb_dominance_frontier_bb); work_list.push_back(bb_dominance_frontier_bb);
bb_has_var_phi[{bb_dominance_frontier_bb, var}] = true; bb_has_var_phi.emplace(bb_dominance_frontier_bb, var);
} }
} }
} }
...@@ -103,17 +125,23 @@ void Mem2Reg::generate_phi() { ...@@ -103,17 +125,23 @@ void Mem2Reg::generate_phi() {
} }
void Mem2Reg::rename(BasicBlock *bb) { void Mem2Reg::rename(BasicBlock *bb) {
std::vector<Instruction *> wait_delete; // 可能用到的数据结构
// TODO // list<AllocaInst*> allocas_ 所有 Mem2Reg 需要消除的局部变量,用于遍历
// 步骤一:将 phi 指令作为 lval 的最新定值,lval 即是为局部变量 alloca 出的地址空间 // map<AllocaInst*,vector<Value *>> var_val_stack 每个局部变量的存储值栈,还未进行任何操作时已经存进去了 0,不会为空
// 步骤二:用 lval 最新的定值替代对应的load指令 // map<PhiInst *, AllocaInst*>; Phi 对应的局部变量
// 步骤三:将 store 指令的 rval,也即被存入内存的值,作为 lval 的最新定值 // map<BasicBlock*, list<PhiInst*>>; 在某个基本块的 Phi
// 步骤四:为 lval 对应的 phi 指令参数补充完整 // 可能用到的函数
// 步骤五:对 bb 在支配树上的所有后继节点,递归执行 re_name 操作 // Value::replace_all_use_with(Value* a) 将所有用到 this 的指令对应 this 的操作数都替换为 a
// 步骤六:pop出 lval 的最新定值 // BasicBlock::erase_instrs(set<Instruction*>) 移除并 delete 列表中的指令
// 步骤七:清除冗余的指令 // TODO
for (auto instr : wait_delete) { // 步骤一:对每个 alloca 非数组变量(局部变量), 在其存储值栈存入其当前的最新值(也就是目前的栈顶值)
bb->erase_instr(instr); // 步骤二:遍历基本块所有指令,执行操作并记录需要删除的 load/store/alloca 指令
} // - 步骤三: 将 store 指令存储的值,作为其对应局部变量的最新值(更新栈顶)
// - 步骤四: 将 load 指令的所有使用替换为其读取的局部变量的最新值
// 步骤五:为所有后继块的 phi 添加参数
// 步骤六:对 bb 在支配树上的所有后继节点,递归执行 rename 操作
// 步骤七:pop 出所有局部变量的最新值
// 步骤八:删除需要删除的冗余指令
} }
#include "PassManager.hpp"
TransformPass::~TransformPass() = default;
ModuleAnalysisPass::~ModuleAnalysisPass() = default;
FunctionAnalysisPass::~FunctionAnalysisPass() = default;
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment