#include "LoopUnroll.hpp" #include "BasicBlock.h" #include "Constant.h" #include "Function.h" #include "Instruction.h" #include "Type.h" #include using std::find; using namespace Graph; using namespace Analysis; #define CONSTINT(x) ConstantInt::get(x, m_) #define CONSTFP(x) ConstantFP::get(x, m_) #define op(instr, i) right_v(instr.get_operand(i)) void LoopUnroll::run() { for (auto &_f : m_->get_functions()) { if (_f.is_declaration()) continue; auto func = &_f; // cout << func->get_name() << endl; auto belist = detect_back(func); auto sloops = check_sloops(belist); cout << "get simple loops for function " << func->get_name() << ":\n"; for (auto sl : sloops) { cout << "\t"; for (auto p : sl) cout << p->get_name() << " "; cout << "\n"; } for (auto sl : sloops) { unroll_loop(sl); } } } void LoopUnroll::unroll_loop(SimpleLoop &sl) { CountedLoop cl(sl); auto func = sl[0]->get_parent(); switch (cl.Type) { case CountedLoop::UNDEF: return; case CountedLoop::INT: cout << "Get CountedLoop in function " << func->get_name() << ":\n\t" << "initial value: " << cl.initial.v << ", stop value: " << cl.stop.v << "\n\t" << "delta: " << cl.delta->print() << "\n\t" << "loop times: " << cl.count << endl; break; case CountedLoop::FLOAT: cout << "Get CountedLoop in function " << func->get_name() << ":\n\t" << "initial value: " << cl.initial.fv << ", stop value: " << cl.stop.fv << "\n\t" << "delta: " << cl.delta->print() << "\n\t" << "loop times: " << cl.count << endl; break; } // get pre block and succ block auto b = *sl.begin(); auto e = *sl.rbegin(); BasicBlock *pre, *succ; for (auto p : b->get_pre_basic_blocks()) { if (p != e) { pre = p; break; } } for (auto s : b->get_succ_basic_blocks()) { if (find(sl.begin(), sl.end(), s) == sl.end()) { succ = s; break; } } auto BB = BasicBlock::create(m_, "", func); auto nb = BB; // unroll loop: copy instructions bool init = true; for (cl.new_emulate(); cl.judge(); cl.next()) { for (auto bb : sl) { for (auto &instr : bb->get_instructions()) { BB = copy_instruction(instr, bb, BB, pre, succ, sl, cl, init); } } init = false; } for (auto &instr : b->get_instructions()) BB = copy_instruction(instr, b, BB, pre, succ, sl, cl, false); BranchInst::create_br(succ, BB); // correct pre's br auto pre_br = dynamic_cast(&*pre->get_instructions().rbegin()); if (pre_br->is_cond_br()) { assert(pre_br->get_operand(1) == b or pre_br->get_operand(2) == b); if (pre_br->get_operand(1) == b) pre_br->set_operand(1, nb); else pre_br->set_operand(2, nb); } else { assert(pre_br->get_operand(0) == b); pre_br->set_operand(0, nb); } /* new links */ // nb & pre nb->add_pre_basic_block(pre); pre->add_succ_basic_block(nb); // BB & succ: br maintain the links /* succ->add_pre_basic_block(BB); * BB->add_succ_basic_block(succ); */ // old links: remove links from old graph pre->remove_succ_basic_block(b); succ->remove_pre_basic_block(b); // remove blocks in simpleloop from func->get_basic_blocks() for (auto b : sl) func->get_basic_blocks().remove(b); // neg block's pre blocks // replace use for (auto &bb : func->get_basic_blocks()) { for (auto &instr : bb.get_instructions()) { for (int i = 0; i < instr.get_num_operand(); ++i) { if (old2new.find(instr.get_operand(i)) != old2new.end()) { instr.set_operand(i, old2new.at(instr.get_operand(i))); } } } } } CountedLoop::CountedLoop(const Graph::SimpleLoop &sl) : Type(UNDEF), count(0) { auto b = sl.front(); auto e = sl.back(); Value *i, *control_op; PhiInst *phi; // In `p`, get stop number(const) auto rit = b->get_instructions().rbegin(); assert(dynamic_cast(&*rit) && "The end instruction of a block should be branch"); i = (rit++)->get_operand(0); assert(i == &*rit); if (dynamic_cast(&*rit)) { assert(static_cast(&*rit)->get_cmp_op() == CmpInst::NE && "should be neq 0"); } else if (dynamic_cast(&*rit)) { assert(static_cast(&*rit)->get_cmp_op() == FCmpInst::NE && "should be neq 0"); } else assert(false && "should not have this case"); control = dynamic_cast(&*rit); // only `ne` case i = (rit++)->get_operand(0); if (dynamic_cast(&*rit)) { assert(i == &*rit); i = (rit++)->get_operand(0); control = dynamic_cast(&*rit); assert(i == &*rit && (dynamic_cast(control) or dynamic_cast(control)) && "cmp or fcmp"); } if (dynamic_cast(control->get_operand(0)) or dynamic_cast(control->get_operand(1))) { if (dynamic_cast(&*control)) { Type = FLOAT; auto constfloat = dynamic_cast(control->get_operand(1)); assert(constfloat && "the case const at operand(0) not implemented"); stop.fv = constfloat->get_value(); } else { Type = INT; auto constint = dynamic_cast(control->get_operand(1)); assert(constint && "the case const at operand(0) not implemented"); stop.v = constint->get_value(); } } else goto can_not_count; // get control value and initial value control_op = control->get_operand(0); phi = dynamic_cast(control_op); if (phi == nullptr) goto can_not_count; assert(phi->get_parent() == b && "unexpected structure for while block"); assert(phi->get_operand(3) == e && "the orther case not implemented"); if (dynamic_cast(phi->get_operand(0))) { switch (Type) { case INT: initial.v = static_cast(phi->get_operand(0)) ->get_value(); break; case FLOAT: initial.fv = static_cast(phi->get_operand(0))->get_value(); break; case UNDEF: assert(false); break; } } else goto can_not_count; // get delta, maybe `control` op `const` or `const` op `control` delta = dynamic_cast(phi->get_operand(2)); if (delta == nullptr) goto can_not_count; if (delta->get_operand(0) != control_op and delta->get_operand(1) != control_op) goto can_not_count; if (delta->get_operand(0) == control_op) { reverse = false; if (dynamic_cast(delta->get_operand(1)) == nullptr) goto can_not_count; } else { // control at [1] reverse = true; if (dynamic_cast(delta->get_operand(0)) == nullptr) goto can_not_count; } // check correctness // count loop new_emulate(); for (count = 0; judge(); ++count) { if (count > Threshold) goto can_not_count; next(); } return; can_not_count: Type = UNDEF; } bool CountedLoop::judge() { bool flag; // cycle judge switch (Type) { case INT: { switch (static_cast(control)->get_cmp_op()) { case CmpInst::EQ: flag = emulate.v == stop.v; break; case CmpInst::NE: flag = emulate.v != stop.v; break; case CmpInst::GT: flag = emulate.v > stop.v; break; case CmpInst::GE: flag = emulate.v >= stop.v; break; case CmpInst::LT: flag = emulate.v < stop.v; break; case CmpInst::LE: flag = emulate.v <= stop.v; break; } break; } case FLOAT: { switch (static_cast(control)->get_cmp_op()) { case FCmpInst::EQ: flag = emulate.fv == stop.fv; break; case FCmpInst::NE: flag = emulate.fv != stop.fv; break; case FCmpInst::GT: flag = emulate.fv > stop.fv; break; case FCmpInst::GE: flag = emulate.fv >= stop.fv; break; case FCmpInst::LT: flag = emulate.fv < stop.fv; break; case FCmpInst::LE: flag = emulate.fv <= stop.fv; break; } break; } case UNDEF: assert(false); } return flag; } void CountedLoop::next() { switch (Type) { case INT: { int op2 = static_cast(delta->get_operand(reverse ? 0 : 1)) ->get_value(); switch (delta->get_instr_type()) { case Instruction::add: emulate.v += op2; break; case Instruction::sub: emulate.v = (reverse ? -1 : 1) * (emulate.v - op2); break; case Instruction::mul: emulate.v *= op2; break; case Instruction::sdiv: emulate.v = (reverse ? op2 / emulate.v : emulate.v / op2); break; default: assert(false && "not implemented"); break; } break; } case FLOAT: { float op2 = static_cast(delta->get_operand(reverse ? 0 : 1)) ->get_value(); switch (delta->get_instr_type()) { case Instruction::fadd: emulate.fv += op2; break; case Instruction::fsub: emulate.fv = (reverse ? -1 : 1) * (emulate.fv - op2); break; case Instruction::fmul: emulate.fv *= op2; break; case Instruction::fdiv: emulate.fv = (reverse ? (op2 / emulate.fv) : (emulate.fv / op2)); break; default: assert(false && "not implemented"); break; } break; } case UNDEF: assert(false); } } vector LoopUnroll::check_sloops(const BackEdgeList &belist) const { vector sloops; for (auto [e, b] : belist) { SimpleLoop sl; if (b->get_succ_basic_blocks().size() != 2 or b->get_pre_basic_blocks().size() != 2) continue; bool flag = true; // the start node should have 2*2 degree, and others have 1*1 degree // one exception: br to neg_idx_except for (auto p = e; p != b; p = *p->get_pre_basic_blocks().begin()) { if (p->get_pre_basic_blocks().size() != 1) { flag = false; break; } auto succ_bbs = p->get_succ_basic_blocks(); if (succ_bbs.size() != 1) { assert(succ_bbs.size() == 2); flag = false; for (auto bb : succ_bbs) { if (is_neg_block(bb)) { flag = true; break; } } } if (not flag) break; sl.insert(sl.begin(), p); } if (not flag) continue; sl.insert(sl.begin(), b); if (flag) sloops.push_back(sl); } return sloops; } BackEdgeList LoopUnroll::detect_back(Function *func) { BackEdgeSearcher search(func->get_entry_block()); return search.edges; } void BackEdgeSearcher::dfsrun(BasicBlock *bb) { vis[bb] = true; path.push_back(bb); for (auto succ : bb->get_succ_basic_blocks()) { if (vis[succ]) { string type; Edge edge(bb, succ); if (find(path.rbegin(), path.rend(), succ) == path.rend()) { type = "cross-edge"; } else { type = "back-edge"; edges.push_back(edge); } /* cout << "find " << type << ": " << LoopUnroll::str(edge) * << "\n"; */ } else dfsrun(succ); } path.pop_back(); } BasicBlock * LoopUnroll::copy_instruction(Instruction &instr, BasicBlock *bb, BasicBlock *BB, BasicBlock *pre, BasicBlock *succ, SimpleLoop &sl, CountedLoop &cl, bool init) { Value *n; auto b = *sl.begin(); auto e = *sl.rbegin(); auto func = b->get_parent(); switch (instr.get_instr_type()) { case Instruction::ret: if (instr.get_num_operand() == 0) ReturnInst::create_void_ret(BB); else ReturnInst::create_ret(op(instr, 0), BB); break; case Instruction::br: if (bb == b) { // do nothing assert(instr.get_operand(1) == succ or instr.get_operand(2) == succ && "unexpected structure"); } else { if (instr.get_num_operand() == 3) { // we know the neg block is always at operand(1) auto newBB = BasicBlock::create(m_, "", func); BranchInst::create_cond_br( op(instr, 0), static_cast(op(instr, 1)), newBB, BB); BB = newBB; } else { // do nothing assert(instr.get_num_operand() == 1 and instr.get_operand(0) == b); } } break; case Instruction::add: n = BinaryInst::create_add(op(instr, 0), op(instr, 1), BB, m_); break; case Instruction::sub: n = BinaryInst::create_sub(op(instr, 0), op(instr, 1), BB, m_); break; case Instruction::mul: n = BinaryInst::create_mul(op(instr, 0), op(instr, 1), BB, m_); break; case Instruction::sdiv: n = BinaryInst::create_sdiv(op(instr, 0), op(instr, 1), BB, m_); break; case Instruction::fadd: n = BinaryInst::create_fadd(op(instr, 0), op(instr, 1), BB, m_); break; case Instruction::fsub: n = BinaryInst::create_fsub(op(instr, 0), op(instr, 1), BB, m_); break; case Instruction::fmul: n = BinaryInst::create_fmul(op(instr, 0), op(instr, 1), BB, m_); break; case Instruction::fdiv: n = BinaryInst::create_fdiv(op(instr, 0), op(instr, 1), BB, m_); break; case Instruction::alloca: n = AllocaInst::create_alloca( static_cast(&instr)->get_alloca_type(), BB); break; case Instruction::load: n = LoadInst::create_load( static_cast(&instr)->get_load_type(), op(instr, 0), BB); break; case Instruction::store: StoreInst::create_store(op(instr, 0), op(instr, 1), BB); break; case Instruction::cmp: n = CmpInst::create_cmp( static_cast(&instr)->get_cmp_op(), op(instr, 0), op(instr, 1), BB, m_); break; case Instruction::fcmp: n = FCmpInst::create_fcmp( static_cast(&instr)->get_cmp_op(), op(instr, 0), op(instr, 1), BB, m_); break; case Instruction::phi: { // make it a copy assert(bb == b); Value *v, *zero; // zero switch (cl.Type) { case CountedLoop::INT: zero = CONSTINT(0); break; case CountedLoop::FLOAT: zero = CONSTFP(0); break; case CountedLoop::UNDEF: assert(false); } // v if (&instr == cl.control) { switch (cl.Type) { case CountedLoop::INT: v = CONSTINT(cl.emulate.v); break; case CountedLoop::FLOAT: v = CONSTFP(cl.emulate.fv); break; case CountedLoop::UNDEF: assert(false); } } else if (init) { assert(instr.get_operand(1) == pre); v = op(instr, 0); } else { assert(instr.get_operand(3) == e); v = op(instr, 2); } switch (cl.Type) { case CountedLoop::INT: n = BinaryInst::create_add(zero, v, BB, m_); break; case CountedLoop::FLOAT: n = BinaryInst::create_fadd(zero, v, BB, m_); break; case CountedLoop::UNDEF: assert(false); } break; } case Instruction::call: { vector args; for (int i = 1; i < instr.get_num_operand(); ++i) args.push_back(op(instr, i)); n = CallInst::create( static_cast(op(instr, 0)), args, BB); break; } case Instruction::getelementptr: { vector idxs; for (int i = 1; i < instr.get_num_operand(); ++i) idxs.push_back(op(instr, i)); n = GetElementPtrInst::create_gep(op(instr, 0), idxs, BB); break; } case Instruction::zext: n = ZextInst::create_zext( op(instr, 0), static_cast(&instr)->get_dest_type(), BB); break; case Instruction::fptosi: n = FpToSiInst::create_fptosi( op(instr, 0), static_cast(&instr)->get_dest_type(), BB); break; case Instruction::sitofp: n = SiToFpInst::create_sitofp( op(instr, 0), static_cast(&instr)->get_dest_type(), BB); break; } if (not instr.is_void()) old2new[&instr] = n; return BB; }