Commit 56317aae authored by lxq's avatar lxq

add LoopUnroll code

parent 3e52250f
#ifndef EXCEPTCALLMERGE_HPP
#define EXCEPTCALLMERGE_HPP
#include "BasicBlock.h"
#include "Function.h"
#include "Module.h"
#include "PassManager.hpp"
#include <ostream>
#include <set>
#include <string>
#include <vector>
using std::cout;
using std::endl;
using std::set;
using std::string;
using std::to_string;
using std::vector;
// goal: each function share one block calling neg_idx_except
class NegCallMerge : public Pass {
public:
NegCallMerge(Module *_m);
NegCallMerge() = delete;
void run() override {
for (auto &f : m_->get_functions()) {
run(&f);
}
}
private:
Function *neg_func;
void run(Function *f);
};
#endif
#ifndef LOOPUNROLL_HPP
#define LOOPUNROLL_HPP
#include "BasicBlock.h"
#include "Module.h"
#include "PassManager.hpp"
#include <ostream>
#include <string>
#include <vector>
using std::cout;
using std::endl;
using std::string;
using std::to_string;
using std::vector;
namespace Graph {
using Edge = std::pair<BasicBlock *, BasicBlock *>;
using BackEdgeList = vector<Edge>;
using SimpleLoop = vector<BasicBlock *>;
}
/* This is a class to unroll simple loops:
* - strict structure:
* --a->b->c---
* ^-----+
* - if the loop has constant times, unroll it.
*/
class LoopUnroll : public Pass {
public:
LoopUnroll(Module *_m) : Pass(_m) {
for (auto &f : m_->get_functions())
if (f.get_name() == "neg_idx_except") {
neg_func = &f;
break;
}
if (neg_func == nullptr)
assert(false && "find function neg_idx_except first!");
}
LoopUnroll() = delete;
void run() override;
static string str(const Graph::Edge &edge) {
return "(" + edge.first->get_name() + ", " + edge.second->get_name() +
")";
}
private:
Function *neg_func;
Graph::BackEdgeList detect_back(Function *);
vector<Graph::SimpleLoop> check_sloops(const Graph::BackEdgeList &) const;
};
#endif
......@@ -5,6 +5,8 @@
#include "GVN.h"
// #include "LoopInvHoist.hpp"
// #include "LoopSearch.hpp"
#include "ExceptCallMerge.hpp"
#include "LoopUnroll.hpp"
#include "Mem2Reg.hpp"
#include "PassManager.hpp"
#include "cminusf_builder.hpp"
......@@ -35,6 +37,7 @@ main(int argc, char **argv) {
bool dump_json = false;
bool emit = false;
bool assembly = false;
bool loopunroll = false;
for (int i = 1; i < argc; ++i) {
if (argv[i] == "-h"s || argv[i] == "--help"s) {
......@@ -58,6 +61,8 @@ main(int argc, char **argv) {
mem2reg = true;
} else if (argv[i] == "-gvn"s) {
gvn = true;
} else if (argv[i] == "-loopunroll"s) {
loopunroll = true;
} else if (argv[i] == "-dump-json"s) {
dump_json = true;
} else {
......@@ -109,21 +114,25 @@ main(int argc, char **argv) {
PassManager PM(m.get());
if (mem2reg) {
PM.add_pass<Mem2Reg>(emit);
PM.add_pass<Mem2Reg>(false);
}
if (gvn) {
PM.add_pass<DeadCode>(false); // remove some undef
PM.add_pass<GVN>(emit, dump_json);
PM.add_pass<GVN>(false, dump_json);
PM.add_pass<DeadCode>(
false); // delete unused instructions created by GVN
}
if (loopunroll) {
PM.add_pass<NegCallMerge>(false);
PM.add_pass<LoopUnroll>(false);
}
m->set_print_name();
PM.run();
auto IR = m->print();
if (assembly) {
CodeGen codegen(m.get());
codegen.run();
std::ofstream target_file(target_path + ".s");
......
......@@ -78,7 +78,6 @@ CodeGen::getPhiMap() {
void
CodeGen::run() {
// 以下内容生成 int main() { return 0; } 的汇编代码
getPhiMap();
output.push_back(".text");
// global variables
......@@ -583,11 +582,11 @@ void
CodeGen::IR2assem(CallInst *instr) {
auto func = static_cast<Function *>(instr->get_operand(0));
auto func_argN = func_arg_N.at(func);
// analyze the registers that need to be stored
//
int cur_i = LRA.get_instr_id().at(instr);
auto regmap_int = RA_int.get();
auto regmap_float = RA_float.get();
//
// analyze the registers that need to be stored
int storeN = 0;
vector<std::tuple<Value *, string, int>> store_record;
for (auto [op, interval] : LRA.get_interval_map()) {
......
......@@ -27,8 +27,11 @@ LiveRangeAnalyzer::joinFor(BasicBlock *bb) {
for (auto succ : bb->get_succ_basic_blocks()) {
auto &irs = succ->get_instructions();
auto it = irs.begin();
cout << succ->get_name() << endl;
while (it != irs.end() and it->is_phi())
++it;
/* if (it == irs.end())
* cout << succ->print() << endl; */
assert(it != irs.end() && "need to find first_ir from copy-stmt");
union_ip(out, IN[instr_id.at(&(*it))]);
// cout << "# " + it->print() << endl;
......@@ -82,6 +85,10 @@ LiveRangeAnalyzer::get_dfs_order(Function *func) {
for (auto succ : bb->get_succ_basic_blocks())
Q.push_front(succ);
}
cout << func->get_name() << "'s dfs order:\n\t";
for (auto bb : BB_DFS_Order)
cout << bb->get_name() << " ";
cout << endl;
}
void
......
......@@ -3,4 +3,6 @@ add_library(
Dominators.cpp
Mem2Reg.cpp
GVN.cpp
LoopUnroll.cpp
ExceptCallMerge.cpp
)
#include "ExceptCallMerge.hpp"
#include "Function.h"
#include "Instruction.h"
#include <cstddef>
using std::find;
NegCallMerge::NegCallMerge(Module *_m) : Pass(_m), neg_func(nullptr) {
for (auto &f : m_->get_functions())
if (f.get_name() == "neg_idx_except") {
neg_func = &f;
break;
}
if (neg_func == nullptr)
assert(false && "find function neg_idx_except first!");
}
void
NegCallMerge::run(Function *func) {
BasicBlock *reserved = nullptr;
set<Value *> calls;
auto &blocks = func->get_basic_blocks();
// check bb
for (auto &bb : blocks) {
if (bb.get_instructions().size() != 2)
continue;
auto instr = &*bb.get_instructions().begin();
if (instr->is_call() and instr->get_operand(0) == neg_func) {
calls.insert(&bb);
if (reserved == nullptr)
reserved = &bb;
}
}
// BasicBlock redirect
for (auto &bb : blocks) {
for (auto &instr : bb.get_instructions()) {
if (not instr.is_br())
continue;
auto br = static_cast<BranchInst *>(&instr);
vector<int> idx;
if (br->is_cond_br()) {
idx = {1, 2};
} else {
idx = {0};
}
for (auto i : idx)
if (calls.find(br->get_operand(i)) != calls.end())
br->get_operands()[i] = reserved;
}
}
// remove useless BasicBlocks
cout << "remove blocks for function " << func->get_name() << endl;
for (auto _bb : calls) {
auto bb = static_cast<BasicBlock *>(_bb);
if (bb != reserved) {
cout << "remove block " << bb->get_name() << endl;
auto it = blocks.begin();
for (; &*it != bb; ++it)
;
blocks.erase(it);
assert(bb->get_pre_basic_blocks().size() == 1);
(*bb->get_pre_basic_blocks().begin())
->get_succ_basic_blocks()
.remove(bb);
}
}
}
#include "LoopUnroll.hpp"
#include "BasicBlock.h"
#include "Function.h"
#include "Instruction.h"
#include "syntax_analyzer.h"
#include <map>
#include <vector>
using std::find;
using std::map;
using namespace Graph;
struct BackEdgeSearcher {
BackEdgeSearcher(BasicBlock *entry) { dfsrun(entry); }
void dfsrun(BasicBlock *bb) {
vis[bb] = true;
path.push_back(bb);
for (auto succ : bb->get_succ_basic_blocks()) {
if (vis[succ]) {
string type;
Edge edge(bb, succ);
if (find(path.rbegin(), path.rend(), succ) == path.rend()) {
type = "cross-edge";
} else {
type = "back-edge";
edges.push_back(edge);
}
cout << "find " << type << ": " << LoopUnroll::str(edge)
<< "\n";
} else
dfsrun(succ);
}
path.pop_back();
}
vector<BasicBlock *> path;
map<BasicBlock *, bool> vis;
BackEdgeList edges;
};
BackEdgeList
LoopUnroll::detect_back(Function *func) {
BackEdgeSearcher search(func->get_entry_block());
return search.edges;
}
vector<SimpleLoop>
LoopUnroll::check_sloops(const BackEdgeList &belist) const {
vector<SimpleLoop> sloops;
for (auto [e, b] : belist) {
SimpleLoop sl;
if (b->get_succ_basic_blocks().size() != 2 or
b->get_pre_basic_blocks().size() != 2)
continue;
bool flag = true;
// the start node should have 2*2 degree, and others have 1*1 degree
// one exception: br to neg_idx_except
for (auto p = e; p != b; p = *p->get_pre_basic_blocks().begin()) {
if (p->get_pre_basic_blocks().size() != 1) {
flag = false;
break;
}
auto succ_bbs = p->get_succ_basic_blocks();
if (succ_bbs.size() != 1) {
assert(succ_bbs.size() == 2);
flag = false;
for (auto bb : succ_bbs) {
auto instr = &*bb->get_instructions().begin();
if (instr->is_call() and
instr->get_operand(0) == neg_func) {
flag = true;
break;
}
}
}
if (not flag)
break;
sl.insert(sl.begin(), p);
}
if (not flag)
continue;
sl.insert(sl.begin(), b);
if (flag)
sloops.push_back(sl);
}
return sloops;
}
void
LoopUnroll::run() {
for (auto &_f : m_->get_functions()) {
if (_f.is_declaration())
continue;
auto func = &_f;
// cout << func->get_name() << endl;
auto belist = detect_back(func);
auto sloops = check_sloops(belist);
cout << "get simple loops for function " << func->get_name() << ":\n";
for (auto sl : sloops) {
cout << "\t";
for (auto p : sl)
cout << p->get_name() << " ";
cout << "\n";
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment