#ifndef LOOPUNROLL_HPP #define LOOPUNROLL_HPP #include "BasicBlock.h" #include "Instruction.h" #include "Module.h" #include "PassManager.hpp" #include "Type.h" #include "Value.h" #include #include #include #include using std::cout; using std::endl; using std::map; using std::string; using std::to_string; using std::vector; namespace Graph { using Edge = std::pair; using BackEdgeList = vector; using SimpleLoop = vector; struct BackEdgeSearcher { BackEdgeSearcher(BasicBlock *entry) { dfsrun(entry); } void dfsrun(BasicBlock *bb); vector path; map vis; BackEdgeList edges; }; } // namespace Graph namespace Analysis { #define Threshold 100 struct CountedLoop { typedef union { int v; float fv; } I_F; CountedLoop(const Graph::SimpleLoop &); CountedLoop() = delete; void new_emulate() { switch (Type) { case INT: emulate.v = initial.v; break; case FLOAT: emulate.fv = initial.fv; break; case UNDEF: assert(false); break; } }; // whether cur emulate shuold continue loop bool judge(); // emulate the delta part void next(); enum { INT, FLOAT, UNDEF } Type; I_F initial, stop, emulate; int count; // determined in construct function bool reverse; // for delta emulate part BinaryInst *delta; Instruction *control; }; } // namespace Analysis /* This is a class to unroll simple loops: * - strict structure: * --a->b->c--- * ^-----+ * - if the loop has constant times, unroll it. */ class LoopUnroll : public Pass { public: LoopUnroll(Module *_m) : Pass(_m) { m_->set_print_name(); for (auto &f : m_->get_functions()) if (f.get_name() == "neg_idx_except") { neg_func = &f; break; } if (neg_func == nullptr) assert(false && "find function neg_idx_except first!"); } LoopUnroll() = delete; void run() override; static string str(const Graph::Edge &edge) { return "(" + edge.first->get_name() + ", " + edge.second->get_name() + ")"; } private: Function *neg_func; map old2new; Graph::BackEdgeList detect_back(Function *); vector check_sloops(const Graph::BackEdgeList &) const; void unroll_loop(Graph::SimpleLoop &); Value *right_v(Value *v) const { if (old2new.find(v) != old2new.end()) { return old2new.at(v); } else return v; } BasicBlock *copy_instruction(Instruction &instr, BasicBlock *bb, // old block BasicBlock *BB, // new block BasicBlock *pre, // previous of whole loop BasicBlock *succ, // successor of whole loop Graph::SimpleLoop &sl, // the whole loop Analysis::CountedLoop &cl, // analysis info bool init); bool is_neg_block(BasicBlock *bb) const { auto instr = &*bb->get_instructions().begin(); return (instr->is_call() and instr->get_operand(0) == neg_func); } }; #endif