#include "LoopUnroll.hpp" #include "BasicBlock.h" #include "Function.h" #include "Instruction.h" #include "syntax_analyzer.h" #include #include using std::find; using std::map; using namespace Graph; struct BackEdgeSearcher { BackEdgeSearcher(BasicBlock *entry) { dfsrun(entry); } void dfsrun(BasicBlock *bb) { vis[bb] = true; path.push_back(bb); for (auto succ : bb->get_succ_basic_blocks()) { if (vis[succ]) { string type; Edge edge(bb, succ); if (find(path.rbegin(), path.rend(), succ) == path.rend()) { type = "cross-edge"; } else { type = "back-edge"; edges.push_back(edge); } cout << "find " << type << ": " << LoopUnroll::str(edge) << "\n"; } else dfsrun(succ); } path.pop_back(); } vector path; map vis; BackEdgeList edges; }; BackEdgeList LoopUnroll::detect_back(Function *func) { BackEdgeSearcher search(func->get_entry_block()); return search.edges; } vector LoopUnroll::check_sloops(const BackEdgeList &belist) const { vector sloops; for (auto [e, b] : belist) { SimpleLoop sl; if (b->get_succ_basic_blocks().size() != 2 or b->get_pre_basic_blocks().size() != 2) continue; bool flag = true; // the start node should have 2*2 degree, and others have 1*1 degree // one exception: br to neg_idx_except for (auto p = e; p != b; p = *p->get_pre_basic_blocks().begin()) { if (p->get_pre_basic_blocks().size() != 1) { flag = false; break; } auto succ_bbs = p->get_succ_basic_blocks(); if (succ_bbs.size() != 1) { assert(succ_bbs.size() == 2); flag = false; for (auto bb : succ_bbs) { auto instr = &*bb->get_instructions().begin(); if (instr->is_call() and instr->get_operand(0) == neg_func) { flag = true; break; } } } if (not flag) break; sl.insert(sl.begin(), p); } if (not flag) continue; sl.insert(sl.begin(), b); if (flag) sloops.push_back(sl); } return sloops; } void LoopUnroll::run() { for (auto &_f : m_->get_functions()) { if (_f.is_declaration()) continue; auto func = &_f; // cout << func->get_name() << endl; auto belist = detect_back(func); auto sloops = check_sloops(belist); cout << "get simple loops for function " << func->get_name() << ":\n"; for (auto sl : sloops) { cout << "\t"; for (auto p : sl) cout << p->get_name() << " "; cout << "\n"; } } }