From 56317aae8d10a6624df1ca6528113217e7617d37 Mon Sep 17 00:00:00 2001 From: lxq <877250099@qq.com> Date: Thu, 2 Mar 2023 15:26:16 +0800 Subject: [PATCH] add LoopUnroll code --- include/optimization/ExceptCallMerge.hpp | 36 ++++++++ include/optimization/LoopUnroll.hpp | 57 ++++++++++++ src/cminusfc/cminusfc.cpp | 15 ++- src/codegen/codegen.cpp | 5 +- src/codegen/liverange.cpp | 7 ++ src/optimization/CMakeLists.txt | 2 + src/optimization/ExceptCallMerge.cpp | 69 ++++++++++++++ src/optimization/LoopUnroll.cpp | 111 +++++++++++++++++++++++ 8 files changed, 296 insertions(+), 6 deletions(-) create mode 100644 include/optimization/ExceptCallMerge.hpp create mode 100644 include/optimization/LoopUnroll.hpp create mode 100644 src/optimization/ExceptCallMerge.cpp create mode 100644 src/optimization/LoopUnroll.cpp diff --git a/include/optimization/ExceptCallMerge.hpp b/include/optimization/ExceptCallMerge.hpp new file mode 100644 index 0000000..1deef3d --- /dev/null +++ b/include/optimization/ExceptCallMerge.hpp @@ -0,0 +1,36 @@ +#ifndef EXCEPTCALLMERGE_HPP +#define EXCEPTCALLMERGE_HPP +#include "BasicBlock.h" +#include "Function.h" +#include "Module.h" +#include "PassManager.hpp" + +#include +#include +#include +#include + +using std::cout; +using std::endl; +using std::set; +using std::string; +using std::to_string; +using std::vector; + +// goal: each function share one block calling neg_idx_except +class NegCallMerge : public Pass { + public: + NegCallMerge(Module *_m); + NegCallMerge() = delete; + + void run() override { + for (auto &f : m_->get_functions()) { + run(&f); + } + } + + private: + Function *neg_func; + void run(Function *f); +}; +#endif diff --git a/include/optimization/LoopUnroll.hpp b/include/optimization/LoopUnroll.hpp new file mode 100644 index 0000000..2191099 --- /dev/null +++ b/include/optimization/LoopUnroll.hpp @@ -0,0 +1,57 @@ + +#ifndef LOOPUNROLL_HPP +#define LOOPUNROLL_HPP +#include "BasicBlock.h" +#include "Module.h" +#include "PassManager.hpp" + +#include +#include +#include + +using std::cout; +using std::endl; +using std::string; +using std::to_string; +using std::vector; + +namespace Graph { + +using Edge = std::pair; +using BackEdgeList = vector; +using SimpleLoop = vector; +} + + +/* This is a class to unroll simple loops: + * - strict structure: + * --a->b->c--- + * ^-----+ + * - if the loop has constant times, unroll it. + */ +class LoopUnroll : public Pass { + public: + LoopUnroll(Module *_m) : Pass(_m) { + for (auto &f : m_->get_functions()) + if (f.get_name() == "neg_idx_except") { + neg_func = &f; + break; + } + if (neg_func == nullptr) + assert(false && "find function neg_idx_except first!"); + } + + LoopUnroll() = delete; + + void run() override; + static string str(const Graph::Edge &edge) { + return "(" + edge.first->get_name() + ", " + edge.second->get_name() + + ")"; + } + + private: + Function *neg_func; + Graph::BackEdgeList detect_back(Function *); + vector check_sloops(const Graph::BackEdgeList &) const; +}; +#endif diff --git a/src/cminusfc/cminusfc.cpp b/src/cminusfc/cminusfc.cpp index bb3df2b..92c0fc6 100644 --- a/src/cminusfc/cminusfc.cpp +++ b/src/cminusfc/cminusfc.cpp @@ -5,6 +5,8 @@ #include "GVN.h" // #include "LoopInvHoist.hpp" // #include "LoopSearch.hpp" +#include "ExceptCallMerge.hpp" +#include "LoopUnroll.hpp" #include "Mem2Reg.hpp" #include "PassManager.hpp" #include "cminusf_builder.hpp" @@ -35,6 +37,7 @@ main(int argc, char **argv) { bool dump_json = false; bool emit = false; bool assembly = false; + bool loopunroll = false; for (int i = 1; i < argc; ++i) { if (argv[i] == "-h"s || argv[i] == "--help"s) { @@ -58,6 +61,8 @@ main(int argc, char **argv) { mem2reg = true; } else if (argv[i] == "-gvn"s) { gvn = true; + } else if (argv[i] == "-loopunroll"s) { + loopunroll = true; } else if (argv[i] == "-dump-json"s) { dump_json = true; } else { @@ -109,21 +114,25 @@ main(int argc, char **argv) { PassManager PM(m.get()); if (mem2reg) { - PM.add_pass(emit); + PM.add_pass(false); } if (gvn) { PM.add_pass(false); // remove some undef - PM.add_pass(emit, dump_json); + PM.add_pass(false, dump_json); PM.add_pass( false); // delete unused instructions created by GVN } + if (loopunroll) { + PM.add_pass(false); + PM.add_pass(false); + } + m->set_print_name(); PM.run(); auto IR = m->print(); if (assembly) { - CodeGen codegen(m.get()); codegen.run(); std::ofstream target_file(target_path + ".s"); diff --git a/src/codegen/codegen.cpp b/src/codegen/codegen.cpp index dae3ed3..a4b2c40 100644 --- a/src/codegen/codegen.cpp +++ b/src/codegen/codegen.cpp @@ -78,7 +78,6 @@ CodeGen::getPhiMap() { void CodeGen::run() { - // 以下内容生成 int main() { return 0; } 的汇编代码 getPhiMap(); output.push_back(".text"); // global variables @@ -583,11 +582,11 @@ void CodeGen::IR2assem(CallInst *instr) { auto func = static_cast(instr->get_operand(0)); auto func_argN = func_arg_N.at(func); - // analyze the registers that need to be stored + // int cur_i = LRA.get_instr_id().at(instr); auto regmap_int = RA_int.get(); auto regmap_float = RA_float.get(); - // + // analyze the registers that need to be stored int storeN = 0; vector> store_record; for (auto [op, interval] : LRA.get_interval_map()) { diff --git a/src/codegen/liverange.cpp b/src/codegen/liverange.cpp index 015af88..fc9ef91 100644 --- a/src/codegen/liverange.cpp +++ b/src/codegen/liverange.cpp @@ -27,8 +27,11 @@ LiveRangeAnalyzer::joinFor(BasicBlock *bb) { for (auto succ : bb->get_succ_basic_blocks()) { auto &irs = succ->get_instructions(); auto it = irs.begin(); + cout << succ->get_name() << endl; while (it != irs.end() and it->is_phi()) ++it; + /* if (it == irs.end()) + * cout << succ->print() << endl; */ assert(it != irs.end() && "need to find first_ir from copy-stmt"); union_ip(out, IN[instr_id.at(&(*it))]); // cout << "# " + it->print() << endl; @@ -82,6 +85,10 @@ LiveRangeAnalyzer::get_dfs_order(Function *func) { for (auto succ : bb->get_succ_basic_blocks()) Q.push_front(succ); } + cout << func->get_name() << "'s dfs order:\n\t"; + for (auto bb : BB_DFS_Order) + cout << bb->get_name() << " "; + cout << endl; } void diff --git a/src/optimization/CMakeLists.txt b/src/optimization/CMakeLists.txt index 8e969b3..73ea2d6 100644 --- a/src/optimization/CMakeLists.txt +++ b/src/optimization/CMakeLists.txt @@ -3,4 +3,6 @@ add_library( Dominators.cpp Mem2Reg.cpp GVN.cpp + LoopUnroll.cpp + ExceptCallMerge.cpp ) diff --git a/src/optimization/ExceptCallMerge.cpp b/src/optimization/ExceptCallMerge.cpp new file mode 100644 index 0000000..3427b69 --- /dev/null +++ b/src/optimization/ExceptCallMerge.cpp @@ -0,0 +1,69 @@ +#include "ExceptCallMerge.hpp" + +#include "Function.h" +#include "Instruction.h" + +#include + +using std::find; + +NegCallMerge::NegCallMerge(Module *_m) : Pass(_m), neg_func(nullptr) { + for (auto &f : m_->get_functions()) + if (f.get_name() == "neg_idx_except") { + neg_func = &f; + break; + } + if (neg_func == nullptr) + assert(false && "find function neg_idx_except first!"); +} + +void +NegCallMerge::run(Function *func) { + BasicBlock *reserved = nullptr; + set calls; + auto &blocks = func->get_basic_blocks(); + // check bb + for (auto &bb : blocks) { + if (bb.get_instructions().size() != 2) + continue; + auto instr = &*bb.get_instructions().begin(); + if (instr->is_call() and instr->get_operand(0) == neg_func) { + calls.insert(&bb); + if (reserved == nullptr) + reserved = &bb; + } + } + // BasicBlock redirect + for (auto &bb : blocks) { + for (auto &instr : bb.get_instructions()) { + if (not instr.is_br()) + continue; + auto br = static_cast(&instr); + vector idx; + if (br->is_cond_br()) { + idx = {1, 2}; + } else { + idx = {0}; + } + for (auto i : idx) + if (calls.find(br->get_operand(i)) != calls.end()) + br->get_operands()[i] = reserved; + } + } + // remove useless BasicBlocks + cout << "remove blocks for function " << func->get_name() << endl; + for (auto _bb : calls) { + auto bb = static_cast(_bb); + if (bb != reserved) { + cout << "remove block " << bb->get_name() << endl; + auto it = blocks.begin(); + for (; &*it != bb; ++it) + ; + blocks.erase(it); + assert(bb->get_pre_basic_blocks().size() == 1); + (*bb->get_pre_basic_blocks().begin()) + ->get_succ_basic_blocks() + .remove(bb); + } + } +} diff --git a/src/optimization/LoopUnroll.cpp b/src/optimization/LoopUnroll.cpp new file mode 100644 index 0000000..6fe3f0b --- /dev/null +++ b/src/optimization/LoopUnroll.cpp @@ -0,0 +1,111 @@ +#include "LoopUnroll.hpp" + +#include "BasicBlock.h" +#include "Function.h" +#include "Instruction.h" +#include "syntax_analyzer.h" + +#include +#include + +using std::find; +using std::map; + +using namespace Graph; + +struct BackEdgeSearcher { + BackEdgeSearcher(BasicBlock *entry) { dfsrun(entry); } + + void dfsrun(BasicBlock *bb) { + vis[bb] = true; + path.push_back(bb); + for (auto succ : bb->get_succ_basic_blocks()) { + if (vis[succ]) { + string type; + Edge edge(bb, succ); + if (find(path.rbegin(), path.rend(), succ) == path.rend()) { + type = "cross-edge"; + } else { + type = "back-edge"; + edges.push_back(edge); + } + cout << "find " << type << ": " << LoopUnroll::str(edge) + << "\n"; + } else + dfsrun(succ); + } + + path.pop_back(); + } + + vector path; + map vis; + BackEdgeList edges; +}; + +BackEdgeList +LoopUnroll::detect_back(Function *func) { + BackEdgeSearcher search(func->get_entry_block()); + return search.edges; +} + +vector +LoopUnroll::check_sloops(const BackEdgeList &belist) const { + vector sloops; + for (auto [e, b] : belist) { + SimpleLoop sl; + if (b->get_succ_basic_blocks().size() != 2 or + b->get_pre_basic_blocks().size() != 2) + continue; + bool flag = true; + // the start node should have 2*2 degree, and others have 1*1 degree + // one exception: br to neg_idx_except + for (auto p = e; p != b; p = *p->get_pre_basic_blocks().begin()) { + if (p->get_pre_basic_blocks().size() != 1) { + flag = false; + break; + } + auto succ_bbs = p->get_succ_basic_blocks(); + if (succ_bbs.size() != 1) { + assert(succ_bbs.size() == 2); + flag = false; + for (auto bb : succ_bbs) { + auto instr = &*bb->get_instructions().begin(); + if (instr->is_call() and + instr->get_operand(0) == neg_func) { + flag = true; + break; + } + } + } + if (not flag) + break; + sl.insert(sl.begin(), p); + } + + if (not flag) + continue; + sl.insert(sl.begin(), b); + if (flag) + sloops.push_back(sl); + } + return sloops; +} +void +LoopUnroll::run() { + for (auto &_f : m_->get_functions()) { + if (_f.is_declaration()) + continue; + auto func = &_f; + // cout << func->get_name() << endl; + auto belist = detect_back(func); + auto sloops = check_sloops(belist); + cout << "get simple loops for function " << func->get_name() << ":\n"; + for (auto sl : sloops) { + cout << "\t"; + for (auto p : sl) + cout << p->get_name() << " "; + cout << "\n"; + } + } +} -- GitLab