diff --git a/include/optimization/LoopUnroll.hpp b/include/optimization/LoopUnroll.hpp index 21910992c223660feca9ca9327b1d7c104d8ef45..6942eabfa2e8e5064a37b2e1bd0f04041dc5ce30 100644 --- a/include/optimization/LoopUnroll.hpp +++ b/include/optimization/LoopUnroll.hpp @@ -4,13 +4,16 @@ #include "BasicBlock.h" #include "Module.h" #include "PassManager.hpp" +#include "Type.h" +#include #include #include #include using std::cout; using std::endl; +using std::map; using std::string; using std::to_string; using std::vector; @@ -20,8 +23,29 @@ namespace Graph { using Edge = std::pair; using BackEdgeList = vector; using SimpleLoop = vector; -} +struct BackEdgeSearcher { + BackEdgeSearcher(BasicBlock *entry) { dfsrun(entry); } + + void dfsrun(BasicBlock *bb); + + vector path; + map vis; + BackEdgeList edges; +}; +} // namespace Graph + +namespace Analysis{ +struct LoopAnalysis { + LoopAnalysis(const Graph::SimpleLoop &); + LoopAnalysis() = delete; + + enum { INT, FLOAT, UNDEF} Type; + union { + int v; + float fv; + } initial, delta, threshold; +};} /* This is a class to unroll simple loops: * - strict structure: diff --git a/src/optimization/LoopUnroll.cpp b/src/optimization/LoopUnroll.cpp index e4351d881e413b4fc5519c2d41e5944fa7d1dce6..7b06ce3e1fce1b1999d0ac89f908b9c01edfe7e7 100644 --- a/src/optimization/LoopUnroll.cpp +++ b/src/optimization/LoopUnroll.cpp @@ -1,52 +1,110 @@ #include "LoopUnroll.hpp" #include "BasicBlock.h" +#include "Constant.h" #include "Function.h" #include "Instruction.h" #include "syntax_analyzer.h" -#include #include using std::find; -using std::map; using namespace Graph; +using namespace Analysis; -struct BackEdgeSearcher { - BackEdgeSearcher(BasicBlock *entry) { dfsrun(entry); } - - void dfsrun(BasicBlock *bb) { - vis[bb] = true; - path.push_back(bb); - for (auto succ : bb->get_succ_basic_blocks()) { - if (vis[succ]) { - string type; - Edge edge(bb, succ); - if (find(path.rbegin(), path.rend(), succ) == path.rend()) { - type = "cross-edge"; - } else { - type = "back-edge"; - edges.push_back(edge); - } - /* cout << "find " << type << ": " << LoopUnroll::str(edge) - * << "\n"; */ - } else - dfsrun(succ); +void +LoopUnroll::run() { + for (auto &_f : m_->get_functions()) { + if (_f.is_declaration()) + continue; + auto func = &_f; + // cout << func->get_name() << endl; + auto belist = detect_back(func); + auto sloops = check_sloops(belist); + cout << "get simple loops for function " << func->get_name() << ":\n"; + for (auto sl : sloops) { + cout << "\t"; + for (auto p : sl) + cout << p->get_name() << " "; + cout << "\n"; } + } +} - path.pop_back(); +LoopAnalysis::LoopAnalysis(const Graph::SimpleLoop &sl) { + auto b = sl.front(); + auto e = sl.back(); + + // In `p`, get stop number(const) + Value *i; + auto rit = b->get_instructions().rbegin(); + assert(dynamic_cast(&*rit) && + "The end instruction of a block should be branch"); + i = (rit++)->get_operand(0); + assert(i == &*rit && dynamic_cast(&*rit) && + static_cast(&*rit)->get_cmp_op() == CmpInst::NE && + "should be neq 0"); + i = (rit++)->get_operand(0); + assert(i == &*rit && dynamic_cast(&*rit) && "neqz"); + i = (rit++)->get_operand(0); + assert( + i == &*rit && + (dynamic_cast(&*rit) or dynamic_cast(&*rit)) && + "cmp or fcmp"); + if (dynamic_cast(rit->get_operand(0)) or + dynamic_cast(rit->get_operand(1))) { + if (dynamic_cast(&*rit)) { + Type = FLOAT; + auto constfloat = dynamic_cast(rit->get_operand(1)); + assert(constfloat && + "the case const at operand(0) not implemented"); + threshold.fv = constfloat->get_value(); + } else { + Type = INT; + auto constint = dynamic_cast(rit->get_operand(1)); + assert(constint && "the case const at operand(0) not implemented"); + threshold.v = constint->get_value(); + } + } else { + Type = UNDEF; + return; } - vector path; - map vis; - BackEdgeList edges; -}; + // get control value and initial value + auto control = rit->get_operand(0); + auto it = b->get_instructions().begin(); + for (; it != b->get_instructions().end(); ++it) { + auto phi = dynamic_cast(&*it); + assert(phi && "unexpected structure for while block"); + if (&*it != control) + continue; + assert(phi->get_operand(3) == e && "the orther case not implemented"); + if (dynamic_cast(phi->get_operand(0))) { + switch (Type) { + case INT: + initial.v = static_cast(phi->get_operand(0)) + ->get_value(); + break; + case FLOAT: + initial.fv = static_cast(phi->get_operand(0)) + ->get_value(); + break; + case UNDEF: + assert(false); + break; + } + } else { + Type = UNDEF; + return; + } + } -BackEdgeList -LoopUnroll::detect_back(Function *func) { - BackEdgeSearcher search(func->get_entry_block()); - return search.edges; + // get delta + + // check correctness + + // count loop } vector @@ -58,8 +116,8 @@ LoopUnroll::check_sloops(const BackEdgeList &belist) const { b->get_pre_basic_blocks().size() != 2) continue; bool flag = true; - // the start node should have 2*2 degree, and others have 1*1 degree - // one exception: br to neg_idx_except + // the start node should have 2*2 degree, and others have 1*1 degree + // one exception: br to neg_idx_except for (auto p = e; p != b; p = *p->get_pre_basic_blocks().begin()) { if (p->get_pre_basic_blocks().size() != 1) { flag = false; @@ -91,21 +149,32 @@ LoopUnroll::check_sloops(const BackEdgeList &belist) const { } return sloops; } + +BackEdgeList +LoopUnroll::detect_back(Function *func) { + BackEdgeSearcher search(func->get_entry_block()); + return search.edges; +} + void -LoopUnroll::run() { - for (auto &_f : m_->get_functions()) { - if (_f.is_declaration()) - continue; - auto func = &_f; - // cout << func->get_name() << endl; - auto belist = detect_back(func); - auto sloops = check_sloops(belist); - cout << "get simple loops for function " << func->get_name() << ":\n"; - for (auto sl : sloops) { - cout << "\t"; - for (auto p : sl) - cout << p->get_name() << " "; - cout << "\n"; - } +BackEdgeSearcher::dfsrun(BasicBlock *bb) { + vis[bb] = true; + path.push_back(bb); + for (auto succ : bb->get_succ_basic_blocks()) { + if (vis[succ]) { + string type; + Edge edge(bb, succ); + if (find(path.rbegin(), path.rend(), succ) == path.rend()) { + type = "cross-edge"; + } else { + type = "back-edge"; + edges.push_back(edge); + } + /* cout << "find " << type << ": " << LoopUnroll::str(edge) + * << "\n"; */ + } else + dfsrun(succ); } + + path.pop_back(); }