Commit 805d36ad authored by lxq's avatar lxq

ready to finish all the functional test!

parent e23c2e2e
...@@ -161,4 +161,10 @@ op6: <8, 8> ...@@ -161,4 +161,10 @@ op6: <8, 8>
- `tests/4-ir-opt/testcases/GVN/performance` - `tests/4-ir-opt/testcases/GVN/performance`
## 局限性
- GEP的取巧设计
- 未考虑指令寻址的立即数
- -
...@@ -15,10 +15,12 @@ ...@@ -15,10 +15,12 @@
#include <string> #include <string>
#define __PRINT_ORI__ #define __PRINT_ORI__
#define __RO_PART__ // #define __RO_PART__
#define __PRINT_COMMENT__ #define __PRINT_COMMENT__
// #a = 8, #t = 9, reserve $t0, $t1 for temporary // #a = 8, #t = 9, reserve $t0, $t1 for temporary
#define R_USABLE 17 - 2 #define R_USABLE (17 - 2)
// #fa = 8, #ft=16, reserve $ft0, $ft1 for temporary
#define FR_USABLE (24 - 2)
#define ARG_R 8 #define ARG_R 8
#include <map> #include <map>
...@@ -40,7 +42,12 @@ using std::vector; ...@@ -40,7 +42,12 @@ using std::vector;
class CodeGen { class CodeGen {
public: public:
CodeGen(Module *m_) : m(m_), LRA(m_, phi_map), RA(R_USABLE, ARG_R) {} CodeGen(Module *m_)
: cmp_zext_cnt(0)
, m(m_)
, LRA(m_, phi_map)
, RA_int(R_USABLE, false)
, RA_float(FR_USABLE, true) {}
string print() { string print() {
string result; string result;
...@@ -77,7 +84,8 @@ class CodeGen { ...@@ -77,7 +84,8 @@ class CodeGen {
vector<string> output; vector<string> output;
// register allocation // register allocation
LRA::LiveRangeAnalyzer LRA; LRA::LiveRangeAnalyzer LRA;
RA::RegAllocator RA; LRA::LVITS LVITS_int, LVITS_float;
RA::RegAllocator RA_int, RA_float;
// some instruction has lvalue, but is stack-allocated, // some instruction has lvalue, but is stack-allocated,
// we need this variable to track the reg name which has rvalue. // we need this variable to track the reg name which has rvalue.
// this variable is maintain by gencopy() and LoadInst. // this variable is maintain by gencopy() and LoadInst.
...@@ -147,7 +155,7 @@ class CodeGen { ...@@ -147,7 +155,7 @@ class CodeGen {
gencopy(lhs_reg, rhs_reg, is_float); gencopy(lhs_reg, rhs_reg, is_float);
return true; return true;
} }
void gencopy(string lhs_reg, string rhs_reg, bool is_float = false) { void gencopy(string lhs_reg, string rhs_reg, bool is_float) {
if (rhs_reg != lhs_reg) { if (rhs_reg != lhs_reg) {
if (is_float) if (is_float)
output.push_back("fmov.s " + lhs_reg + ", " + rhs_reg); output.push_back("fmov.s " + lhs_reg + ", " + rhs_reg);
...@@ -215,7 +223,9 @@ class CodeGen { ...@@ -215,7 +223,9 @@ class CodeGen {
return true; return true;
if (instr->is_fcmp() or instr->is_cmp() or instr->is_zext()) if (instr->is_fcmp() or instr->is_cmp() or instr->is_zext())
return true; return true;
if (RA.get().find(instr) != RA.get().end()) auto regmap = (instr->get_type()->is_float_type() ? RA_float.get()
: RA_int.get());
if (regmap.find(instr) != regmap.end())
return true; return true;
return false; return false;
...@@ -225,17 +235,23 @@ class CodeGen { ...@@ -225,17 +235,23 @@ class CodeGen {
return (is_float ? "$ft" : "$t") + to_string(i); return (is_float ? "$ft" : "$t") + to_string(i);
} }
static string regname(int i, bool is_float = false) { static string regname(uint i, bool is_float) {
string name; string name;
if (is_float) { if (is_float) {
assert(false && "not implemented!"); // assert(false && "not implemented!");
if (1 <= i and i <= 8)
name = "$fa" + to_string(i - 1);
else if (9 <= i and i <= FR_USABLE)
name = "$ft" + to_string(i - 9 + 2);
else
name = "WRONG_REG_" + to_string(i);
} else { } else {
if (1 <= i and i <= 8) if (1 <= i and i <= 8)
name = "$a" + to_string(i - 1); name = "$a" + to_string(i - 1);
else if (9 <= i and i <= R_USABLE) else if (9 <= i and i <= R_USABLE)
name = "$t" + to_string(i - 9 + 2); name = "$t" + to_string(i - 9 + 2);
else else
name = "WRONG_REG" + to_string(i); name = "WRONG_REG_" + to_string(i);
} }
return name; return name;
} }
......
...@@ -15,12 +15,14 @@ using std::string; ...@@ -15,12 +15,14 @@ using std::string;
using std::to_string; using std::to_string;
using std::vector; using std::vector;
#define UNINITIAL -1
#define __LRA_PRINT__ #define __LRA_PRINT__
namespace LRA { namespace LRA {
struct Interval { struct Interval {
Interval(int a = -1, int b = -1) : i(a), j(b) {} Interval(int a = UNINITIAL, int b = UNINITIAL) : i(a), j(b) {}
int i; // 0 means uninitialized int i; // 0 means uninitialized
int j; int j;
}; };
...@@ -49,15 +51,17 @@ class LiveRangeAnalyzer { ...@@ -49,15 +51,17 @@ class LiveRangeAnalyzer {
// void run(); // void run();
void run(Function *); void run(Function *);
void clear(); void clear();
void print(Function *func, bool printSet = false, bool printInt = false) const; void print(Function *func,
string print_liveSet(const LiveSet &ls) const { bool printSet = false,
bool printInt = false) const;
static string print_liveSet(const LiveSet &ls) {
string s = "[ "; string s = "[ ";
for (auto k : ls) for (auto k : ls)
s += k->get_name() + " "; s += k->get_name() + " ";
s += "]"; s += "]";
return s; return s;
} }
string print_interval(Interval &i) const { static string print_interval(const Interval &i) {
return "<" + to_string(i.i) + ", " + to_string(i.j) + ">"; return "<" + to_string(i.i) + ", " + to_string(i.j) + ">";
} }
const LVITS &get() { return liveIntervals; } const LVITS &get() { return liveIntervals; }
...@@ -91,8 +95,12 @@ class LiveRangeAnalyzer { ...@@ -91,8 +95,12 @@ class LiveRangeAnalyzer {
LiveSet transferFunction(Instruction *); LiveSet transferFunction(Instruction *);
public: public:
const decltype(instr_id) &get_instr_id() { return instr_id; } const decltype(instr_id) &get_instr_id() const { return instr_id; }
const decltype(intervalmap) &get_interval_map() { return intervalmap; } const decltype(intervalmap) &get_interval_map() const {
return intervalmap;
}
const decltype(IN) &get_in_set() const { return IN; }
const decltype(OUT) &get_out_set() const { return OUT; }
}; };
} // namespace LRA } // namespace LRA
#endif #endif
...@@ -14,24 +14,27 @@ using namespace LRA; ...@@ -14,24 +14,27 @@ using namespace LRA;
namespace RA { namespace RA {
#define MAXR 32 #define MAXR 32
#define ARG_MAX_R 8
struct ActiveCMP { struct ActiveCMP {
bool operator()(LiveInterval const &lhs, LiveInterval const &rhs) const { bool operator()(LiveInterval const &lhs, LiveInterval const &rhs) const {
if (lhs.first.j != rhs.first.j) if (lhs.first.j != rhs.first.j)
return lhs.first.j < rhs.first.j; return lhs.first.j < rhs.first.j;
else else if (lhs.first.i != rhs.first.i)
return lhs.first.i < rhs.first.i; return lhs.first.i < rhs.first.i;
else
return lhs.second < rhs.second;
} }
}; };
class RegAllocator { class RegAllocator {
public: public:
RegAllocator(const uint R_, const uint ARG_R_) RegAllocator(const uint R_, bool fl) : FLOAT(fl), R(R_), used{false} {
: R(R_), ARG_MAX_R(ARG_R_), used{false} { cout << "RegAllocator initialize: R=" << R << endl;
assert(R <= MAXR); assert(R <= MAXR);
} }
RegAllocator() = delete; RegAllocator() = delete;
bool no_reg_alloca(Value *v) const; static bool no_reg_alloca(Value *v);
// input set is sorted by increasing start point // input set is sorted by increasing start point
void LinearScan(const LVITS &, Function *); void LinearScan(const LVITS &, Function *);
const map<Value *, int> &get() const { return regmap; } const map<Value *, int> &get() const { return regmap; }
...@@ -42,8 +45,8 @@ class RegAllocator { ...@@ -42,8 +45,8 @@ class RegAllocator {
private: private:
Function *cur_func; Function *cur_func;
const bool FLOAT;
const uint R; const uint R;
const uint ARG_MAX_R;
bool used[MAXR + 1]; // index range: 1 ~ R bool used[MAXR + 1]; // index range: 1 ~ R
map<Value *, int> regmap; map<Value *, int> regmap;
// sorted by increasing end point // sorted by increasing end point
......
This diff is collapsed.
...@@ -125,31 +125,34 @@ LiveRangeAnalyzer::run(Function *func) { ...@@ -125,31 +125,34 @@ LiveRangeAnalyzer::run(Function *func) {
} }
} }
// argument should be in the IN-set of Entry // argument should be in the IN-set of Entry
assert(IN.find(0) == IN.end() and OUT.find(0) == OUT.end() &&
"no instr_id will be mapped to 0");
IN[0] = OUT[0] = {};
for (auto arg : func->get_args()) for (auto arg : func->get_args())
IN[1].insert(arg); IN[0].insert(arg);
make_interval(func); make_interval(func);
#ifdef __LRA_PRINT__ #ifdef __LRA_PRINT__
print(func, true, true); print(func, false, true);
#endif #endif
} }
void void
LiveRangeAnalyzer::make_interval(Function *) { LiveRangeAnalyzer::make_interval(Function *) {
for (int time = 1; time <= ir_cnt; ++time) { for (int time = 0; time <= ir_cnt; ++time) {
for (auto op : IN.at(time)) { for (auto op : IN.at(time)) {
auto &interval = intervalmap[op]; auto &interval = intervalmap[op];
if (interval.i == -1) // uninitialized if (interval.i == UNINITIAL) // uninitialized
interval.i = time - 1; interval.i = interval.j = time;
else else
interval.j = time - 1; interval.j = time;
} }
for (auto op : OUT.at(time)) { for (auto op : OUT.at(time)) {
auto &interval = intervalmap[op]; auto &interval = intervalmap[op];
if (interval.i == -1) // uninitialized if (interval.i == UNINITIAL) // uninitialized
interval.i = time; interval.i = interval.j = time + 1;
else else
interval.j = time; interval.j = time + 1;
} }
} }
for (auto &[op, interval] : intervalmap) for (auto &[op, interval] : intervalmap)
...@@ -197,6 +200,11 @@ LiveRangeAnalyzer::print(Function *func, ...@@ -197,6 +200,11 @@ LiveRangeAnalyzer::print(Function *func,
bool printSet, bool printSet,
bool printInt) const { // for debug bool printInt) const { // for debug
cout << "Function " << func->get_name() << endl; cout << "Function " << func->get_name() << endl;
cout << "0. Entry" << endl;
if (printSet) {
cout << "\tin-set: " + print_liveSet(IN.at(0)) << "\n";
cout << "\tout-set: " + print_liveSet(OUT.at(0)) << "\n";
}
for (auto &bb : func->get_basic_blocks()) { for (auto &bb : func->get_basic_blocks()) {
for (auto &instr : bb.get_instructions()) { for (auto &instr : bb.get_instructions()) {
if (instr.is_phi()) // ignore phi if (instr.is_phi()) // ignore phi
...@@ -220,8 +228,7 @@ LiveRangeAnalyzer::print(Function *func, ...@@ -220,8 +228,7 @@ LiveRangeAnalyzer::print(Function *func,
} }
} }
// normal ir // normal ir
cout << instr_id.at(&instr) << ". " << instr.print() << " # " cout << instr_id.at(&instr) << ". " << instr.print() << endl;
<< &instr << endl;
if (not printSet) if (not printSet)
continue; continue;
auto idx = instr_id.at(&instr); auto idx = instr_id.at(&instr);
......
...@@ -10,6 +10,9 @@ using std::for_each; ...@@ -10,6 +10,9 @@ using std::for_each;
using namespace RA; using namespace RA;
#define ASSERT_CMPINST_USED_ONCE(cmpinst) \
(assert(cmpinst->get_use_list().size() == 1))
int int
get_arg_id(Argument *arg) { get_arg_id(Argument *arg) {
auto args = arg->get_parent()->get_args(); auto args = arg->get_parent()->get_args();
...@@ -23,7 +26,7 @@ get_arg_id(Argument *arg) { ...@@ -23,7 +26,7 @@ get_arg_id(Argument *arg) {
} }
bool bool
RegAllocator::no_reg_alloca(Value *v) const { RegAllocator::no_reg_alloca(Value *v) {
auto instr = dynamic_cast<Instruction *>(v); auto instr = dynamic_cast<Instruction *>(v);
auto arg = dynamic_cast<Argument *>(v); auto arg = dynamic_cast<Argument *>(v);
if (instr) { if (instr) {
...@@ -31,21 +34,26 @@ RegAllocator::no_reg_alloca(Value *v) const { ...@@ -31,21 +34,26 @@ RegAllocator::no_reg_alloca(Value *v) const {
if (instr->is_alloca() or instr->is_cmp() or instr->is_fcmp()) if (instr->is_alloca() or instr->is_cmp() or instr->is_fcmp())
return true; return true;
else if (instr->is_zext()) { // only alloca for true use else if (instr->is_zext()) { // only alloca for true use
for (auto use : instr->get_use_list()) bool alloc;
if (not dynamic_cast<Instruction *>(use.val_)->is_br()) { ASSERT_CMPINST_USED_ONCE(instr);
auto instr = static_cast<Instruction *>(use.val_); auto use_ins = dynamic_cast<Instruction *>(
if (instr->is_cmp()) { // special case for cmp again instr->get_use_list().begin()->val_);
auto cmp = static_cast<CmpInst *>(instr); // assert(use_ins != nullptr && "should only be instruction?");
assert(cmp->get_cmp_op() == CmpInst::NE); if (use_ins->is_cmp() and
auto uses = instr->get_use_list(); static_cast<CmpInst *>(use_ins)->get_cmp_op() == CmpInst::NE) {
assert(uses.size() == 1 and // this case:
dynamic_cast<Instruction *>(uses.begin()->val_) // %op0 = icmp slt i32 1, 2
->is_br()); // %op1 = zext i1 %op0 to i32
// %op2 = icmp ne i32 %op1, 0 # <- if judges to here
// br i1 %op2, label %label3, label %label5
ASSERT_CMPINST_USED_ONCE(use_ins);
auto use2_ins = dynamic_cast<Instruction *>(
use_ins->get_use_list().begin()->val_);
alloc = not(use2_ins->is_br());
} else
alloc = true;
} else return not(alloc);
return false;
}
return true;
} else // then always allocate } else // then always allocate
return false; return false;
} }
...@@ -64,19 +72,22 @@ RegAllocator::reset(Function *func) { ...@@ -64,19 +72,22 @@ RegAllocator::reset(Function *func) {
} }
int int
RegAllocator::ReserveForArg(const LVITS &Liveints) { RegAllocator::ReserveForArg(const LVITS &liveints) {
auto args = cur_func->get_args(); auto args = cur_func->get_args();
auto it_int = Liveints.begin(); auto it_int = liveints.begin();
auto it_arg = args.begin(); auto it_arg = args.begin();
int reg; int reg;
for (reg = 1; reg <= args.size() and reg <= ARG_MAX_R; ++reg) { for (reg = 1; reg <= args.size() and reg <= ARG_MAX_R; ++reg) {
auto arg = *it_arg; auto arg = *it_arg;
auto liveint = *it_int; if (not(FLOAT ^ arg->get_type()->is_float_type())) {
assert(arg == liveint.second && "arg should be in order in liveints"); auto liveint = *it_int;
assert(arg == liveint.second &&
"arg should be in order in liveints");
used[reg] = true; used[reg] = true;
regmap[arg] = reg; regmap[arg] = reg;
active.insert(liveint); active.insert(liveint);
}
++it_arg, ++it_int; ++it_arg, ++it_int;
} }
return reg; return reg;
...@@ -88,6 +99,8 @@ RegAllocator::LinearScan(const LVITS &liveints, Function *func) { ...@@ -88,6 +99,8 @@ RegAllocator::LinearScan(const LVITS &liveints, Function *func) {
ReserveForArg(liveints); ReserveForArg(liveints);
int reg; int reg;
for (auto liveint : liveints) { for (auto liveint : liveints) {
if (FLOAT ^ liveint.second->get_type()->is_float_type())
continue;
if (dynamic_cast<Argument *>(liveint.second)) { if (dynamic_cast<Argument *>(liveint.second)) {
continue; continue;
} }
...@@ -99,6 +112,13 @@ RegAllocator::LinearScan(const LVITS &liveints, Function *func) { ...@@ -99,6 +112,13 @@ RegAllocator::LinearScan(const LVITS &liveints, Function *func) {
else { else {
for (reg = 1; reg <= R and used[reg]; ++reg) for (reg = 1; reg <= R and used[reg]; ++reg)
; ;
if (reg == 16) {
for (auto [interval, v] : active) {
cout << "already allocated: " << v->get_name() << " ~ "
<< regmap.at(v) << endl;
}
assert(false);
}
used[reg] = true; used[reg] = true;
regmap[liveint.second] = reg; regmap[liveint.second] = reg;
active.insert(liveint); active.insert(liveint);
...@@ -108,10 +128,12 @@ RegAllocator::LinearScan(const LVITS &liveints, Function *func) { ...@@ -108,10 +128,12 @@ RegAllocator::LinearScan(const LVITS &liveints, Function *func) {
void void
RegAllocator::ExpireOldIntervals(LiveInterval liveint) { RegAllocator::ExpireOldIntervals(LiveInterval liveint) {
auto it = active.begin(); auto it = active.begin();
for (; it != active.end() and it->first.j < liveint.first.i; ++it) for (; it != active.end() and it->first.j < liveint.first.i; ++it)
used[regmap.at(it->second)] = false; used[regmap.at(it->second)] = false;
active.erase(active.begin(), it); active.erase(active.begin(), it);
} }
void void
...@@ -122,6 +144,7 @@ RegAllocator::SpillAtInterval(LiveInterval liveint) { ...@@ -122,6 +144,7 @@ RegAllocator::SpillAtInterval(LiveInterval liveint) {
if (spill.first.j > liveint.first.j) { if (spill.first.j > liveint.first.j) {
// cancel reg allocation for spill // cancel reg allocation for spill
regmap[liveint.second] = regmap.at(spill.second); regmap[liveint.second] = regmap.at(spill.second);
active.insert(liveint);
active.erase(spill); active.erase(spill);
regmap.erase(spill.second); regmap.erase(spill.second);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment