Commit 158a82bc authored by lxq's avatar lxq

loop unroll finished

parent efe0d887
#ifndef BRMERGE_HPP
#define BRMERGE_HPP
#include "Module.h"
#include "PassManager.hpp"
class BrMerge : public Pass {
public:
BrMerge(Module *m) : Pass(m) {}
void run() override;
private:
};
#endif
...@@ -2,9 +2,11 @@ ...@@ -2,9 +2,11 @@
#ifndef LOOPUNROLL_HPP #ifndef LOOPUNROLL_HPP
#define LOOPUNROLL_HPP #define LOOPUNROLL_HPP
#include "BasicBlock.h" #include "BasicBlock.h"
#include "Instruction.h"
#include "Module.h" #include "Module.h"
#include "PassManager.hpp" #include "PassManager.hpp"
#include "Type.h" #include "Type.h"
#include "Value.h"
#include <map> #include <map>
#include <ostream> #include <ostream>
...@@ -35,17 +37,42 @@ struct BackEdgeSearcher { ...@@ -35,17 +37,42 @@ struct BackEdgeSearcher {
}; };
} // namespace Graph } // namespace Graph
namespace Analysis{ namespace Analysis {
struct LoopAnalysis { #define Threshold 100
LoopAnalysis(const Graph::SimpleLoop &); struct CountedLoop {
LoopAnalysis() = delete; typedef union {
enum { INT, FLOAT, UNDEF} Type;
union {
int v; int v;
float fv; float fv;
} initial, delta, threshold; } I_F;
};} CountedLoop(const Graph::SimpleLoop &);
CountedLoop() = delete;
void new_emulate() {
switch (Type) {
case INT:
emulate.v = initial.v;
break;
case FLOAT:
emulate.fv = initial.fv;
break;
case UNDEF:
assert(false);
break;
}
};
// whether cur emulate shuold continue loop
bool judge();
// emulate the delta part
void next();
enum { INT, FLOAT, UNDEF } Type;
I_F initial, stop, emulate;
int count; // determined in construct function
bool reverse; // for delta emulate part
BinaryInst *delta;
Instruction *control;
};
} // namespace Analysis
/* This is a class to unroll simple loops: /* This is a class to unroll simple loops:
* - strict structure: * - strict structure:
...@@ -56,6 +83,7 @@ struct LoopAnalysis { ...@@ -56,6 +83,7 @@ struct LoopAnalysis {
class LoopUnroll : public Pass { class LoopUnroll : public Pass {
public: public:
LoopUnroll(Module *_m) : Pass(_m) { LoopUnroll(Module *_m) : Pass(_m) {
m_->set_print_name();
for (auto &f : m_->get_functions()) for (auto &f : m_->get_functions())
if (f.get_name() == "neg_idx_except") { if (f.get_name() == "neg_idx_except") {
neg_func = &f; neg_func = &f;
...@@ -75,7 +103,28 @@ class LoopUnroll : public Pass { ...@@ -75,7 +103,28 @@ class LoopUnroll : public Pass {
private: private:
Function *neg_func; Function *neg_func;
map<Value *, Value *> old2new;
Graph::BackEdgeList detect_back(Function *); Graph::BackEdgeList detect_back(Function *);
vector<Graph::SimpleLoop> check_sloops(const Graph::BackEdgeList &) const; vector<Graph::SimpleLoop> check_sloops(const Graph::BackEdgeList &) const;
void unroll_loop(Graph::SimpleLoop &);
Value *right_v(Value *v) const {
if (old2new.find(v) != old2new.end()) {
return old2new.at(v);
} else
return v;
}
BasicBlock *copy_instruction(Instruction &instr,
BasicBlock *bb, // old block
BasicBlock *BB, // new block
BasicBlock *pre, // previous of whole loop
BasicBlock *succ, // successor of whole loop
Graph::SimpleLoop &sl, // the whole loop
Analysis::CountedLoop &cl, // analysis info
bool init);
bool is_neg_block(BasicBlock *bb) const {
auto instr = &*bb->get_instructions().begin();
return (instr->is_call() and instr->get_operand(0) == neg_func);
}
}; };
#endif #endif
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "GVN.h" #include "GVN.h"
// #include "LoopInvHoist.hpp" // #include "LoopInvHoist.hpp"
// #include "LoopSearch.hpp" // #include "LoopSearch.hpp"
#include "BrMerge.hpp"
#include "ExceptCallMerge.hpp" #include "ExceptCallMerge.hpp"
#include "LoopUnroll.hpp" #include "LoopUnroll.hpp"
#include "Mem2Reg.hpp" #include "Mem2Reg.hpp"
...@@ -22,7 +23,8 @@ using namespace std::literals::string_literals; ...@@ -22,7 +23,8 @@ using namespace std::literals::string_literals;
void void
print_help(std::string exe_name) { print_help(std::string exe_name) {
std::cout << "Usage: " << exe_name std::cout << "Usage: " << exe_name
<< " [ -h | --help ] [ -o <target-file> ] [ -emit-llvm ] " << " [-h | --help] [-o <target-file>] [-S <assembly-file>] "
"[-emit-llvm] [-loopunroll] "
"[-mem2reg] [-gvn] [-dump-json] <input-file>" "[-mem2reg] [-gvn] [-dump-json] <input-file>"
<< std::endl; << std::endl;
} }
...@@ -62,7 +64,7 @@ main(int argc, char **argv) { ...@@ -62,7 +64,7 @@ main(int argc, char **argv) {
} else if (argv[i] == "-gvn"s) { } else if (argv[i] == "-gvn"s) {
gvn = true; gvn = true;
} else if (argv[i] == "-loopunroll"s) { } else if (argv[i] == "-loopunroll"s) {
loopunroll = true; gvn = loopunroll = true;
} else if (argv[i] == "-dump-json"s) { } else if (argv[i] == "-dump-json"s) {
dump_json = true; dump_json = true;
} else { } else {
...@@ -125,11 +127,15 @@ main(int argc, char **argv) { ...@@ -125,11 +127,15 @@ main(int argc, char **argv) {
if (loopunroll) { if (loopunroll) {
PM.add_pass<NegCallMerge>(false); PM.add_pass<NegCallMerge>(false);
PM.add_pass<LoopUnroll>(false); PM.add_pass<LoopUnroll>(false);
PM.add_pass<DeadCode>(false);
PM.add_pass<GVN>(false, dump_json);
PM.add_pass<DeadCode>(false);
PM.add_pass<BrMerge>(false);
} }
m->set_print_name();
PM.run(); PM.run();
m->set_print_name();
auto IR = m->print(); auto IR = m->print();
if (assembly) { if (assembly) {
......
...@@ -9,7 +9,6 @@ ...@@ -9,7 +9,6 @@
#include "Value.h" #include "Value.h"
#include "ast.hpp" #include "ast.hpp"
#include "regalloc.hpp" #include "regalloc.hpp"
#include "syntax_analyzer.h"
#include <algorithm> #include <algorithm>
#include <cstdint> #include <cstdint>
......
...@@ -82,6 +82,10 @@ LiveRangeAnalyzer::get_dfs_order(Function *func) { ...@@ -82,6 +82,10 @@ LiveRangeAnalyzer::get_dfs_order(Function *func) {
for (auto succ : bb->get_succ_basic_blocks()) for (auto succ : bb->get_succ_basic_blocks())
Q.push_front(succ); Q.push_front(succ);
} }
cout << "DFS order for function " << func->get_name() << ":\n";
for (auto bb : BB_DFS_Order)
cout << bb->get_name() << " ";
cout << endl;
} }
void void
......
...@@ -11,7 +11,7 @@ using std::for_each; ...@@ -11,7 +11,7 @@ using std::for_each;
using namespace RA; using namespace RA;
#define ASSERT_CMPINST_USED_ONCE(cmpinst) \ #define ASSERT_CMPINST_USED_ONCE(cmpinst) \
(assert(cmpinst->get_use_list().size() == 1)) (assert(cmpinst->get_use_list().size() <= 1))
int int
get_arg_id(Argument *arg) { get_arg_id(Argument *arg) {
...@@ -47,6 +47,8 @@ RegAllocator::no_reg_alloca(Value *v) { ...@@ -47,6 +47,8 @@ RegAllocator::no_reg_alloca(Value *v) {
// %op2 = icmp ne i32 %op1, 0 # <- if judges to here // %op2 = icmp ne i32 %op1, 0 # <- if judges to here
// br i1 %op2, label %label3, label %label5 // br i1 %op2, label %label3, label %label5
ASSERT_CMPINST_USED_ONCE(use_ins); ASSERT_CMPINST_USED_ONCE(use_ins);
if (use_ins->get_use_list().size() == 0)
return false;
auto use2_ins = dynamic_cast<Instruction *>( auto use2_ins = dynamic_cast<Instruction *>(
use_ins->get_use_list().begin()->val_); use_ins->get_use_list().begin()->val_);
alloc = not(use2_ins->is_br()); alloc = not(use2_ins->is_br());
...@@ -128,12 +130,10 @@ RegAllocator::LinearScan(const LVITS &liveints, Function *func) { ...@@ -128,12 +130,10 @@ RegAllocator::LinearScan(const LVITS &liveints, Function *func) {
void void
RegAllocator::ExpireOldIntervals(LiveInterval liveint) { RegAllocator::ExpireOldIntervals(LiveInterval liveint) {
auto it = active.begin(); auto it = active.begin();
for (; it != active.end() and it->first.j < liveint.first.i; ++it) for (; it != active.end() and it->first.j < liveint.first.i; ++it)
used[regmap.at(it->second)] = false; used[regmap.at(it->second)] = false;
active.erase(active.begin(), it); active.erase(active.begin(), it);
} }
void void
......
#include "BrMerge.hpp"
#include "BasicBlock.h"
#include "Constant.h"
#include "Instruction.h"
void
BrMerge::run() {
BranchInst *br;
ReturnInst *ret;
for (auto &func : m_->get_functions()) {
bool cont = true;
while (cont) {
cont = false;
for (auto &bb : func.get_basic_blocks()) {
if (&bb == func.get_entry_block())
continue;
auto &instructions = bb.get_instructions();
br = dynamic_cast<BranchInst *>(&*instructions.rbegin());
ret = dynamic_cast<ReturnInst *>(&*instructions.rbegin());
assert((br or ret) && "final Instruction");
if (instructions.size() == 1) {
if (br) { // change: br or cond-br with constant jump
BasicBlock *succ;
if (not br->is_cond_br()) {
assert(bb.get_succ_basic_blocks().size() == 1);
succ = *bb.get_succ_basic_blocks().begin();
} else if (dynamic_cast<Constant *>(
br->get_operand(0))) {
assert(bb.get_succ_basic_blocks().size() == 2);
succ = static_cast<BasicBlock *>(
dynamic_cast<ConstantInt *>(br->get_operand(0))
->get_value()
? br->get_operand(1)
: br->get_operand(2));
} else
continue;
for (auto pre : bb.get_pre_basic_blocks()) {
// change br's op
auto ins = &*pre->get_instructions().rbegin();
bool set = false;
for (int i = 0; i < ins->get_num_operand(); ++i)
if (ins->get_operand(i) == &bb) {
ins->set_operand(i, succ);
set = true;
break;
}
assert(set);
// change pre's succ
pre->remove_succ_basic_block(&bb);
pre->add_succ_basic_block(&bb);
// change succ's pre
succ->add_pre_basic_block(pre);
}
// change succ's pre
succ->remove_pre_basic_block(&bb);
// remove useless block
func.get_basic_blocks().remove(&bb);
cont = true;
} else { // ret: do not change
}
} else {
}
}
}
}
}
...@@ -5,4 +5,5 @@ add_library( ...@@ -5,4 +5,5 @@ add_library(
Dominators.cpp Dominators.cpp
Mem2Reg.cpp Mem2Reg.cpp
GVN.cpp GVN.cpp
BrMerge.cpp
) )
...@@ -46,15 +46,28 @@ NegCallMerge::run(Function *func) { ...@@ -46,15 +46,28 @@ NegCallMerge::run(Function *func) {
idx = {0}; idx = {0};
} }
for (auto i : idx) for (auto i : idx)
if (calls.find(br->get_operand(i)) != calls.end()) if (calls.find(br->get_operand(i)) != calls.end()) {
auto remove = static_cast<BasicBlock *>(br->get_operand(i));
if (remove == reserved)
continue;
br->get_operands()[i] = reserved; br->get_operands()[i] = reserved;
// correct pre/succ blocks
bb.add_succ_basic_block(reserved);
reserved->add_pre_basic_block(&bb);
// remove wrong graph links
bb.remove_succ_basic_block(remove);
remove->remove_pre_basic_block(&bb);
blocks.remove(remove);
}
} }
} }
return;
// remove useless BasicBlocks // remove useless BasicBlocks
for (auto _bb : calls) { for (auto _bb : calls) {
auto bb = static_cast<BasicBlock *>(_bb); auto bb = static_cast<BasicBlock *>(_bb);
if (bb != reserved) { if (bb != reserved) {
cout << "remove block " << bb->get_name() << " in function " << func->get_name() << endl; /* cout << "remove block " << bb->get_name() << " in function "
* << func->get_name() << endl; */
auto it = blocks.begin(); auto it = blocks.begin();
for (; &*it != bb; ++it) for (; &*it != bb; ++it)
; ;
......
...@@ -411,10 +411,10 @@ GVN::detectEquivalences() { ...@@ -411,10 +411,10 @@ GVN::detectEquivalences() {
break; break;
} }
default: { default: {
std::cerr << "In function " << func_->get_name() << ", " /* std::cerr << "In function " << func_->get_name() << ", "
<< " block " << bb->get_name() * << " block " << bb->get_name()
<< " has count of predecessors: " * << " has count of predecessors: "
<< pre_bbs_.size(); * << pre_bbs_.size(); */
continue; continue;
} }
} }
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include "Constant.h" #include "Constant.h"
#include "Function.h" #include "Function.h"
#include "Instruction.h" #include "Instruction.h"
#include "syntax_analyzer.h" #include "Type.h"
#include <vector> #include <vector>
...@@ -13,6 +13,10 @@ using std::find; ...@@ -13,6 +13,10 @@ using std::find;
using namespace Graph; using namespace Graph;
using namespace Analysis; using namespace Analysis;
#define CONSTINT(x) ConstantInt::get(x, m_)
#define CONSTFP(x) ConstantFP::get(x, m_)
#define op(instr, i) right_v(instr.get_operand(i))
void void
LoopUnroll::run() { LoopUnroll::run() {
for (auto &_f : m_->get_functions()) { for (auto &_f : m_->get_functions()) {
...@@ -29,56 +33,172 @@ LoopUnroll::run() { ...@@ -29,56 +33,172 @@ LoopUnroll::run() {
cout << p->get_name() << " "; cout << p->get_name() << " ";
cout << "\n"; cout << "\n";
} }
for (auto sl : sloops) {
unroll_loop(sl);
}
} }
} }
void
LoopUnroll::unroll_loop(SimpleLoop &sl) {
CountedLoop cl(sl);
auto func = sl[0]->get_parent();
switch (cl.Type) {
case CountedLoop::UNDEF:
return;
case CountedLoop::INT:
cout << "Get CountedLoop in function " << func->get_name()
<< ":\n\t"
<< "initial value: " << cl.initial.v
<< ", stop value: " << cl.stop.v << "\n\t"
<< "delta: " << cl.delta->print() << "\n\t"
<< "loop times: " << cl.count << endl;
break;
case CountedLoop::FLOAT:
cout << "Get CountedLoop in function " << func->get_name()
<< ":\n\t"
<< "initial value: " << cl.initial.fv
<< ", stop value: " << cl.stop.fv << "\n\t"
<< "delta: " << cl.delta->print() << "\n\t"
<< "loop times: " << cl.count << endl;
break;
}
// get pre block and succ block
auto b = *sl.begin();
auto e = *sl.rbegin();
BasicBlock *pre, *succ;
for (auto p : b->get_pre_basic_blocks()) {
if (p != e) {
pre = p;
break;
}
}
for (auto s : b->get_succ_basic_blocks()) {
if (find(sl.begin(), sl.end(), s) == sl.end()) {
succ = s;
break;
}
}
auto BB = BasicBlock::create(m_, "", func);
auto nb = BB;
LoopAnalysis::LoopAnalysis(const Graph::SimpleLoop &sl) { // unroll loop: copy instructions
bool init = true;
for (cl.new_emulate(); cl.judge(); cl.next()) {
for (auto bb : sl) {
for (auto &instr : bb->get_instructions()) {
BB = copy_instruction(instr, bb, BB, pre, succ, sl, cl, init);
}
}
init = false;
}
for (auto &instr : b->get_instructions())
BB = copy_instruction(instr, b, BB, pre, succ, sl, cl, false);
BranchInst::create_br(succ, BB);
// correct pre's br
auto pre_br =
dynamic_cast<BranchInst *>(&*pre->get_instructions().rbegin());
if (pre_br->is_cond_br()) {
assert(pre_br->get_operand(1) == b or pre_br->get_operand(2) == b);
if (pre_br->get_operand(1) == b)
pre_br->set_operand(1, nb);
else
pre_br->set_operand(2, nb);
} else {
assert(pre_br->get_operand(0) == b);
pre_br->set_operand(0, nb);
}
/* new links */
// nb & pre
nb->add_pre_basic_block(pre);
pre->add_succ_basic_block(nb);
// BB & succ: br maintain the links
/* succ->add_pre_basic_block(BB);
* BB->add_succ_basic_block(succ); */
// old links: remove links from old graph
pre->remove_succ_basic_block(b);
succ->remove_pre_basic_block(b);
// remove blocks in simpleloop from func->get_basic_blocks()
for (auto b : sl)
func->get_basic_blocks().remove(b);
// neg block's pre blocks
// replace use
m_->set_print_name();
for (auto [k, v] : old2new)
cout << "[debug]" << k->get_name() << "-" << v->get_name() << endl;
for (auto &bb : func->get_basic_blocks()) {
for (auto &instr : bb.get_instructions()) {
for (int i = 0; i < instr.get_num_operand(); ++i) {
if (old2new.find(instr.get_operand(i)) != old2new.end()) {
instr.set_operand(i, old2new.at(instr.get_operand(i)));
}
}
}
}
}
CountedLoop::CountedLoop(const Graph::SimpleLoop &sl) : Type(UNDEF), count(0) {
auto b = sl.front(); auto b = sl.front();
auto e = sl.back(); auto e = sl.back();
Value *i, *control_op;
PhiInst *phi;
// In `p`, get stop number(const) // In `p`, get stop number(const)
Value *i;
auto rit = b->get_instructions().rbegin(); auto rit = b->get_instructions().rbegin();
assert(dynamic_cast<BranchInst *>(&*rit) && assert(dynamic_cast<BranchInst *>(&*rit) &&
"The end instruction of a block should be branch"); "The end instruction of a block should be branch");
i = (rit++)->get_operand(0); i = (rit++)->get_operand(0);
assert(i == &*rit && dynamic_cast<CmpInst *>(&*rit) && assert(i == &*rit);
static_cast<CmpInst *>(&*rit)->get_cmp_op() == CmpInst::NE && if (dynamic_cast<CmpInst *>(&*rit)) {
assert(static_cast<CmpInst *>(&*rit)->get_cmp_op() == CmpInst::NE &&
"should be neq 0"); "should be neq 0");
} else if (dynamic_cast<FCmpInst *>(&*rit)) {
assert(static_cast<FCmpInst *>(&*rit)->get_cmp_op() == FCmpInst::NE &&
"should be neq 0");
} else
assert(false && "should not have this case");
control = dynamic_cast<Instruction *>(&*rit); // only `ne` case
i = (rit++)->get_operand(0); i = (rit++)->get_operand(0);
assert(i == &*rit && dynamic_cast<ZextInst *>(&*rit) && "neqz"); if (dynamic_cast<ZextInst *>(&*rit)) {
assert(i == &*rit);
i = (rit++)->get_operand(0); i = (rit++)->get_operand(0);
assert( control = dynamic_cast<Instruction *>(&*rit);
i == &*rit && assert(i == &*rit &&
(dynamic_cast<CmpInst *>(&*rit) or dynamic_cast<FCmpInst *>(&*rit)) && (dynamic_cast<CmpInst *>(control) or
dynamic_cast<FCmpInst *>(control)) &&
"cmp or fcmp"); "cmp or fcmp");
if (dynamic_cast<Constant *>(rit->get_operand(0)) or }
dynamic_cast<Constant *>(rit->get_operand(1))) { if (dynamic_cast<Constant *>(control->get_operand(0)) or
if (dynamic_cast<CmpInst *>(&*rit)) { dynamic_cast<Constant *>(control->get_operand(1))) {
if (dynamic_cast<FCmpInst *>(&*control)) {
Type = FLOAT; Type = FLOAT;
auto constfloat = dynamic_cast<ConstantFP *>(rit->get_operand(1)); auto constfloat =
dynamic_cast<ConstantFP *>(control->get_operand(1));
assert(constfloat && assert(constfloat &&
"the case const at operand(0) not implemented"); "the case const at operand(0) not implemented");
threshold.fv = constfloat->get_value(); stop.fv = constfloat->get_value();
} else { } else {
Type = INT; Type = INT;
auto constint = dynamic_cast<ConstantInt *>(rit->get_operand(1)); auto constint =
dynamic_cast<ConstantInt *>(control->get_operand(1));
assert(constint && "the case const at operand(0) not implemented"); assert(constint && "the case const at operand(0) not implemented");
threshold.v = constint->get_value(); stop.v = constint->get_value();
}
} else {
Type = UNDEF;
return;
} }
} else
goto can_not_count;
// get control value and initial value // get control value and initial value
auto control = rit->get_operand(0); control_op = control->get_operand(0);
auto it = b->get_instructions().begin(); phi = dynamic_cast<PhiInst *>(control_op);
for (; it != b->get_instructions().end(); ++it) { if (phi == nullptr)
auto phi = dynamic_cast<PhiInst *>(&*it); goto can_not_count;
assert(phi && "unexpected structure for while block"); assert(phi->get_parent() == b && "unexpected structure for while block");
if (&*it != control)
continue;
assert(phi->get_operand(3) == e && "the orther case not implemented"); assert(phi->get_operand(3) == e && "the orther case not implemented");
if (dynamic_cast<Constant *>(phi->get_operand(0))) { if (dynamic_cast<Constant *>(phi->get_operand(0))) {
switch (Type) { switch (Type) {
...@@ -87,24 +207,158 @@ LoopAnalysis::LoopAnalysis(const Graph::SimpleLoop &sl) { ...@@ -87,24 +207,158 @@ LoopAnalysis::LoopAnalysis(const Graph::SimpleLoop &sl) {
->get_value(); ->get_value();
break; break;
case FLOAT: case FLOAT:
initial.fv = static_cast<ConstantFP *>(phi->get_operand(0)) initial.fv =
->get_value(); static_cast<ConstantFP *>(phi->get_operand(0))->get_value();
break; break;
case UNDEF: case UNDEF:
assert(false); assert(false);
break; break;
} }
} else { } else
Type = UNDEF; goto can_not_count;
return;
}
}
// get delta // get delta, maybe `control` op `const` or `const` op `control`
delta = dynamic_cast<BinaryInst *>(phi->get_operand(2));
if (delta == nullptr)
goto can_not_count;
if (delta->get_operand(0) != control_op and
delta->get_operand(1) != control_op)
goto can_not_count;
if (delta->get_operand(0) == control_op) {
reverse = false;
if (dynamic_cast<Constant *>(delta->get_operand(1)) == nullptr)
goto can_not_count;
} else { // control at [1]
reverse = true;
if (dynamic_cast<Constant *>(delta->get_operand(0)) == nullptr)
goto can_not_count;
}
// check correctness // check correctness
// count loop // count loop
new_emulate();
for (count = 0; judge(); ++count) {
if (count > Threshold)
goto can_not_count;
next();
}
return;
can_not_count:
Type = UNDEF;
}
bool
CountedLoop::judge() {
bool flag;
// cycle judge
switch (Type) {
case INT: {
switch (static_cast<CmpInst *>(control)->get_cmp_op()) {
case CmpInst::EQ:
flag = emulate.v == stop.v;
break;
case CmpInst::NE:
flag = emulate.v != stop.v;
break;
case CmpInst::GT:
flag = emulate.v > stop.v;
break;
case CmpInst::GE:
flag = emulate.v >= stop.v;
break;
case CmpInst::LT:
flag = emulate.v < stop.v;
break;
case CmpInst::LE:
flag = emulate.v <= stop.v;
break;
}
break;
}
case FLOAT: {
switch (static_cast<FCmpInst *>(control)->get_cmp_op()) {
case FCmpInst::EQ:
flag = emulate.fv == stop.fv;
break;
case FCmpInst::NE:
flag = emulate.fv != stop.fv;
break;
case FCmpInst::GT:
flag = emulate.fv > stop.fv;
break;
case FCmpInst::GE:
flag = emulate.fv >= stop.fv;
break;
case FCmpInst::LT:
flag = emulate.fv < stop.fv;
break;
case FCmpInst::LE:
flag = emulate.fv <= stop.fv;
break;
}
break;
}
case UNDEF:
assert(false);
}
return flag;
}
void
CountedLoop::next() {
switch (Type) {
case INT: {
int op2 =
static_cast<ConstantInt *>(delta->get_operand(reverse ? 0 : 1))
->get_value();
switch (delta->get_instr_type()) {
case Instruction::add:
emulate.v += op2;
break;
case Instruction::sub:
emulate.v = (reverse ? -1 : 1) * (emulate.v - op2);
break;
case Instruction::mul:
emulate.v *= op2;
break;
case Instruction::sdiv:
emulate.v = (reverse ? op2 / emulate.v : emulate.v / op2);
break;
default:
assert(false && "not implemented");
break;
}
break;
}
case FLOAT: {
float op2 =
static_cast<ConstantFP *>(delta->get_operand(reverse ? 0 : 1))
->get_value();
switch (delta->get_instr_type()) {
case Instruction::fadd:
emulate.fv += op2;
break;
case Instruction::fsub:
emulate.fv = (reverse ? -1 : 1) * (emulate.fv - op2);
break;
case Instruction::fmul:
emulate.fv *= op2;
break;
case Instruction::fdiv:
emulate.fv =
(reverse ? (op2 / emulate.fv) : (emulate.fv / op2));
break;
default:
assert(false && "not implemented");
break;
}
break;
}
case UNDEF:
assert(false);
}
} }
vector<SimpleLoop> vector<SimpleLoop>
...@@ -128,9 +382,7 @@ LoopUnroll::check_sloops(const BackEdgeList &belist) const { ...@@ -128,9 +382,7 @@ LoopUnroll::check_sloops(const BackEdgeList &belist) const {
assert(succ_bbs.size() == 2); assert(succ_bbs.size() == 2);
flag = false; flag = false;
for (auto bb : succ_bbs) { for (auto bb : succ_bbs) {
auto instr = &*bb->get_instructions().begin(); if (is_neg_block(bb)) {
if (instr->is_call() and
instr->get_operand(0) == neg_func) {
flag = true; flag = true;
break; break;
} }
...@@ -178,3 +430,201 @@ BackEdgeSearcher::dfsrun(BasicBlock *bb) { ...@@ -178,3 +430,201 @@ BackEdgeSearcher::dfsrun(BasicBlock *bb) {
path.pop_back(); path.pop_back();
} }
BasicBlock *
LoopUnroll::copy_instruction(Instruction &instr,
BasicBlock *bb,
BasicBlock *BB,
BasicBlock *pre,
BasicBlock *succ,
SimpleLoop &sl,
CountedLoop &cl,
bool init) {
Value *n;
auto b = *sl.begin();
auto e = *sl.rbegin();
auto func = b->get_parent();
switch (instr.get_instr_type()) {
case Instruction::ret:
if (instr.get_num_operand() == 0)
ReturnInst::create_void_ret(BB);
else
ReturnInst::create_ret(op(instr, 0), BB);
break;
case Instruction::br:
if (bb == b) { // do nothing
assert(instr.get_operand(1) == succ or
instr.get_operand(2) == succ && "unexpected structure");
} else {
if (instr.get_num_operand() == 3) {
// we know the neg block is always at operand(1)
auto newBB = BasicBlock::create(m_, "", func);
BranchInst::create_cond_br(
op(instr, 0),
static_cast<BasicBlock *>(op(instr, 1)),
newBB,
BB);
BB = newBB;
} else { // do nothing
assert(instr.get_num_operand() == 1 and
instr.get_operand(0) == b);
}
}
break;
case Instruction::add:
n = BinaryInst::create_add(op(instr, 0), op(instr, 1), BB, m_);
break;
case Instruction::sub:
n = BinaryInst::create_sub(op(instr, 0), op(instr, 1), BB, m_);
break;
case Instruction::mul:
n = BinaryInst::create_mul(op(instr, 0), op(instr, 1), BB, m_);
break;
case Instruction::sdiv:
n = BinaryInst::create_sdiv(op(instr, 0), op(instr, 1), BB, m_);
break;
case Instruction::fadd:
n = BinaryInst::create_fadd(op(instr, 0), op(instr, 1), BB, m_);
break;
case Instruction::fsub:
n = BinaryInst::create_fsub(op(instr, 0), op(instr, 1), BB, m_);
break;
case Instruction::fmul:
n = BinaryInst::create_fmul(op(instr, 0), op(instr, 1), BB, m_);
break;
case Instruction::fdiv:
n = BinaryInst::create_fdiv(op(instr, 0), op(instr, 1), BB, m_);
break;
case Instruction::alloca:
n = AllocaInst::create_alloca(
static_cast<AllocaInst *>(&instr)->get_alloca_type(), BB);
break;
case Instruction::load:
n = LoadInst::create_load(
static_cast<LoadInst *>(&instr)->get_load_type(),
op(instr, 0),
BB);
break;
case Instruction::store:
StoreInst::create_store(op(instr, 0), op(instr, 1), BB);
break;
case Instruction::cmp:
n = CmpInst::create_cmp(
static_cast<CmpInst *>(&instr)->get_cmp_op(),
op(instr, 0),
op(instr, 1),
BB,
m_);
break;
case Instruction::fcmp:
n = FCmpInst::create_fcmp(
static_cast<FCmpInst *>(&instr)->get_cmp_op(),
op(instr, 0),
op(instr, 1),
BB,
m_);
break;
case Instruction::phi: { // make it a copy
assert(bb == b);
Value *v, *zero;
// zero
switch (cl.Type) {
case CountedLoop::INT:
zero = CONSTINT(0);
break;
case CountedLoop::FLOAT:
zero = CONSTFP(0);
break;
case CountedLoop::UNDEF:
assert(false);
}
// v
if (&instr == cl.control) {
switch (cl.Type) {
case CountedLoop::INT:
v = CONSTINT(cl.emulate.v);
break;
case CountedLoop::FLOAT:
v = CONSTFP(cl.emulate.fv);
break;
case CountedLoop::UNDEF:
assert(false);
}
} else if (init) {
assert(instr.get_operand(1) == pre);
v = op(instr, 0);
} else {
assert(instr.get_operand(3) == e);
v = op(instr, 2);
}
switch (cl.Type) {
case CountedLoop::INT:
n = BinaryInst::create_add(zero, v, BB, m_);
break;
case CountedLoop::FLOAT:
n = BinaryInst::create_fadd(zero, v, BB, m_);
break;
case CountedLoop::UNDEF:
assert(false);
}
break;
}
case Instruction::call: {
vector<Value *> args;
for (int i = 1; i < instr.get_num_operand(); ++i)
args.push_back(op(instr, i));
n = CallInst::create(
static_cast<Function *>(op(instr, 0)), args, BB);
break;
}
case Instruction::getelementptr: {
vector<Value *> idxs;
for (int i = 1; i < instr.get_num_operand(); ++i)
idxs.push_back(op(instr, i));
n = GetElementPtrInst::create_gep(op(instr, 0), idxs, BB);
break;
}
case Instruction::zext:
n = ZextInst::create_zext(
op(instr, 0),
static_cast<ZextInst *>(&instr)->get_dest_type(),
BB);
break;
case Instruction::fptosi:
n = FpToSiInst::create_fptosi(
op(instr, 0),
static_cast<FpToSiInst *>(&instr)->get_dest_type(),
BB);
break;
case Instruction::sitofp:
n = SiToFpInst::create_sitofp(
op(instr, 0),
static_cast<SiToFpInst *>(&instr)->get_dest_type(),
BB);
break;
}
if (not instr.is_void())
old2new[&instr] = n;
return BB;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment