Commit 805d36ad authored by lxq's avatar lxq

ready to finish all the functional test!

parent e23c2e2e
...@@ -161,4 +161,10 @@ op6: <8, 8> ...@@ -161,4 +161,10 @@ op6: <8, 8>
- `tests/4-ir-opt/testcases/GVN/performance` - `tests/4-ir-opt/testcases/GVN/performance`
## 局限性
- GEP的取巧设计
- 未考虑指令寻址的立即数
- -
...@@ -15,10 +15,12 @@ ...@@ -15,10 +15,12 @@
#include <string> #include <string>
#define __PRINT_ORI__ #define __PRINT_ORI__
#define __RO_PART__ // #define __RO_PART__
#define __PRINT_COMMENT__ #define __PRINT_COMMENT__
// #a = 8, #t = 9, reserve $t0, $t1 for temporary // #a = 8, #t = 9, reserve $t0, $t1 for temporary
#define R_USABLE 17 - 2 #define R_USABLE (17 - 2)
// #fa = 8, #ft=16, reserve $ft0, $ft1 for temporary
#define FR_USABLE (24 - 2)
#define ARG_R 8 #define ARG_R 8
#include <map> #include <map>
...@@ -40,7 +42,12 @@ using std::vector; ...@@ -40,7 +42,12 @@ using std::vector;
class CodeGen { class CodeGen {
public: public:
CodeGen(Module *m_) : m(m_), LRA(m_, phi_map), RA(R_USABLE, ARG_R) {} CodeGen(Module *m_)
: cmp_zext_cnt(0)
, m(m_)
, LRA(m_, phi_map)
, RA_int(R_USABLE, false)
, RA_float(FR_USABLE, true) {}
string print() { string print() {
string result; string result;
...@@ -77,7 +84,8 @@ class CodeGen { ...@@ -77,7 +84,8 @@ class CodeGen {
vector<string> output; vector<string> output;
// register allocation // register allocation
LRA::LiveRangeAnalyzer LRA; LRA::LiveRangeAnalyzer LRA;
RA::RegAllocator RA; LRA::LVITS LVITS_int, LVITS_float;
RA::RegAllocator RA_int, RA_float;
// some instruction has lvalue, but is stack-allocated, // some instruction has lvalue, but is stack-allocated,
// we need this variable to track the reg name which has rvalue. // we need this variable to track the reg name which has rvalue.
// this variable is maintain by gencopy() and LoadInst. // this variable is maintain by gencopy() and LoadInst.
...@@ -147,7 +155,7 @@ class CodeGen { ...@@ -147,7 +155,7 @@ class CodeGen {
gencopy(lhs_reg, rhs_reg, is_float); gencopy(lhs_reg, rhs_reg, is_float);
return true; return true;
} }
void gencopy(string lhs_reg, string rhs_reg, bool is_float = false) { void gencopy(string lhs_reg, string rhs_reg, bool is_float) {
if (rhs_reg != lhs_reg) { if (rhs_reg != lhs_reg) {
if (is_float) if (is_float)
output.push_back("fmov.s " + lhs_reg + ", " + rhs_reg); output.push_back("fmov.s " + lhs_reg + ", " + rhs_reg);
...@@ -215,7 +223,9 @@ class CodeGen { ...@@ -215,7 +223,9 @@ class CodeGen {
return true; return true;
if (instr->is_fcmp() or instr->is_cmp() or instr->is_zext()) if (instr->is_fcmp() or instr->is_cmp() or instr->is_zext())
return true; return true;
if (RA.get().find(instr) != RA.get().end()) auto regmap = (instr->get_type()->is_float_type() ? RA_float.get()
: RA_int.get());
if (regmap.find(instr) != regmap.end())
return true; return true;
return false; return false;
...@@ -225,17 +235,23 @@ class CodeGen { ...@@ -225,17 +235,23 @@ class CodeGen {
return (is_float ? "$ft" : "$t") + to_string(i); return (is_float ? "$ft" : "$t") + to_string(i);
} }
static string regname(int i, bool is_float = false) { static string regname(uint i, bool is_float) {
string name; string name;
if (is_float) { if (is_float) {
assert(false && "not implemented!"); // assert(false && "not implemented!");
if (1 <= i and i <= 8)
name = "$fa" + to_string(i - 1);
else if (9 <= i and i <= FR_USABLE)
name = "$ft" + to_string(i - 9 + 2);
else
name = "WRONG_REG_" + to_string(i);
} else { } else {
if (1 <= i and i <= 8) if (1 <= i and i <= 8)
name = "$a" + to_string(i - 1); name = "$a" + to_string(i - 1);
else if (9 <= i and i <= R_USABLE) else if (9 <= i and i <= R_USABLE)
name = "$t" + to_string(i - 9 + 2); name = "$t" + to_string(i - 9 + 2);
else else
name = "WRONG_REG" + to_string(i); name = "WRONG_REG_" + to_string(i);
} }
return name; return name;
} }
......
...@@ -15,12 +15,14 @@ using std::string; ...@@ -15,12 +15,14 @@ using std::string;
using std::to_string; using std::to_string;
using std::vector; using std::vector;
#define UNINITIAL -1
#define __LRA_PRINT__ #define __LRA_PRINT__
namespace LRA { namespace LRA {
struct Interval { struct Interval {
Interval(int a = -1, int b = -1) : i(a), j(b) {} Interval(int a = UNINITIAL, int b = UNINITIAL) : i(a), j(b) {}
int i; // 0 means uninitialized int i; // 0 means uninitialized
int j; int j;
}; };
...@@ -49,15 +51,17 @@ class LiveRangeAnalyzer { ...@@ -49,15 +51,17 @@ class LiveRangeAnalyzer {
// void run(); // void run();
void run(Function *); void run(Function *);
void clear(); void clear();
void print(Function *func, bool printSet = false, bool printInt = false) const; void print(Function *func,
string print_liveSet(const LiveSet &ls) const { bool printSet = false,
bool printInt = false) const;
static string print_liveSet(const LiveSet &ls) {
string s = "[ "; string s = "[ ";
for (auto k : ls) for (auto k : ls)
s += k->get_name() + " "; s += k->get_name() + " ";
s += "]"; s += "]";
return s; return s;
} }
string print_interval(Interval &i) const { static string print_interval(const Interval &i) {
return "<" + to_string(i.i) + ", " + to_string(i.j) + ">"; return "<" + to_string(i.i) + ", " + to_string(i.j) + ">";
} }
const LVITS &get() { return liveIntervals; } const LVITS &get() { return liveIntervals; }
...@@ -91,8 +95,12 @@ class LiveRangeAnalyzer { ...@@ -91,8 +95,12 @@ class LiveRangeAnalyzer {
LiveSet transferFunction(Instruction *); LiveSet transferFunction(Instruction *);
public: public:
const decltype(instr_id) &get_instr_id() { return instr_id; } const decltype(instr_id) &get_instr_id() const { return instr_id; }
const decltype(intervalmap) &get_interval_map() { return intervalmap; } const decltype(intervalmap) &get_interval_map() const {
return intervalmap;
}
const decltype(IN) &get_in_set() const { return IN; }
const decltype(OUT) &get_out_set() const { return OUT; }
}; };
} // namespace LRA } // namespace LRA
#endif #endif
...@@ -14,24 +14,27 @@ using namespace LRA; ...@@ -14,24 +14,27 @@ using namespace LRA;
namespace RA { namespace RA {
#define MAXR 32 #define MAXR 32
#define ARG_MAX_R 8
struct ActiveCMP { struct ActiveCMP {
bool operator()(LiveInterval const &lhs, LiveInterval const &rhs) const { bool operator()(LiveInterval const &lhs, LiveInterval const &rhs) const {
if (lhs.first.j != rhs.first.j) if (lhs.first.j != rhs.first.j)
return lhs.first.j < rhs.first.j; return lhs.first.j < rhs.first.j;
else else if (lhs.first.i != rhs.first.i)
return lhs.first.i < rhs.first.i; return lhs.first.i < rhs.first.i;
else
return lhs.second < rhs.second;
} }
}; };
class RegAllocator { class RegAllocator {
public: public:
RegAllocator(const uint R_, const uint ARG_R_) RegAllocator(const uint R_, bool fl) : FLOAT(fl), R(R_), used{false} {
: R(R_), ARG_MAX_R(ARG_R_), used{false} { cout << "RegAllocator initialize: R=" << R << endl;
assert(R <= MAXR); assert(R <= MAXR);
} }
RegAllocator() = delete; RegAllocator() = delete;
bool no_reg_alloca(Value *v) const; static bool no_reg_alloca(Value *v);
// input set is sorted by increasing start point // input set is sorted by increasing start point
void LinearScan(const LVITS &, Function *); void LinearScan(const LVITS &, Function *);
const map<Value *, int> &get() const { return regmap; } const map<Value *, int> &get() const { return regmap; }
...@@ -42,8 +45,8 @@ class RegAllocator { ...@@ -42,8 +45,8 @@ class RegAllocator {
private: private:
Function *cur_func; Function *cur_func;
const bool FLOAT;
const uint R; const uint R;
const uint ARG_MAX_R;
bool used[MAXR + 1]; // index range: 1 ~ R bool used[MAXR + 1]; // index range: 1 ~ R
map<Value *, int> regmap; map<Value *, int> regmap;
// sorted by increasing end point // sorted by increasing end point
......
...@@ -8,11 +8,14 @@ ...@@ -8,11 +8,14 @@
#include "Type.h" #include "Type.h"
#include "Value.h" #include "Value.h"
#include "ast.hpp" #include "ast.hpp"
#include "regalloc.hpp"
#include "syntax_analyzer.h"
#include <algorithm> #include <algorithm>
#include <cstdint> #include <cstdint>
#include <cstring> #include <cstring>
#include <deque> #include <deque>
#include <memory>
#include <ostream> #include <ostream>
#include <sstream> #include <sstream>
#include <string> #include <string>
...@@ -43,7 +46,7 @@ CodeGen::getRegName(Value *v, int i) const { ...@@ -43,7 +46,7 @@ CodeGen::getRegName(Value *v, int i) const {
bool find; bool find;
string name; string name;
bool is_float = v->get_type()->is_float_type(); bool is_float = v->get_type()->is_float_type();
auto regmap = RA.get(); auto regmap = (is_float ? RA_float.get() : RA_int.get());
if (regmap.find(v) == regmap.end()) { if (regmap.find(v) == regmap.end()) {
name = tmpregname(i, is_float); name = tmpregname(i, is_float);
find = false; find = false;
...@@ -75,7 +78,6 @@ CodeGen::getPhiMap() { ...@@ -75,7 +78,6 @@ CodeGen::getPhiMap() {
void void
CodeGen::run() { CodeGen::run() {
// TODO: implement
// 以下内容生成 int main() { return 0; } 的汇编代码 // 以下内容生成 int main() { return 0; } 的汇编代码
getPhiMap(); getPhiMap();
output.push_back(".text"); output.push_back(".text");
...@@ -94,14 +96,19 @@ CodeGen::run() { ...@@ -94,14 +96,19 @@ CodeGen::run() {
for (auto &func : m->get_functions()) { for (auto &func : m->get_functions()) {
if (not func.is_declaration()) { if (not func.is_declaration()) {
LRA.run(&func); LRA.run(&func);
RA.LinearScan(LRA.get(), &func); RA_int.LinearScan(LRA.get(), &func);
RA_float.LinearScan(LRA.get(), &func);
std::cout << "register map for function: " << func.get_name() std::cout << "integer register map for function: "
<< func.get_name() << std::endl;
RA_int.print([](int i) { return regname(i, false); });
std::cout << "float register map for function: " << func.get_name()
<< std::endl; << std::endl;
RA.print([](int i) { return regname(i); }); RA_float.print([](int i) { return regname(i, true); });
auto regmap = RA.get();
for (auto [_, op] : LRA.get()) { for (auto [_, op] : LRA.get()) {
auto regmap = op->get_type()->is_float_type() ? RA_float.get()
: RA_int.get();
if (regmap.find(op) == regmap.end()) if (regmap.find(op) == regmap.end())
std::cout << "no reg belongs to " << op->get_name() std::cout << "no reg belongs to " << op->get_name()
<< std::endl; << std::endl;
...@@ -177,8 +184,6 @@ CodeGen::ptrContent2reg(Value *ptr, string dest_reg) { ...@@ -177,8 +184,6 @@ CodeGen::ptrContent2reg(Value *ptr, string dest_reg) {
assert(false && "unknown type"); assert(false && "unknown type");
} }
void IR2assem(ZextInst *);
string string
CodeGen::value2reg(Value *v, int i, string recommend) { CodeGen::value2reg(Value *v, int i, string recommend) {
bool is_float = v->get_type()->is_float_type(); bool is_float = v->get_type()->is_float_type();
...@@ -193,7 +198,6 @@ CodeGen::value2reg(Value *v, int i, string recommend) { ...@@ -193,7 +198,6 @@ CodeGen::value2reg(Value *v, int i, string recommend) {
if (v == CONST_0) if (v == CONST_0)
return "$zero"; return "$zero";
auto constant = static_cast<Constant *>(v); auto constant = static_cast<Constant *>(v);
#ifdef __RO_PART__
if (ROdata.find(constant) == ROdata.end()) if (ROdata.find(constant) == ROdata.end())
ROdata[constant] = ".LC" + to_string(ROdata.size()); ROdata[constant] = ".LC" + to_string(ROdata.size());
string instr_ir, addr = ROdata[constant]; string instr_ir, addr = ROdata[constant];
...@@ -206,34 +210,6 @@ CodeGen::value2reg(Value *v, int i, string recommend) { ...@@ -206,34 +210,6 @@ CodeGen::value2reg(Value *v, int i, string recommend) {
// bug here: maybe // bug here: maybe
output.push_back("la.local " + tmp_ireg + ", " + addr); output.push_back("la.local " + tmp_ireg + ", " + addr);
output.push_back(instr_ir + " " + reg_name + ", " + tmp_ireg + ", 0"); output.push_back(instr_ir + " " + reg_name + ", " + tmp_ireg + ", 0");
#else
if (dynamic_cast<ConstantInt *>(constant)) {
int k = static_cast<ConstantInt *>(constant)->get_value();
if ((k & 0xfff) != k) {
output.push_back("lu12i.w " + reg_name + ", " +
to_string(k >> 12));
output.push_back("ori " + reg_name + ", " + reg_name + ", " +
to_string(k & 0xfff));
} else
output.push_back("ori " + reg_name + ", $r0, " + to_string(k));
} else if (dynamic_cast<ConstantFP *>(constant)) {
// move the binary code to int-reg, then use movgr2fr to move the
// value to float-reg
float k = static_cast<ConstantFP *>(constant)->get_value();
int hex_int = *(uint32_t *)&k;
if ((hex_int & 0xfff) != hex_int)
output.push_back("lu12i.w " + tmp_ireg + ", " +
to_string(hex_int >> 12));
if (hex_int & 0xfff)
output.push_back("ori " + tmp_ireg + ", " + tmp_ireg + ", " +
to_string(hex_int & 0xfff));
output.push_back("movgr2fr.w " + reg_name + ", " + tmp_ireg);
// output.push_back("ffint.s.w " + reg_name + ", " + reg_name);
} else
assert(false && "wait for completion");
#endif
} else if (dynamic_cast<GlobalVariable *>(v)) { } else if (dynamic_cast<GlobalVariable *>(v)) {
output.push_back("la.local " + reg_name + ", " + v->get_name()); output.push_back("la.local " + reg_name + ", " + v->get_name());
} else if (dynamic_cast<AllocaInst *>(v)) { } else if (dynamic_cast<AllocaInst *>(v)) {
...@@ -250,7 +226,7 @@ CodeGen::value2reg(Value *v, int i, string recommend) { ...@@ -250,7 +226,7 @@ CodeGen::value2reg(Value *v, int i, string recommend) {
if (*iter == v) if (*iter == v)
break; break;
if (id <= ARG_R) if (id <= ARG_R)
return regname(ARG_R); return regname(ARG_R, is_float);
else { else {
string instr_ir = is_float ? "fld" : "ld"; string instr_ir = is_float ? "fld" : "ld";
auto suff = suffix(v->get_type()); auto suff = suffix(v->get_type());
...@@ -270,7 +246,7 @@ CodeGen::value2reg(Value *v, int i, string recommend) { ...@@ -270,7 +246,7 @@ CodeGen::value2reg(Value *v, int i, string recommend) {
makeSureInRange(instr_ir + suff, makeSureInRange(instr_ir + suff,
reg_name, reg_name,
FP, FP,
off.at(v), -off.at(v),
instr_ir + "x" + suff, instr_ir + "x" + suff,
12, 12,
reg_name); reg_name);
...@@ -356,9 +332,10 @@ CodeGen::IR2assem(FpToSiInst *instr) { ...@@ -356,9 +332,10 @@ CodeGen::IR2assem(FpToSiInst *instr) {
assert(instr->get_operand(0)->get_type() == m->get_float_type()); assert(instr->get_operand(0)->get_type() == m->get_float_type());
assert(instr->get_dest_type() == m->get_int32_type()); assert(instr->get_dest_type() == m->get_int32_type());
string f_reg = value2reg(instr->get_operand(0)); string f_reg = value2reg(instr->get_operand(0));
string f_treg = "$ft0";
auto [i_reg, _] = getRegName(instr); auto [i_reg, _] = getRegName(instr);
output.push_back("ftintrz.w.s " + f_reg + ", " + f_reg); output.push_back("ftintrz.w.s " + f_treg + ", " + f_reg);
output.push_back("movfr2gr.s " + i_reg + ", " + f_reg); output.push_back("movfr2gr.s " + i_reg + ", " + f_treg);
gencopy(instr, i_reg); gencopy(instr, i_reg);
} }
void void
...@@ -394,7 +371,7 @@ CodeGen::bool2branch(Instruction *instr) { ...@@ -394,7 +371,7 @@ CodeGen::bool2branch(Instruction *instr) {
->is_zext()) { ->is_zext()) {
// something like: // something like:
// %op0 = icmp slt i32 1, 2 # deepest // %op0 = icmp slt i32 1, 2 # deepest
// %op1 = zext i1 %op0 to i32 // %op1 = zext i1 %op0 to i32 # this zext has no register
// %op2 = icmp ne i32 %op1, 0 // %op2 = icmp ne i32 %op1, 0
// br i1 %op2, label %label3, label %label5 // br i1 %op2, label %label3, label %label5
auto deepest = static_cast<Instruction *>( auto deepest = static_cast<Instruction *>(
...@@ -547,13 +524,16 @@ CodeGen::pass_arguments(CallInst *instr) { ...@@ -547,13 +524,16 @@ CodeGen::pass_arguments(CallInst *instr) {
string t0_contained; string t0_contained;
for (auto arg_id : order) { for (auto arg_id : order) {
auto arg_value = instr->get_operand(arg_id); auto arg_value = instr->get_operand(arg_id);
auto t_reg = arg_value->get_type()->is_float_type() ? "$ft0" : "$t0"; auto arg_is_float = arg_value->get_type()->is_float_type();
auto arg_reg_aid = regname(arg_id, arg_is_float);
auto t_reg = arg_is_float ? "$ft0" : "$t0";
// the value is actually in v_reg, and the arg corresponds to
// arg_reg_aid(only right when arg_id <= 8)
v_reg = value2reg(arg_value, 1); v_reg = value2reg(arg_value, 1);
if (backup[arg_id]) { // still relied by some argument due to cycle if (backup[arg_id]) { // a_id still relied by some argument due to cycle
auto a_id = regname(arg_id); assert(not wroten[arg_reg_aid]);
assert(not wroten[a_id]); gencopy(t_reg, arg_reg_aid, arg_is_float);
gencopy(t_reg, a_id); t0_contained = arg_reg_aid;
t0_contained = a_id;
} }
// in case that the src register has been wroten // in case that the src register has been wroten
if (wroten[v_reg]) { if (wroten[v_reg]) {
...@@ -562,7 +542,7 @@ CodeGen::pass_arguments(CallInst *instr) { ...@@ -562,7 +542,7 @@ CodeGen::pass_arguments(CallInst *instr) {
} }
if (arg_id <= ARG_R) { // pass by register if (arg_id <= ARG_R) { // pass by register
gencopy(regname(arg_id), v_reg); gencopy(arg_reg_aid, v_reg, arg_is_float);
} else { // pass by stack } else { // pass by stack
instr_ir = (arg_value->get_type()->is_float_type() ? "fst" : "st"); instr_ir = (arg_value->get_type()->is_float_type() ? "fst" : "st");
suff = suffix(arg_value->get_type()); suff = suffix(arg_value->get_type());
...@@ -575,7 +555,7 @@ CodeGen::pass_arguments(CallInst *instr) { ...@@ -575,7 +555,7 @@ CodeGen::pass_arguments(CallInst *instr) {
* $sp, " + to_string(func_arg_off.at(func).at(arg_id))); * $sp, " + to_string(func_arg_off.at(func).at(arg_id)));
*/ */
} }
wroten[regname(arg_id)] = true; wroten[arg_reg_aid] = true;
} }
} }
...@@ -585,24 +565,47 @@ CodeGen::IR2assem(CallInst *instr) { ...@@ -585,24 +565,47 @@ CodeGen::IR2assem(CallInst *instr) {
auto func_argN = func_arg_N.at(func); auto func_argN = func_arg_N.at(func);
// analyze the registers that need to be stored // analyze the registers that need to be stored
int cur_i = LRA.get_instr_id().at(instr); int cur_i = LRA.get_instr_id().at(instr);
auto regmap = RA.get(); auto regmap_int = RA_int.get();
auto regmap_float = RA_float.get();
// //
int storeN = 0; int storeN = 0;
vector<std::tuple<Value *, string, int>> store_record; vector<std::tuple<Value *, string, int>> store_record;
for (auto [op, interval] : LRA.get_interval_map()) { for (auto [op, interval] : LRA.get_interval_map()) {
if (RA.no_reg_alloca(op)) if (RA::RegAllocator::no_reg_alloca(op))
continue; continue;
if (not instr->get_function_type() auto [name, find] = getRegName(op);
->get_return_type() if (not find)
->is_void_type() and
regmap.find(instr) != regmap.end() and regmap.at(instr) == 1)
continue; continue;
if (interval.i < cur_i and cur_i <= interval.j) { auto op_type = op->get_type();
auto op_is_float = op_type->is_float_type();
if (not instr->get_function_type()->get_return_type()->is_void_type()) {
// if the called function return a value, and the mapped register is
// just $a0/$fa0, there is no need for restore
bool ret_float =
instr->get_function_type()->get_return_type()->is_float_type();
auto &regmap_ret = (ret_float ? regmap_float : regmap_int);
if (regmap_ret.find(instr) != regmap_ret.end() and
regmap_ret.at(instr) == 1)
if (name == (ret_float ? "$fa0" : "a0"))
continue;
}
bool restore = false;
if (interval.i == interval.j)
;
else if (interval.i < cur_i and cur_i < interval.j)
restore = true;
else if (interval.i == cur_i) {
auto inset = LRA.get_in_set().at(cur_i);
restore = inset.find(op) != inset.end();
} else
;
if (restore) {
cout << "At point " << cur_i << ", restore for " << op->get_name() cout << "At point " << cur_i << ", restore for " << op->get_name()
<< ", interval " << LRA.print_interval(interval) << endl; << ", interval " << LRA.print_interval(interval) << endl;
int tplen = typeLen(op->get_type()); int tplen = typeLen(op_type);
storeN = ALIGN(storeN, tplen) + tplen; storeN = ALIGN(storeN, tplen) + tplen;
auto name = regname(regmap.at(op), op->get_type()->is_float_type()); auto name = regname(
(op_is_float ? regmap_float : regmap_int).at(op), op_is_float);
store_record.push_back({op, name, storeN}); store_record.push_back({op, name, storeN});
} }
} }
...@@ -717,7 +720,7 @@ CodeGen::IR2assem(ReturnInst *instr) { ...@@ -717,7 +720,7 @@ CodeGen::IR2assem(ReturnInst *instr) {
void void
CodeGen::IR2assem(ZextInst *instr) { CodeGen::IR2assem(ZextInst *instr) {
if (RA.no_reg_alloca(instr)) if (RA::RegAllocator::no_reg_alloca(instr))
return; return;
assert(instr->get_num_operand() == 1); assert(instr->get_num_operand() == 1);
auto cmp_instr = instr->get_operand(0); auto cmp_instr = instr->get_operand(0);
...@@ -726,8 +729,6 @@ CodeGen::IR2assem(ZextInst *instr) { ...@@ -726,8 +729,6 @@ CodeGen::IR2assem(ZextInst *instr) {
assert(icmp_instr or fcmp_instr); assert(icmp_instr or fcmp_instr);
auto [dest_reg, _] = getRegName(instr); auto [dest_reg, _] = getRegName(instr);
auto reg1 = value2reg(icmp_instr->get_operand(0), 0);
auto reg2 = value2reg(icmp_instr->get_operand(1), 1);
string instr_ir; string instr_ir;
bool reverse = false, flip = false, check = false; bool reverse = false, flip = false, check = false;
...@@ -758,6 +759,8 @@ CodeGen::IR2assem(ZextInst *instr) { ...@@ -758,6 +759,8 @@ CodeGen::IR2assem(ZextInst *instr) {
flip = true; flip = true;
break; break;
} }
auto reg1 = value2reg(icmp_instr->get_operand(0), 0);
auto reg2 = value2reg(icmp_instr->get_operand(1), 1);
output.push_back( output.push_back(
instr_ir + " " + dest_reg + ", " + instr_ir + " " + dest_reg + ", " +
(reverse ? (reg2 + ", " + reg1) : (reg1 + ", " + reg2))); (reverse ? (reg2 + ", " + reg1) : (reg1 + ", " + reg2)));
...@@ -776,7 +779,7 @@ CodeGen::IR2assem(ZextInst *instr) { ...@@ -776,7 +779,7 @@ CodeGen::IR2assem(ZextInst *instr) {
instr_ir = "fcmp.ceq.s"; instr_ir = "fcmp.ceq.s";
break; break;
case FCmpInst::NE: case FCmpInst::NE:
instr_ir = "fcmp.cun.s"; instr_ir = "fcmp.cne.s";
break; break;
case FCmpInst::GT: case FCmpInst::GT:
instr_ir = "fcmp.cle.s"; instr_ir = "fcmp.cle.s";
...@@ -793,11 +796,14 @@ CodeGen::IR2assem(ZextInst *instr) { ...@@ -793,11 +796,14 @@ CodeGen::IR2assem(ZextInst *instr) {
instr_ir = "fcmp.cle.s"; instr_ir = "fcmp.cle.s";
break; break;
} }
auto reg1 = value2reg(fcmp_instr->get_operand(0), 0);
auto reg2 = value2reg(fcmp_instr->get_operand(1), 1);
string cmp_reg = "$fcc0"; string cmp_reg = "$fcc0";
auto label = "cmp_zext_" + to_string(++cmp_zext_cnt) + ":"; auto label = "cmp_zext_" + to_string(++cmp_zext_cnt);
output.push_back(instr_ir + " " + cmp_reg + ", " + reg1 + ", " + reg2); output.push_back(instr_ir + " " + cmp_reg + ", " + reg1 + ", " + reg2);
output.push_back("or " + dest_reg + ", $zero, $zero"); output.push_back("or " + dest_reg + ", $zero, $zero");
output.push_back("bceqz " + cmp_reg + ", " + label); output.push_back((reverse ? "bcnez " : "bceqz ") + cmp_reg + ", " +
label);
output.push_back("addi.w " + dest_reg + ", $zero, 1"); output.push_back("addi.w " + dest_reg + ", $zero, 1");
output.push_back(label + ":"); output.push_back(label + ":");
} }
......
...@@ -125,31 +125,34 @@ LiveRangeAnalyzer::run(Function *func) { ...@@ -125,31 +125,34 @@ LiveRangeAnalyzer::run(Function *func) {
} }
} }
// argument should be in the IN-set of Entry // argument should be in the IN-set of Entry
assert(IN.find(0) == IN.end() and OUT.find(0) == OUT.end() &&
"no instr_id will be mapped to 0");
IN[0] = OUT[0] = {};
for (auto arg : func->get_args()) for (auto arg : func->get_args())
IN[1].insert(arg); IN[0].insert(arg);
make_interval(func); make_interval(func);
#ifdef __LRA_PRINT__ #ifdef __LRA_PRINT__
print(func, true, true); print(func, false, true);
#endif #endif
} }
void void
LiveRangeAnalyzer::make_interval(Function *) { LiveRangeAnalyzer::make_interval(Function *) {
for (int time = 1; time <= ir_cnt; ++time) { for (int time = 0; time <= ir_cnt; ++time) {
for (auto op : IN.at(time)) { for (auto op : IN.at(time)) {
auto &interval = intervalmap[op]; auto &interval = intervalmap[op];
if (interval.i == -1) // uninitialized if (interval.i == UNINITIAL) // uninitialized
interval.i = time - 1; interval.i = interval.j = time;
else else
interval.j = time - 1; interval.j = time;
} }
for (auto op : OUT.at(time)) { for (auto op : OUT.at(time)) {
auto &interval = intervalmap[op]; auto &interval = intervalmap[op];
if (interval.i == -1) // uninitialized if (interval.i == UNINITIAL) // uninitialized
interval.i = time; interval.i = interval.j = time + 1;
else else
interval.j = time; interval.j = time + 1;
} }
} }
for (auto &[op, interval] : intervalmap) for (auto &[op, interval] : intervalmap)
...@@ -197,6 +200,11 @@ LiveRangeAnalyzer::print(Function *func, ...@@ -197,6 +200,11 @@ LiveRangeAnalyzer::print(Function *func,
bool printSet, bool printSet,
bool printInt) const { // for debug bool printInt) const { // for debug
cout << "Function " << func->get_name() << endl; cout << "Function " << func->get_name() << endl;
cout << "0. Entry" << endl;
if (printSet) {
cout << "\tin-set: " + print_liveSet(IN.at(0)) << "\n";
cout << "\tout-set: " + print_liveSet(OUT.at(0)) << "\n";
}
for (auto &bb : func->get_basic_blocks()) { for (auto &bb : func->get_basic_blocks()) {
for (auto &instr : bb.get_instructions()) { for (auto &instr : bb.get_instructions()) {
if (instr.is_phi()) // ignore phi if (instr.is_phi()) // ignore phi
...@@ -220,8 +228,7 @@ LiveRangeAnalyzer::print(Function *func, ...@@ -220,8 +228,7 @@ LiveRangeAnalyzer::print(Function *func,
} }
} }
// normal ir // normal ir
cout << instr_id.at(&instr) << ". " << instr.print() << " # " cout << instr_id.at(&instr) << ". " << instr.print() << endl;
<< &instr << endl;
if (not printSet) if (not printSet)
continue; continue;
auto idx = instr_id.at(&instr); auto idx = instr_id.at(&instr);
......
...@@ -10,6 +10,9 @@ using std::for_each; ...@@ -10,6 +10,9 @@ using std::for_each;
using namespace RA; using namespace RA;
#define ASSERT_CMPINST_USED_ONCE(cmpinst) \
(assert(cmpinst->get_use_list().size() == 1))
int int
get_arg_id(Argument *arg) { get_arg_id(Argument *arg) {
auto args = arg->get_parent()->get_args(); auto args = arg->get_parent()->get_args();
...@@ -23,7 +26,7 @@ get_arg_id(Argument *arg) { ...@@ -23,7 +26,7 @@ get_arg_id(Argument *arg) {
} }
bool bool
RegAllocator::no_reg_alloca(Value *v) const { RegAllocator::no_reg_alloca(Value *v) {
auto instr = dynamic_cast<Instruction *>(v); auto instr = dynamic_cast<Instruction *>(v);
auto arg = dynamic_cast<Argument *>(v); auto arg = dynamic_cast<Argument *>(v);
if (instr) { if (instr) {
...@@ -31,21 +34,26 @@ RegAllocator::no_reg_alloca(Value *v) const { ...@@ -31,21 +34,26 @@ RegAllocator::no_reg_alloca(Value *v) const {
if (instr->is_alloca() or instr->is_cmp() or instr->is_fcmp()) if (instr->is_alloca() or instr->is_cmp() or instr->is_fcmp())
return true; return true;
else if (instr->is_zext()) { // only alloca for true use else if (instr->is_zext()) { // only alloca for true use
for (auto use : instr->get_use_list()) bool alloc;
if (not dynamic_cast<Instruction *>(use.val_)->is_br()) { ASSERT_CMPINST_USED_ONCE(instr);
auto instr = static_cast<Instruction *>(use.val_); auto use_ins = dynamic_cast<Instruction *>(
if (instr->is_cmp()) { // special case for cmp again instr->get_use_list().begin()->val_);
auto cmp = static_cast<CmpInst *>(instr); // assert(use_ins != nullptr && "should only be instruction?");
assert(cmp->get_cmp_op() == CmpInst::NE); if (use_ins->is_cmp() and
auto uses = instr->get_use_list(); static_cast<CmpInst *>(use_ins)->get_cmp_op() == CmpInst::NE) {
assert(uses.size() == 1 and // this case:
dynamic_cast<Instruction *>(uses.begin()->val_) // %op0 = icmp slt i32 1, 2
->is_br()); // %op1 = zext i1 %op0 to i32
// %op2 = icmp ne i32 %op1, 0 # <- if judges to here
// br i1 %op2, label %label3, label %label5
ASSERT_CMPINST_USED_ONCE(use_ins);
auto use2_ins = dynamic_cast<Instruction *>(
use_ins->get_use_list().begin()->val_);
alloc = not(use2_ins->is_br());
} else
alloc = true;
} else return not(alloc);
return false;
}
return true;
} else // then always allocate } else // then always allocate
return false; return false;
} }
...@@ -64,19 +72,22 @@ RegAllocator::reset(Function *func) { ...@@ -64,19 +72,22 @@ RegAllocator::reset(Function *func) {
} }
int int
RegAllocator::ReserveForArg(const LVITS &Liveints) { RegAllocator::ReserveForArg(const LVITS &liveints) {
auto args = cur_func->get_args(); auto args = cur_func->get_args();
auto it_int = Liveints.begin(); auto it_int = liveints.begin();
auto it_arg = args.begin(); auto it_arg = args.begin();
int reg; int reg;
for (reg = 1; reg <= args.size() and reg <= ARG_MAX_R; ++reg) { for (reg = 1; reg <= args.size() and reg <= ARG_MAX_R; ++reg) {
auto arg = *it_arg; auto arg = *it_arg;
auto liveint = *it_int; if (not(FLOAT ^ arg->get_type()->is_float_type())) {
assert(arg == liveint.second && "arg should be in order in liveints"); auto liveint = *it_int;
assert(arg == liveint.second &&
"arg should be in order in liveints");
used[reg] = true; used[reg] = true;
regmap[arg] = reg; regmap[arg] = reg;
active.insert(liveint); active.insert(liveint);
}
++it_arg, ++it_int; ++it_arg, ++it_int;
} }
return reg; return reg;
...@@ -88,6 +99,8 @@ RegAllocator::LinearScan(const LVITS &liveints, Function *func) { ...@@ -88,6 +99,8 @@ RegAllocator::LinearScan(const LVITS &liveints, Function *func) {
ReserveForArg(liveints); ReserveForArg(liveints);
int reg; int reg;
for (auto liveint : liveints) { for (auto liveint : liveints) {
if (FLOAT ^ liveint.second->get_type()->is_float_type())
continue;
if (dynamic_cast<Argument *>(liveint.second)) { if (dynamic_cast<Argument *>(liveint.second)) {
continue; continue;
} }
...@@ -99,6 +112,13 @@ RegAllocator::LinearScan(const LVITS &liveints, Function *func) { ...@@ -99,6 +112,13 @@ RegAllocator::LinearScan(const LVITS &liveints, Function *func) {
else { else {
for (reg = 1; reg <= R and used[reg]; ++reg) for (reg = 1; reg <= R and used[reg]; ++reg)
; ;
if (reg == 16) {
for (auto [interval, v] : active) {
cout << "already allocated: " << v->get_name() << " ~ "
<< regmap.at(v) << endl;
}
assert(false);
}
used[reg] = true; used[reg] = true;
regmap[liveint.second] = reg; regmap[liveint.second] = reg;
active.insert(liveint); active.insert(liveint);
...@@ -108,10 +128,12 @@ RegAllocator::LinearScan(const LVITS &liveints, Function *func) { ...@@ -108,10 +128,12 @@ RegAllocator::LinearScan(const LVITS &liveints, Function *func) {
void void
RegAllocator::ExpireOldIntervals(LiveInterval liveint) { RegAllocator::ExpireOldIntervals(LiveInterval liveint) {
auto it = active.begin(); auto it = active.begin();
for (; it != active.end() and it->first.j < liveint.first.i; ++it) for (; it != active.end() and it->first.j < liveint.first.i; ++it)
used[regmap.at(it->second)] = false; used[regmap.at(it->second)] = false;
active.erase(active.begin(), it); active.erase(active.begin(), it);
} }
void void
...@@ -122,6 +144,7 @@ RegAllocator::SpillAtInterval(LiveInterval liveint) { ...@@ -122,6 +144,7 @@ RegAllocator::SpillAtInterval(LiveInterval liveint) {
if (spill.first.j > liveint.first.j) { if (spill.first.j > liveint.first.j) {
// cancel reg allocation for spill // cancel reg allocation for spill
regmap[liveint.second] = regmap.at(spill.second); regmap[liveint.second] = regmap.at(spill.second);
active.insert(liveint);
active.erase(spill); active.erase(spill);
regmap.erase(spill.second); regmap.erase(spill.second);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment