Commit 158a82bc authored by lxq's avatar lxq

loop unroll finished

parent efe0d887
#ifndef BRMERGE_HPP
#define BRMERGE_HPP
#include "Module.h"
#include "PassManager.hpp"
class BrMerge : public Pass {
public:
BrMerge(Module *m) : Pass(m) {}
void run() override;
private:
};
#endif
...@@ -2,9 +2,11 @@ ...@@ -2,9 +2,11 @@
#ifndef LOOPUNROLL_HPP #ifndef LOOPUNROLL_HPP
#define LOOPUNROLL_HPP #define LOOPUNROLL_HPP
#include "BasicBlock.h" #include "BasicBlock.h"
#include "Instruction.h"
#include "Module.h" #include "Module.h"
#include "PassManager.hpp" #include "PassManager.hpp"
#include "Type.h" #include "Type.h"
#include "Value.h"
#include <map> #include <map>
#include <ostream> #include <ostream>
...@@ -35,17 +37,42 @@ struct BackEdgeSearcher { ...@@ -35,17 +37,42 @@ struct BackEdgeSearcher {
}; };
} // namespace Graph } // namespace Graph
namespace Analysis{ namespace Analysis {
struct LoopAnalysis { #define Threshold 100
LoopAnalysis(const Graph::SimpleLoop &); struct CountedLoop {
LoopAnalysis() = delete; typedef union {
enum { INT, FLOAT, UNDEF} Type;
union {
int v; int v;
float fv; float fv;
} initial, delta, threshold; } I_F;
};} CountedLoop(const Graph::SimpleLoop &);
CountedLoop() = delete;
void new_emulate() {
switch (Type) {
case INT:
emulate.v = initial.v;
break;
case FLOAT:
emulate.fv = initial.fv;
break;
case UNDEF:
assert(false);
break;
}
};
// whether cur emulate shuold continue loop
bool judge();
// emulate the delta part
void next();
enum { INT, FLOAT, UNDEF } Type;
I_F initial, stop, emulate;
int count; // determined in construct function
bool reverse; // for delta emulate part
BinaryInst *delta;
Instruction *control;
};
} // namespace Analysis
/* This is a class to unroll simple loops: /* This is a class to unroll simple loops:
* - strict structure: * - strict structure:
...@@ -56,6 +83,7 @@ struct LoopAnalysis { ...@@ -56,6 +83,7 @@ struct LoopAnalysis {
class LoopUnroll : public Pass { class LoopUnroll : public Pass {
public: public:
LoopUnroll(Module *_m) : Pass(_m) { LoopUnroll(Module *_m) : Pass(_m) {
m_->set_print_name();
for (auto &f : m_->get_functions()) for (auto &f : m_->get_functions())
if (f.get_name() == "neg_idx_except") { if (f.get_name() == "neg_idx_except") {
neg_func = &f; neg_func = &f;
...@@ -75,7 +103,28 @@ class LoopUnroll : public Pass { ...@@ -75,7 +103,28 @@ class LoopUnroll : public Pass {
private: private:
Function *neg_func; Function *neg_func;
map<Value *, Value *> old2new;
Graph::BackEdgeList detect_back(Function *); Graph::BackEdgeList detect_back(Function *);
vector<Graph::SimpleLoop> check_sloops(const Graph::BackEdgeList &) const; vector<Graph::SimpleLoop> check_sloops(const Graph::BackEdgeList &) const;
void unroll_loop(Graph::SimpleLoop &);
Value *right_v(Value *v) const {
if (old2new.find(v) != old2new.end()) {
return old2new.at(v);
} else
return v;
}
BasicBlock *copy_instruction(Instruction &instr,
BasicBlock *bb, // old block
BasicBlock *BB, // new block
BasicBlock *pre, // previous of whole loop
BasicBlock *succ, // successor of whole loop
Graph::SimpleLoop &sl, // the whole loop
Analysis::CountedLoop &cl, // analysis info
bool init);
bool is_neg_block(BasicBlock *bb) const {
auto instr = &*bb->get_instructions().begin();
return (instr->is_call() and instr->get_operand(0) == neg_func);
}
}; };
#endif #endif
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "GVN.h" #include "GVN.h"
// #include "LoopInvHoist.hpp" // #include "LoopInvHoist.hpp"
// #include "LoopSearch.hpp" // #include "LoopSearch.hpp"
#include "BrMerge.hpp"
#include "ExceptCallMerge.hpp" #include "ExceptCallMerge.hpp"
#include "LoopUnroll.hpp" #include "LoopUnroll.hpp"
#include "Mem2Reg.hpp" #include "Mem2Reg.hpp"
...@@ -22,7 +23,8 @@ using namespace std::literals::string_literals; ...@@ -22,7 +23,8 @@ using namespace std::literals::string_literals;
void void
print_help(std::string exe_name) { print_help(std::string exe_name) {
std::cout << "Usage: " << exe_name std::cout << "Usage: " << exe_name
<< " [ -h | --help ] [ -o <target-file> ] [ -emit-llvm ] " << " [-h | --help] [-o <target-file>] [-S <assembly-file>] "
"[-emit-llvm] [-loopunroll] "
"[-mem2reg] [-gvn] [-dump-json] <input-file>" "[-mem2reg] [-gvn] [-dump-json] <input-file>"
<< std::endl; << std::endl;
} }
...@@ -62,7 +64,7 @@ main(int argc, char **argv) { ...@@ -62,7 +64,7 @@ main(int argc, char **argv) {
} else if (argv[i] == "-gvn"s) { } else if (argv[i] == "-gvn"s) {
gvn = true; gvn = true;
} else if (argv[i] == "-loopunroll"s) { } else if (argv[i] == "-loopunroll"s) {
loopunroll = true; gvn = loopunroll = true;
} else if (argv[i] == "-dump-json"s) { } else if (argv[i] == "-dump-json"s) {
dump_json = true; dump_json = true;
} else { } else {
...@@ -125,11 +127,15 @@ main(int argc, char **argv) { ...@@ -125,11 +127,15 @@ main(int argc, char **argv) {
if (loopunroll) { if (loopunroll) {
PM.add_pass<NegCallMerge>(false); PM.add_pass<NegCallMerge>(false);
PM.add_pass<LoopUnroll>(false); PM.add_pass<LoopUnroll>(false);
PM.add_pass<DeadCode>(false);
PM.add_pass<GVN>(false, dump_json);
PM.add_pass<DeadCode>(false);
PM.add_pass<BrMerge>(false);
} }
m->set_print_name();
PM.run(); PM.run();
m->set_print_name();
auto IR = m->print(); auto IR = m->print();
if (assembly) { if (assembly) {
......
...@@ -9,7 +9,6 @@ ...@@ -9,7 +9,6 @@
#include "Value.h" #include "Value.h"
#include "ast.hpp" #include "ast.hpp"
#include "regalloc.hpp" #include "regalloc.hpp"
#include "syntax_analyzer.h"
#include <algorithm> #include <algorithm>
#include <cstdint> #include <cstdint>
......
...@@ -82,6 +82,10 @@ LiveRangeAnalyzer::get_dfs_order(Function *func) { ...@@ -82,6 +82,10 @@ LiveRangeAnalyzer::get_dfs_order(Function *func) {
for (auto succ : bb->get_succ_basic_blocks()) for (auto succ : bb->get_succ_basic_blocks())
Q.push_front(succ); Q.push_front(succ);
} }
cout << "DFS order for function " << func->get_name() << ":\n";
for (auto bb : BB_DFS_Order)
cout << bb->get_name() << " ";
cout << endl;
} }
void void
......
...@@ -11,7 +11,7 @@ using std::for_each; ...@@ -11,7 +11,7 @@ using std::for_each;
using namespace RA; using namespace RA;
#define ASSERT_CMPINST_USED_ONCE(cmpinst) \ #define ASSERT_CMPINST_USED_ONCE(cmpinst) \
(assert(cmpinst->get_use_list().size() == 1)) (assert(cmpinst->get_use_list().size() <= 1))
int int
get_arg_id(Argument *arg) { get_arg_id(Argument *arg) {
...@@ -47,6 +47,8 @@ RegAllocator::no_reg_alloca(Value *v) { ...@@ -47,6 +47,8 @@ RegAllocator::no_reg_alloca(Value *v) {
// %op2 = icmp ne i32 %op1, 0 # <- if judges to here // %op2 = icmp ne i32 %op1, 0 # <- if judges to here
// br i1 %op2, label %label3, label %label5 // br i1 %op2, label %label3, label %label5
ASSERT_CMPINST_USED_ONCE(use_ins); ASSERT_CMPINST_USED_ONCE(use_ins);
if (use_ins->get_use_list().size() == 0)
return false;
auto use2_ins = dynamic_cast<Instruction *>( auto use2_ins = dynamic_cast<Instruction *>(
use_ins->get_use_list().begin()->val_); use_ins->get_use_list().begin()->val_);
alloc = not(use2_ins->is_br()); alloc = not(use2_ins->is_br());
...@@ -128,12 +130,10 @@ RegAllocator::LinearScan(const LVITS &liveints, Function *func) { ...@@ -128,12 +130,10 @@ RegAllocator::LinearScan(const LVITS &liveints, Function *func) {
void void
RegAllocator::ExpireOldIntervals(LiveInterval liveint) { RegAllocator::ExpireOldIntervals(LiveInterval liveint) {
auto it = active.begin(); auto it = active.begin();
for (; it != active.end() and it->first.j < liveint.first.i; ++it) for (; it != active.end() and it->first.j < liveint.first.i; ++it)
used[regmap.at(it->second)] = false; used[regmap.at(it->second)] = false;
active.erase(active.begin(), it); active.erase(active.begin(), it);
} }
void void
......
#include "BrMerge.hpp"
#include "BasicBlock.h"
#include "Constant.h"
#include "Instruction.h"
void
BrMerge::run() {
BranchInst *br;
ReturnInst *ret;
for (auto &func : m_->get_functions()) {
bool cont = true;
while (cont) {
cont = false;
for (auto &bb : func.get_basic_blocks()) {
if (&bb == func.get_entry_block())
continue;
auto &instructions = bb.get_instructions();
br = dynamic_cast<BranchInst *>(&*instructions.rbegin());
ret = dynamic_cast<ReturnInst *>(&*instructions.rbegin());
assert((br or ret) && "final Instruction");
if (instructions.size() == 1) {
if (br) { // change: br or cond-br with constant jump
BasicBlock *succ;
if (not br->is_cond_br()) {
assert(bb.get_succ_basic_blocks().size() == 1);
succ = *bb.get_succ_basic_blocks().begin();
} else if (dynamic_cast<Constant *>(
br->get_operand(0))) {
assert(bb.get_succ_basic_blocks().size() == 2);
succ = static_cast<BasicBlock *>(
dynamic_cast<ConstantInt *>(br->get_operand(0))
->get_value()
? br->get_operand(1)
: br->get_operand(2));
} else
continue;
for (auto pre : bb.get_pre_basic_blocks()) {
// change br's op
auto ins = &*pre->get_instructions().rbegin();
bool set = false;
for (int i = 0; i < ins->get_num_operand(); ++i)
if (ins->get_operand(i) == &bb) {
ins->set_operand(i, succ);
set = true;
break;
}
assert(set);
// change pre's succ
pre->remove_succ_basic_block(&bb);
pre->add_succ_basic_block(&bb);
// change succ's pre
succ->add_pre_basic_block(pre);
}
// change succ's pre
succ->remove_pre_basic_block(&bb);
// remove useless block
func.get_basic_blocks().remove(&bb);
cont = true;
} else { // ret: do not change
}
} else {
}
}
}
}
}
...@@ -5,4 +5,5 @@ add_library( ...@@ -5,4 +5,5 @@ add_library(
Dominators.cpp Dominators.cpp
Mem2Reg.cpp Mem2Reg.cpp
GVN.cpp GVN.cpp
BrMerge.cpp
) )
...@@ -46,15 +46,28 @@ NegCallMerge::run(Function *func) { ...@@ -46,15 +46,28 @@ NegCallMerge::run(Function *func) {
idx = {0}; idx = {0};
} }
for (auto i : idx) for (auto i : idx)
if (calls.find(br->get_operand(i)) != calls.end()) if (calls.find(br->get_operand(i)) != calls.end()) {
auto remove = static_cast<BasicBlock *>(br->get_operand(i));
if (remove == reserved)
continue;
br->get_operands()[i] = reserved; br->get_operands()[i] = reserved;
// correct pre/succ blocks
bb.add_succ_basic_block(reserved);
reserved->add_pre_basic_block(&bb);
// remove wrong graph links
bb.remove_succ_basic_block(remove);
remove->remove_pre_basic_block(&bb);
blocks.remove(remove);
}
} }
} }
return;
// remove useless BasicBlocks // remove useless BasicBlocks
for (auto _bb : calls) { for (auto _bb : calls) {
auto bb = static_cast<BasicBlock *>(_bb); auto bb = static_cast<BasicBlock *>(_bb);
if (bb != reserved) { if (bb != reserved) {
cout << "remove block " << bb->get_name() << " in function " << func->get_name() << endl; /* cout << "remove block " << bb->get_name() << " in function "
* << func->get_name() << endl; */
auto it = blocks.begin(); auto it = blocks.begin();
for (; &*it != bb; ++it) for (; &*it != bb; ++it)
; ;
......
...@@ -411,10 +411,10 @@ GVN::detectEquivalences() { ...@@ -411,10 +411,10 @@ GVN::detectEquivalences() {
break; break;
} }
default: { default: {
std::cerr << "In function " << func_->get_name() << ", " /* std::cerr << "In function " << func_->get_name() << ", "
<< " block " << bb->get_name() * << " block " << bb->get_name()
<< " has count of predecessors: " * << " has count of predecessors: "
<< pre_bbs_.size(); * << pre_bbs_.size(); */
continue; continue;
} }
} }
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment