Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
2
2022fall-Compiler_CMinus
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Operations
Operations
Metrics
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
李晓奇
2022fall-Compiler_CMinus
Commits
f26d91aa
Commit
f26d91aa
authored
Dec 04, 2022
by
李晓奇
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
finish a lot... bugs
parent
efbf4233
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1321 additions
and
816 deletions
+1321
-816
Reports/4.2-gvn/report.md
Reports/4.2-gvn/report.md
+254
-2
include/lightir/Instruction.h
include/lightir/Instruction.h
+95
-90
include/optimization/GVN.h
include/optimization/GVN.h
+148
-29
src/cminusfc/.gitignore
src/cminusfc/.gitignore
+1
-0
src/cminusfc/cminusf_builder.cpp
src/cminusfc/cminusf_builder.cpp
+318
-553
src/lightir/Instruction.cpp
src/lightir/Instruction.cpp
+231
-84
src/optimization/GVN.cpp
src/optimization/GVN.cpp
+273
-58
tests/4-ir-opt/testcases/GVN/functional/.gitignore
tests/4-ir-opt/testcases/GVN/functional/.gitignore
+1
-0
No files found.
Reports/4.2-gvn/report.md
View file @
f26d91aa
...
@@ -8,12 +8,265 @@
...
@@ -8,12 +8,265 @@
## 实验难点
## 实验难点
实验中遇到哪些挑战
### 对于函数`detectEquivalences(G)`
```
cpp
detectEquivalences
(
G
)
PIN1
=
{}
// “1” is the first statement in the program
POUT1
=
transferFunction
(
PIN1
)
for
each
statement
s
other
than
the
first
statement
in
the
program
POUTs
=
Top
while
changes
to
any
POUT
occur
// i.e. changes in equivalences
for
each
statement
s
other
than
the
first
statement
in
the
program
if
s
appears
in
block
b
that
has
two
predecessors
then
PINs
=
Join
(
POUTs1
,
POUTs2
)
// s1 and s2 are last statements in respective predecessors
else
PINs
=
POUTs3
// s3 the statement just before s
POUTs
=
transferFunction
(
PINs
)
// apply transferFunction on each statement in the block
```
-
POUT、PIN是对于语句的的map,但是实现中是对于基本块的map,怎么处理区别?
注意到基本块的pin、pout对应着首条语句的pin和末尾语句的pout,所以没有本质上的不同,主要是怎么补全中间语句的pin、pout。
中间语句的pin、pout应该不会被显式保存,只在对基本块执行转移函数时,逐条计算。所以应该没啥问题。
-
怎么设计top?
尽管算法伪代码中,top是对每一条语句的赋值的,但是任意一条语句的都有所属基本块,只要对基本块的pout标注top、就能保证基本块的pin合乎逻辑,逐条执行转移函数后,每条语句的pin都合法。
``
- `while changes to any POUT occur`
这句话如此轻松,但是在实现中出现了问题:这涉及到怎么判断分区相等,做不好会出现死循环。
如果单纯比较index恐怕不行:因为我目前的`transferFunction`会生成一个新的等价类,赋予一个新的编号。
所以我使用对等价类中的`members`做相等的判断,如果两个分区中的等价类都能找到相等的,那么就判断分区相等。
> 但是这样依赖于set的排序,
>
> - 对于分区,这个排序依赖于等价类中index的值,侥幸来说每条指令按顺序执行,即使index不一致(每次转移函数分配最新的index),大小顺序也是一致的?
>
> - 对于等价类,里边是`Value*`指针,这个指针的等价就是语义上的等价,没问题。
改过来后发现还是死循环,DEBUG很久意识到我的分区相等写的有问题:`return members_ == other.members_`,这样会挨个儿比较`members_`指针的值,而不是我预想的比较指针指向的等价类。所以要手动遍历一下。
这个遍历也不是很直观的,因为要做两层解引用:
```cpp
for (; it1 != p1.end(); ++it1, ++it2)
if (not(**it1 == **it2))
return false;
```
`*it`是等价类的指针,`**it`才是等价类。
除此之外,我还忘了`++it2`,淦。
改完这些,终于能跑通第一版本了……
### 等价类设计
> ```cpp
> shared_ptr<Expression>
> GVN::valueExpr(Instruction *instr);
> ```
等价类涉及到针对指令设计、常量折叠、递归计算等,很复杂,需要拎出来单独讨论。
首先注意到传给传进去的参数是指令类型(指针)。需要`transferFunction`处理的指令,一次赋值的右式自然可以由指令类型得到,所以没有毛病。
指令有多种类型,也即lightIR中有多种赋值方式,而非单纯的伪代码中的二元运算。但是即使是已经在算法中讨论过的二元运算,写起来也不容易。
#### 生成等价类的逻辑
首先讨论生成等价类的逻辑:
- 二元运算`y op z`
如果y、z都是常量,那么可以进行常量折叠,得到**一个**`ConstantExpression`类型。
否则返回的是一个`BinaryExpression`类型,这个类型需要左右操作数,都是`Expression`类型,如果操作数是常量,使用`ConstantExpression`创建新的值表达式,否则递归调用`valueExpr`。
> 这里为什么递归调用不会走到未定义?
- phi函数
我们逻辑上把其改变成了copy,`transferFunction`应该维护这一点,所以认为没有phi函数传入。
具体是`transferFunction`函数中的这几句话:
```cpp
GVN::partitions
GVN::transferFunction(Instruction *x, Value *e, partitions pin)
{
auto e_instr = dynamic_cast<Instruction *>(e);
auto e_const = dynamic_cast<Constant *>(e);
assert((not e or e_instr or e_const) &&
"A value must be from an instruction or constant");
// ……
if (e) {
if (e_const)
ve = ConstantExpression::create(e_const);
else
ve = valueExpr(e_instr, &pin);
} else
ve = valueExpr(x, &pin);
}
```
- 比较函数:和二元运算一致。
- 类型转换:添加新的表达式类型`CastExpression`,处理和二元操作类似。
- 其他产生赋值的指令,也会由`transferFunction`传递给`ValueExpr`,这些指令是:`load`、`alloca`和返回值非空的`call`。
这里设计一种值表达式类型:`UniqueExpression`,这个表达式和任何其他都不等价,表现为`equiv`直接返回false。
#### 若干需要梳理的问题
除此之外,关于等价类有几个命题一定要说明,这些是理解GVN的关键。
- 等价类中的ve一定存在吗?
根据等价类的设计,ve一定是存在的,且等价类中`members_`中一定存在元素。
- ve一定唯一吗?
从逻辑上,等价类集合中的元素享有共同的ve产生的结果,ve是唯一的
从实现上,添加
- ve或者vpf能够作为代表吗?
这个问题源于`transferFunction`的一句话:
> if ve or vpf is in a class Ci in POUTs
- 代表元如何选取?什么时候选取?
- 怎么判断等价类相等
根据论坛上这个[例子](http://cscourse.ustc.edu.cn/forum/thread.jsp?forum=91&thread=68),我发现值编号会有非必要的增长,即最终收敛时用到的编号是1、5、6,其中2、3、4都是迭代中用到的但是最后都没有体现。
从这个例子中可以发现,迭代收敛时,是pout中对应分区的`members_`不变,值编号还是会变化的,所以判断等价类相等,应该比较`members_`集合。
- 什么时候给新的值编号?
1. 执行转移函数发现没有可以归属的等价类时
2. 求交巴拉巴拉还没搞清楚那里
- 等价类为空的判定,可以用`members_`做判断吗?
我认为是可以的,存在情况:两个基本块汇合时,交出一个集合,它的members为空,但是ve不空:
```
BB0: y =
z =
; pout: [{v1,y}, {v2,z}]
BB1: x1 = y + z ; preds = BB0
; pout: [{v1,y}, {v2,z}, {v3, x1, v1+v2}]
BB2: x2 = y + z ; preds = BB0
; pout: [{v1,y}, {v2,z}, {v4, x2, v1+v2}]
BB3: ... ; preds = BB1, BB2
```
进入BB3时,对BB1和BB2的pout中等价类分别取交,执行`Intersect(Ci={v3, x1, v1+v2}, Cj={v4, x2, v1+v2}`,得到的是`{v1+v2}`,`members_`中没有成员,但是ve存在,我认为这种情况也应该判定结果为空集。
因为最终反映到IR上,我们直接关心的都是`members_`中的成员,所以用`members_`的空代表等价类的空我觉得是合适的。
| 类型 | 处理 |
| ---------------- | ------------------------- |
| 二元运算 | 依伪代码 |
| 没有返回值的 | 不处理 |
| phi函数 | 本块中的不管,后继块的做考虑 |
| 函数调用 | 有返回值的考虑等价类,**注意后边纯函数的处理** |
| 产生指针(GEP、alloca) | 先为返回值单独新建等价类 |
| 类型转换及0扩展 | 单独新建等价类 |
| 比较 | 根据op和操作数递归判断,应该和二元运算类似 |
| load | 单独新建等价类 |
### 对于函数`Intersect(Ci, Cj )`
伪代码中,一个等价类是一个集合,求交后,v~k~是自然而然的在或不在新集合中,但是实现中并不这么简单,如何对这样的两个结构体进行逻辑上的取交:
```cpp
struct CongruenceClass {
size_t index_;
Value *leader_;
std::shared_ptr<GVNExpression::Expression> value_expr_;
std::shared_ptr<GVNExpression::PhiExpression> value_phi_;
std::set<Value *> members_;
}
```
- 首先是Intersect的伪代码描述中,下边这句怎么在实现中体现?
> C~k~ does not have value number
解决:index不同的自然就交不出v~i~,相同的时候才有v~i~即value number.
隐患:transferFunction要对v~i~有继承性,即不能每次涉及一个等价类都新开一个value number
- 根据上一条讨论,得出一个观点:在伪代码中,一个等价类集合会出现的内容,在实现中的对应如下:
- 普通变量:`members_`
- 值表达式:`value_expr_`
- phi函数:`value_phi_`
- 值编号v~i~:`index`
因此,伪代码中对集合的求交,对应到实现上,就是对以上四个域求交。
### 对于函数`valuePhiFunc(ve, P)`
- 输入就一个分区,实际上要用到两个前驱的分区,怎么处理?
传入基本块做参数。
- 二元运算的顺序颠倒怎么办?
> 如phi(x1,y1)+phi(x2,y2)和phi(x1,y1)+phi(y2,x2)
>
> 对于第二种,单纯寻找x1+y2和y1+x2肯定不合逻辑
应该不会有这种情况:phi表达式,追根溯源是在join时产生的:
```cpp
GVN::join(const partitions &P1, const partitions &P2);
```
这里的分区是严格按照`
Basicblock
`的方法`
get_pre_basic_blocks()
`的返回顺序传入的,所以不会有担心的情况发生。
### 所见非所得
两方面:
- 论文中伪代码对于基本块的关注不是很多,但是我们的实现一定要围绕基本块进行,这个的解决对策说到了,其实和单条语句没有本质差异。
- 我们针对lightir优化,经常深入API中陷入到实现细节,还容易被C++的语言特性绊住,但是应该尝试从直观上理解:我们究竟要实现什么样的效果。这个就得去看ll文件找灵感。然后这还不够,回过头来看代码又呆住了,因为ll语句和那些类型对应不上,这个就是容易忽视的一点:看lightir类型的print函数,这个能帮我们直观地串联起实现细节和表象的ll文件。
## 实验设计
## 实验设计
### join_helper
### _TOP
### transferFunction(Basicblock)
###
实现思路,相应代码,优化前后的IR对比(举一个例子)并辅以简单说明
实现思路,相应代码,优化前后的IR对比(举一个例子)并辅以简单说明
### 思考题
### 思考题
1. 请简要分析你的算法复杂度
1. 请简要分析你的算法复杂度
2. `
std::shared_ptr
`
如果存在环形引用,则无法正确释放内存,你的 Expression 类是否存在 circular reference?
2. `
std::shared_ptr
`
如果存在环形引用,则无法正确释放内存,你的 Expression 类是否存在 circular reference?
3.
尽管本次实验已经写了很多代码,但是在算法上和工程上仍然可以对 GVN 进行改进,请简述你的 GVN 实现可以改进的地方
3.
尽管本次实验已经写了很多代码,但是在算法上和工程上仍然可以对 GVN 进行改进,请简述你的 GVN 实现可以改进的地方
...
@@ -25,4 +278,3 @@
...
@@ -25,4 +278,3 @@
## 实验反馈(可选 不会评分)
## 实验反馈(可选 不会评分)
对本次实验的建议
对本次实验的建议
include/lightir/Instruction.h
View file @
f26d91aa
...
@@ -9,8 +9,7 @@
...
@@ -9,8 +9,7 @@
class
BasicBlock
;
class
BasicBlock
;
class
Function
;
class
Function
;
class
Instruction
:
public
User
,
public
llvm
::
ilist_node
<
Instruction
>
class
Instruction
:
public
User
,
public
llvm
::
ilist_node
<
Instruction
>
{
{
public:
public:
enum
OpID
{
enum
OpID
{
// Terminator Instructions
// Terminator Instructions
...
@@ -49,11 +48,11 @@ class Instruction : public User, public llvm::ilist_node<Instruction>
...
@@ -49,11 +48,11 @@ class Instruction : public User, public llvm::ilist_node<Instruction>
Instruction
(
const
Instruction
&
)
=
delete
;
Instruction
(
const
Instruction
&
)
=
delete
;
virtual
~
Instruction
()
=
default
;
virtual
~
Instruction
()
=
default
;
inline
const
BasicBlock
*
get_parent
()
const
{
return
parent_
;
}
inline
const
BasicBlock
*
get_parent
()
const
{
return
parent_
;
}
inline
BasicBlock
*
get_parent
()
{
return
parent_
;
}
inline
BasicBlock
*
get_parent
()
{
return
parent_
;
}
void
set_parent
(
BasicBlock
*
parent
)
{
this
->
parent_
=
parent
;
}
void
set_parent
(
BasicBlock
*
parent
)
{
this
->
parent_
=
parent
;
}
// Return the function this instruction belongs to.
// Return the function this instruction belongs to.
Function
*
get_function
();
Function
*
get_function
();
Module
*
get_module
();
Module
*
get_module
();
OpID
get_instr_type
()
const
{
return
op_id_
;
}
OpID
get_instr_type
()
const
{
return
op_id_
;
}
// clang-format off
// clang-format off
...
@@ -86,8 +85,7 @@ class Instruction : public User, public llvm::ilist_node<Instruction>
...
@@ -86,8 +85,7 @@ class Instruction : public User, public llvm::ilist_node<Instruction>
// clang-format on
// clang-format on
std
::
string
get_instr_op_name
()
{
return
get_instr_op_name
(
op_id_
);
}
std
::
string
get_instr_op_name
()
{
return
get_instr_op_name
(
op_id_
);
}
bool
is_void
()
bool
is_void
()
{
{
return
((
op_id_
==
ret
)
||
(
op_id_
==
br
)
||
(
op_id_
==
store
)
||
return
((
op_id_
==
ret
)
||
(
op_id_
==
br
)
||
(
op_id_
==
store
)
||
(
op_id_
==
call
&&
this
->
get_type
()
->
is_void_type
()));
(
op_id_
==
call
&&
this
->
get_type
()
->
is_void_type
()));
}
}
...
@@ -108,6 +106,7 @@ class Instruction : public User, public llvm::ilist_node<Instruction>
...
@@ -108,6 +106,7 @@ class Instruction : public User, public llvm::ilist_node<Instruction>
bool
is_fsub
()
{
return
op_id_
==
fsub
;
}
bool
is_fsub
()
{
return
op_id_
==
fsub
;
}
bool
is_fmul
()
{
return
op_id_
==
fmul
;
}
bool
is_fmul
()
{
return
op_id_
==
fmul
;
}
bool
is_fdiv
()
{
return
op_id_
==
fdiv
;
}
bool
is_fdiv
()
{
return
op_id_
==
fdiv
;
}
bool
is_fp2si
()
{
return
op_id_
==
fptosi
;
}
bool
is_fp2si
()
{
return
op_id_
==
fptosi
;
}
bool
is_si2fp
()
{
return
op_id_
==
sitofp
;
}
bool
is_si2fp
()
{
return
op_id_
==
sitofp
;
}
...
@@ -118,8 +117,7 @@ class Instruction : public User, public llvm::ilist_node<Instruction>
...
@@ -118,8 +117,7 @@ class Instruction : public User, public llvm::ilist_node<Instruction>
bool
is_gep
()
{
return
op_id_
==
getelementptr
;
}
bool
is_gep
()
{
return
op_id_
==
getelementptr
;
}
bool
is_zext
()
{
return
op_id_
==
zext
;
}
bool
is_zext
()
{
return
op_id_
==
zext
;
}
bool
isBinary
()
bool
isBinary
()
{
{
return
(
is_add
()
||
is_sub
()
||
is_mul
()
||
is_div
()
||
is_fadd
()
||
return
(
is_add
()
||
is_sub
()
||
is_mul
()
||
is_div
()
||
is_fadd
()
||
is_fsub
()
||
is_fmul
()
||
is_fdiv
())
&&
is_fsub
()
||
is_fmul
()
||
is_fdiv
())
&&
(
get_num_operand
()
==
2
);
(
get_num_operand
()
==
2
);
...
@@ -128,39 +126,34 @@ class Instruction : public User, public llvm::ilist_node<Instruction>
...
@@ -128,39 +126,34 @@ class Instruction : public User, public llvm::ilist_node<Instruction>
bool
isTerminator
()
{
return
is_br
()
||
is_ret
();
}
bool
isTerminator
()
{
return
is_br
()
||
is_ret
();
}
private:
private:
OpID
op_id_
;
OpID
op_id_
;
unsigned
num_ops_
;
unsigned
num_ops_
;
BasicBlock
*
parent_
;
BasicBlock
*
parent_
;
};
};
namespace
detail
namespace
detail
{
{
template
<
typename
T
>
template
<
typename
T
>
struct
tag
{
struct
tag
using
type
=
T
;
{
};
using
type
=
T
;
template
<
typename
...
Ts
>
};
struct
select_last
{
template
<
typename
...
Ts
>
// Use a fold-expression to fold the comma operator over the parameter
struct
select_last
// pack.
{
using
type
=
typename
decltype
((
tag
<
Ts
>
{},
...))
::
type
;
// Use a fold-expression to fold the comma operator over the parameter
};
// pack.
template
<
typename
...
Ts
>
using
type
=
typename
decltype
((
tag
<
Ts
>
{},
...))
::
type
;
using
select_last_t
=
typename
select_last
<
Ts
...
>::
type
;
};
template
<
typename
...
Ts
>
using
select_last_t
=
typename
select_last
<
Ts
...
>::
type
;
};
// namespace detail
};
// namespace detail
template
<
class
>
template
<
class
>
inline
constexpr
bool
always_false_v
=
false
;
inline
constexpr
bool
always_false_v
=
false
;
template
<
typename
Inst
>
template
<
typename
Inst
>
class
BaseInst
:
public
Instruction
class
BaseInst
:
public
Instruction
{
{
protected:
protected:
template
<
typename
...
Args
>
template
<
typename
...
Args
>
static
Inst
*
create
(
Args
&&
...
args
)
static
Inst
*
create
(
Args
&&
...
args
)
{
{
if
constexpr
(
std
::
is_same_v
<
if
constexpr
(
std
::
is_same_v
<
std
::
decay_t
<
detail
::
select_last_t
<
Args
...
>>
,
std
::
decay_t
<
detail
::
select_last_t
<
Args
...
>>
,
BasicBlock
*>
)
{
BasicBlock
*>
)
{
...
@@ -171,13 +164,10 @@ class BaseInst : public Instruction
...
@@ -171,13 +164,10 @@ class BaseInst : public Instruction
}
}
template
<
typename
...
Args
>
template
<
typename
...
Args
>
BaseInst
(
Args
&&
...
args
)
:
Instruction
(
std
::
forward
<
Args
>
(
args
)...)
BaseInst
(
Args
&&
...
args
)
:
Instruction
(
std
::
forward
<
Args
>
(
args
)...)
{}
{
}
};
};
class
BinaryInst
:
public
BaseInst
<
BinaryInst
>
class
BinaryInst
:
public
BaseInst
<
BinaryInst
>
{
{
friend
BaseInst
<
BinaryInst
>
;
friend
BaseInst
<
BinaryInst
>
;
private:
private:
...
@@ -185,35 +175,51 @@ class BinaryInst : public BaseInst<BinaryInst>
...
@@ -185,35 +175,51 @@ class BinaryInst : public BaseInst<BinaryInst>
public:
public:
// create add instruction, auto insert to bb
// create add instruction, auto insert to bb
static
BinaryInst
*
create_add
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
static
BinaryInst
*
create_add
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
Module
*
m
);
Module
*
m
);
// create sub instruction, auto insert to bb
// create sub instruction, auto insert to bb
static
BinaryInst
*
create_sub
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
static
BinaryInst
*
create_sub
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
Module
*
m
);
Module
*
m
);
// create mul instruction, auto insert to bb
// create mul instruction, auto insert to bb
static
BinaryInst
*
create_mul
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
static
BinaryInst
*
create_mul
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
Module
*
m
);
Module
*
m
);
// create Div instruction, auto insert to bb
// create Div instruction, auto insert to bb
static
BinaryInst
*
create_sdiv
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
static
BinaryInst
*
create_sdiv
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
Module
*
m
);
Module
*
m
);
// create fadd instruction, auto insert to bb
// create fadd instruction, auto insert to bb
static
BinaryInst
*
create_fadd
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
static
BinaryInst
*
create_fadd
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
Module
*
m
);
Module
*
m
);
// create fsub instruction, auto insert to bb
// create fsub instruction, auto insert to bb
static
BinaryInst
*
create_fsub
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
static
BinaryInst
*
create_fsub
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
Module
*
m
);
Module
*
m
);
// create fmul instruction, auto insert to bb
// create fmul instruction, auto insert to bb
static
BinaryInst
*
create_fmul
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
static
BinaryInst
*
create_fmul
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
Module
*
m
);
Module
*
m
);
// create fDiv instruction, auto insert to bb
// create fDiv instruction, auto insert to bb
static
BinaryInst
*
create_fdiv
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
static
BinaryInst
*
create_fdiv
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
Module
*
m
);
Module
*
m
);
virtual
std
::
string
print
()
override
;
virtual
std
::
string
print
()
override
;
...
@@ -222,8 +228,7 @@ class BinaryInst : public BaseInst<BinaryInst>
...
@@ -222,8 +228,7 @@ class BinaryInst : public BaseInst<BinaryInst>
void
assertValid
();
void
assertValid
();
};
};
class
CmpInst
:
public
BaseInst
<
CmpInst
>
class
CmpInst
:
public
BaseInst
<
CmpInst
>
{
{
friend
BaseInst
<
CmpInst
>
;
friend
BaseInst
<
CmpInst
>
;
public:
public:
...
@@ -240,7 +245,10 @@ class CmpInst : public BaseInst<CmpInst>
...
@@ -240,7 +245,10 @@ class CmpInst : public BaseInst<CmpInst>
CmpInst
(
Type
*
ty
,
CmpOp
op
,
Value
*
lhs
,
Value
*
rhs
,
BasicBlock
*
bb
);
CmpInst
(
Type
*
ty
,
CmpOp
op
,
Value
*
lhs
,
Value
*
rhs
,
BasicBlock
*
bb
);
public:
public:
static
CmpInst
*
create_cmp
(
CmpOp
op
,
Value
*
lhs
,
Value
*
rhs
,
BasicBlock
*
bb
,
static
CmpInst
*
create_cmp
(
CmpOp
op
,
Value
*
lhs
,
Value
*
rhs
,
BasicBlock
*
bb
,
Module
*
m
);
Module
*
m
);
CmpOp
get_cmp_op
()
{
return
cmp_op_
;
}
CmpOp
get_cmp_op
()
{
return
cmp_op_
;
}
...
@@ -253,8 +261,7 @@ class CmpInst : public BaseInst<CmpInst>
...
@@ -253,8 +261,7 @@ class CmpInst : public BaseInst<CmpInst>
void
assertValid
();
void
assertValid
();
};
};
class
FCmpInst
:
public
BaseInst
<
FCmpInst
>
class
FCmpInst
:
public
BaseInst
<
FCmpInst
>
{
{
friend
BaseInst
<
FCmpInst
>
;
friend
BaseInst
<
FCmpInst
>
;
public:
public:
...
@@ -271,8 +278,11 @@ class FCmpInst : public BaseInst<FCmpInst>
...
@@ -271,8 +278,11 @@ class FCmpInst : public BaseInst<FCmpInst>
FCmpInst
(
Type
*
ty
,
CmpOp
op
,
Value
*
lhs
,
Value
*
rhs
,
BasicBlock
*
bb
);
FCmpInst
(
Type
*
ty
,
CmpOp
op
,
Value
*
lhs
,
Value
*
rhs
,
BasicBlock
*
bb
);
public:
public:
static
FCmpInst
*
create_fcmp
(
CmpOp
op
,
Value
*
lhs
,
Value
*
rhs
,
static
FCmpInst
*
create_fcmp
(
CmpOp
op
,
BasicBlock
*
bb
,
Module
*
m
);
Value
*
lhs
,
Value
*
rhs
,
BasicBlock
*
bb
,
Module
*
m
);
CmpOp
get_cmp_op
()
{
return
cmp_op_
;
}
CmpOp
get_cmp_op
()
{
return
cmp_op_
;
}
...
@@ -284,33 +294,36 @@ class FCmpInst : public BaseInst<FCmpInst>
...
@@ -284,33 +294,36 @@ class FCmpInst : public BaseInst<FCmpInst>
void
assert_valid
();
void
assert_valid
();
};
};
class
CallInst
:
public
BaseInst
<
CallInst
>
class
CallInst
:
public
BaseInst
<
CallInst
>
{
{
friend
BaseInst
<
CallInst
>
;
friend
BaseInst
<
CallInst
>
;
protected:
protected:
CallInst
(
Function
*
func
,
std
::
vector
<
Value
*>
args
,
BasicBlock
*
bb
);
CallInst
(
Function
*
func
,
std
::
vector
<
Value
*>
args
,
BasicBlock
*
bb
);
public:
public:
static
CallInst
*
create
(
Function
*
func
,
std
::
vector
<
Value
*>
args
,
static
CallInst
*
create
(
Function
*
func
,
std
::
vector
<
Value
*>
args
,
BasicBlock
*
bb
);
BasicBlock
*
bb
);
FunctionType
*
get_function_type
()
const
;
FunctionType
*
get_function_type
()
const
;
virtual
std
::
string
print
()
override
;
virtual
std
::
string
print
()
override
;
};
};
class
BranchInst
:
public
BaseInst
<
BranchInst
>
class
BranchInst
:
public
BaseInst
<
BranchInst
>
{
{
friend
BaseInst
<
BranchInst
>
;
friend
BaseInst
<
BranchInst
>
;
private:
private:
BranchInst
(
Value
*
cond
,
BasicBlock
*
if_true
,
BasicBlock
*
if_false
,
BranchInst
(
Value
*
cond
,
BasicBlock
*
if_true
,
BasicBlock
*
if_false
,
BasicBlock
*
bb
);
BasicBlock
*
bb
);
BranchInst
(
BasicBlock
*
if_true
,
BasicBlock
*
bb
);
BranchInst
(
BasicBlock
*
if_true
,
BasicBlock
*
bb
);
public:
public:
static
BranchInst
*
create_cond_br
(
Value
*
cond
,
BasicBlock
*
if_true
,
static
BranchInst
*
create_cond_br
(
Value
*
cond
,
BasicBlock
*
if_false
,
BasicBlock
*
bb
);
BasicBlock
*
if_true
,
BasicBlock
*
if_false
,
BasicBlock
*
bb
);
static
BranchInst
*
create_br
(
BasicBlock
*
if_true
,
BasicBlock
*
bb
);
static
BranchInst
*
create_br
(
BasicBlock
*
if_true
,
BasicBlock
*
bb
);
bool
is_cond_br
()
const
;
bool
is_cond_br
()
const
;
...
@@ -318,8 +331,7 @@ class BranchInst : public BaseInst<BranchInst>
...
@@ -318,8 +331,7 @@ class BranchInst : public BaseInst<BranchInst>
virtual
std
::
string
print
()
override
;
virtual
std
::
string
print
()
override
;
};
};
class
ReturnInst
:
public
BaseInst
<
ReturnInst
>
class
ReturnInst
:
public
BaseInst
<
ReturnInst
>
{
{
friend
BaseInst
<
ReturnInst
>
;
friend
BaseInst
<
ReturnInst
>
;
private:
private:
...
@@ -329,13 +341,12 @@ class ReturnInst : public BaseInst<ReturnInst>
...
@@ -329,13 +341,12 @@ class ReturnInst : public BaseInst<ReturnInst>
public:
public:
static
ReturnInst
*
create_ret
(
Value
*
val
,
BasicBlock
*
bb
);
static
ReturnInst
*
create_ret
(
Value
*
val
,
BasicBlock
*
bb
);
static
ReturnInst
*
create_void_ret
(
BasicBlock
*
bb
);
static
ReturnInst
*
create_void_ret
(
BasicBlock
*
bb
);
bool
is_void_ret
()
const
;
bool
is_void_ret
()
const
;
virtual
std
::
string
print
()
override
;
virtual
std
::
string
print
()
override
;
};
};
class
GetElementPtrInst
:
public
BaseInst
<
GetElementPtrInst
>
class
GetElementPtrInst
:
public
BaseInst
<
GetElementPtrInst
>
{
{
friend
BaseInst
<
GetElementPtrInst
>
;
friend
BaseInst
<
GetElementPtrInst
>
;
private:
private:
...
@@ -343,9 +354,10 @@ class GetElementPtrInst : public BaseInst<GetElementPtrInst>
...
@@ -343,9 +354,10 @@ class GetElementPtrInst : public BaseInst<GetElementPtrInst>
public:
public:
static
Type
*
get_element_type
(
Value
*
ptr
,
std
::
vector
<
Value
*>
idxs
);
static
Type
*
get_element_type
(
Value
*
ptr
,
std
::
vector
<
Value
*>
idxs
);
static
GetElementPtrInst
*
create_gep
(
Value
*
ptr
,
std
::
vector
<
Value
*>
idxs
,
static
GetElementPtrInst
*
create_gep
(
Value
*
ptr
,
std
::
vector
<
Value
*>
idxs
,
BasicBlock
*
bb
);
BasicBlock
*
bb
);
Type
*
get_element_type
()
const
;
Type
*
get_element_type
()
const
;
virtual
std
::
string
print
()
override
;
virtual
std
::
string
print
()
override
;
...
@@ -353,8 +365,7 @@ class GetElementPtrInst : public BaseInst<GetElementPtrInst>
...
@@ -353,8 +365,7 @@ class GetElementPtrInst : public BaseInst<GetElementPtrInst>
Type
*
element_ty_
;
Type
*
element_ty_
;
};
};
class
StoreInst
:
public
BaseInst
<
StoreInst
>
class
StoreInst
:
public
BaseInst
<
StoreInst
>
{
{
friend
BaseInst
<
StoreInst
>
;
friend
BaseInst
<
StoreInst
>
;
private:
private:
...
@@ -369,8 +380,7 @@ class StoreInst : public BaseInst<StoreInst>
...
@@ -369,8 +380,7 @@ class StoreInst : public BaseInst<StoreInst>
virtual
std
::
string
print
()
override
;
virtual
std
::
string
print
()
override
;
};
};
class
LoadInst
:
public
BaseInst
<
LoadInst
>
class
LoadInst
:
public
BaseInst
<
LoadInst
>
{
{
friend
BaseInst
<
LoadInst
>
;
friend
BaseInst
<
LoadInst
>
;
private:
private:
...
@@ -378,15 +388,14 @@ class LoadInst : public BaseInst<LoadInst>
...
@@ -378,15 +388,14 @@ class LoadInst : public BaseInst<LoadInst>
public:
public:
static
LoadInst
*
create_load
(
Type
*
ty
,
Value
*
ptr
,
BasicBlock
*
bb
);
static
LoadInst
*
create_load
(
Type
*
ty
,
Value
*
ptr
,
BasicBlock
*
bb
);
Value
*
get_lval
()
{
return
this
->
get_operand
(
0
);
}
Value
*
get_lval
()
{
return
this
->
get_operand
(
0
);
}
Type
*
get_load_type
()
const
;
Type
*
get_load_type
()
const
;
virtual
std
::
string
print
()
override
;
virtual
std
::
string
print
()
override
;
};
};
class
AllocaInst
:
public
BaseInst
<
AllocaInst
>
class
AllocaInst
:
public
BaseInst
<
AllocaInst
>
{
{
friend
BaseInst
<
AllocaInst
>
;
friend
BaseInst
<
AllocaInst
>
;
private:
private:
...
@@ -403,8 +412,7 @@ class AllocaInst : public BaseInst<AllocaInst>
...
@@ -403,8 +412,7 @@ class AllocaInst : public BaseInst<AllocaInst>
Type
*
alloca_ty_
;
Type
*
alloca_ty_
;
};
};
class
ZextInst
:
public
BaseInst
<
ZextInst
>
class
ZextInst
:
public
BaseInst
<
ZextInst
>
{
{
friend
BaseInst
<
ZextInst
>
;
friend
BaseInst
<
ZextInst
>
;
private:
private:
...
@@ -421,8 +429,7 @@ class ZextInst : public BaseInst<ZextInst>
...
@@ -421,8 +429,7 @@ class ZextInst : public BaseInst<ZextInst>
Type
*
dest_ty_
;
Type
*
dest_ty_
;
};
};
class
FpToSiInst
:
public
BaseInst
<
FpToSiInst
>
class
FpToSiInst
:
public
BaseInst
<
FpToSiInst
>
{
{
friend
BaseInst
<
FpToSiInst
>
;
friend
BaseInst
<
FpToSiInst
>
;
private:
private:
...
@@ -439,8 +446,7 @@ class FpToSiInst : public BaseInst<FpToSiInst>
...
@@ -439,8 +446,7 @@ class FpToSiInst : public BaseInst<FpToSiInst>
Type
*
dest_ty_
;
Type
*
dest_ty_
;
};
};
class
SiToFpInst
:
public
BaseInst
<
SiToFpInst
>
class
SiToFpInst
:
public
BaseInst
<
SiToFpInst
>
{
{
friend
BaseInst
<
SiToFpInst
>
;
friend
BaseInst
<
SiToFpInst
>
;
private:
private:
...
@@ -457,25 +463,24 @@ class SiToFpInst : public BaseInst<SiToFpInst>
...
@@ -457,25 +463,24 @@ class SiToFpInst : public BaseInst<SiToFpInst>
Type
*
dest_ty_
;
Type
*
dest_ty_
;
};
};
class
PhiInst
:
public
BaseInst
<
PhiInst
>
class
PhiInst
:
public
BaseInst
<
PhiInst
>
{
{
friend
BaseInst
<
PhiInst
>
;
friend
BaseInst
<
PhiInst
>
;
private:
private:
PhiInst
(
OpID
op
,
std
::
vector
<
Value
*>
vals
,
PhiInst
(
OpID
op
,
std
::
vector
<
BasicBlock
*>
val_bbs
,
Type
*
ty
,
BasicBlock
*
bb
);
std
::
vector
<
Value
*>
vals
,
std
::
vector
<
BasicBlock
*>
val_bbs
,
Type
*
ty
,
BasicBlock
*
bb
);
PhiInst
(
Type
*
ty
,
OpID
op
,
unsigned
num_ops
,
BasicBlock
*
bb
)
PhiInst
(
Type
*
ty
,
OpID
op
,
unsigned
num_ops
,
BasicBlock
*
bb
)
:
BaseInst
<
PhiInst
>
(
ty
,
op
,
num_ops
,
bb
)
:
BaseInst
<
PhiInst
>
(
ty
,
op
,
num_ops
,
bb
)
{}
{
}
Value
*
l_val_
;
Value
*
l_val_
;
public:
public:
static
PhiInst
*
create_phi
(
Type
*
ty
,
BasicBlock
*
bb
);
static
PhiInst
*
create_phi
(
Type
*
ty
,
BasicBlock
*
bb
);
Value
*
get_lval
()
{
return
l_val_
;
}
Value
*
get_lval
()
{
return
l_val_
;
}
void
set_lval
(
Value
*
l_val
)
{
l_val_
=
l_val
;
}
void
set_lval
(
Value
*
l_val
)
{
l_val_
=
l_val
;
}
void
add_phi_pair_operand
(
Value
*
val
,
Value
*
pre_bb
)
void
add_phi_pair_operand
(
Value
*
val
,
Value
*
pre_bb
)
{
{
this
->
add_operand
(
val
);
this
->
add_operand
(
val
);
this
->
add_operand
(
pre_bb
);
this
->
add_operand
(
pre_bb
);
}
}
...
...
include/optimization/GVN.h
View file @
f26d91aa
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include <utility>
#include <utility>
#include <vector>
#include <vector>
class
GVN
;
namespace
GVNExpression
{
namespace
GVNExpression
{
// fold the constant value
// fold the constant value
...
@@ -35,12 +36,13 @@ class ConstFolder {
...
@@ -35,12 +36,13 @@ class ConstFolder {
/**
/**
* for constructor of class derived from `Expression`, we make it public
* for constructor of class derived from `Expression`, we make it public
* because `std::make_shared` needs the constructor to be publicly available,
* because `std::make_shared` needs the constructor to be publicly available,
* but you should call the static factory method `create` instead the constructor itself to get the desired data
* but you should call the static factory method `create` instead the
* constructor itself to get the desired data
*/
*/
class
Expression
{
class
Expression
{
public:
public:
// TODO: you need to extend expression types according to testcases
// TODO: you need to extend expression types according to testcases
enum
gvn_expr_t
{
e_constant
,
e_bin
,
e_phi
};
enum
gvn_expr_t
{
e_constant
,
e_bin
,
e_phi
,
e_cast
,
e_gep
,
e_unique
};
Expression
(
gvn_expr_t
t
)
:
expr_type
(
t
)
{}
Expression
(
gvn_expr_t
t
)
:
expr_type
(
t
)
{}
virtual
~
Expression
()
=
default
;
virtual
~
Expression
()
=
default
;
virtual
std
::
string
print
()
=
0
;
virtual
std
::
string
print
()
=
0
;
...
@@ -50,15 +52,21 @@ class Expression {
...
@@ -50,15 +52,21 @@ class Expression {
gvn_expr_t
expr_type
;
gvn_expr_t
expr_type
;
};
};
bool
operator
==
(
const
std
::
shared_ptr
<
Expression
>
&
lhs
,
const
std
::
shared_ptr
<
Expression
>
&
rhs
);
bool
operator
==
(
const
std
::
shared_ptr
<
Expression
>
&
lhs
,
bool
operator
==
(
const
GVNExpression
::
Expression
&
lhs
,
const
GVNExpression
::
Expression
&
rhs
);
const
std
::
shared_ptr
<
Expression
>
&
rhs
);
bool
operator
==
(
const
GVNExpression
::
Expression
&
lhs
,
const
GVNExpression
::
Expression
&
rhs
);
class
ConstantExpression
:
public
Expression
{
class
ConstantExpression
:
public
Expression
{
public:
public:
static
std
::
shared_ptr
<
ConstantExpression
>
create
(
Constant
*
c
)
{
return
std
::
make_shared
<
ConstantExpression
>
(
c
);
}
static
std
::
shared_ptr
<
ConstantExpression
>
create
(
Constant
*
c
)
{
return
std
::
make_shared
<
ConstantExpression
>
(
c
);
}
virtual
std
::
string
print
()
{
return
c_
->
print
();
}
virtual
std
::
string
print
()
{
return
c_
->
print
();
}
// we leverage the fact that constants in lightIR have unique addresses
// we leverage the fact that constants in lightIR have unique addresses
bool
equiv
(
const
ConstantExpression
*
other
)
const
{
return
c_
==
other
->
c_
;
}
bool
equiv
(
const
ConstantExpression
*
other
)
const
{
return
c_
==
other
->
c_
;
}
ConstantExpression
(
Constant
*
c
)
:
Expression
(
e_constant
),
c_
(
c
)
{}
ConstantExpression
(
Constant
*
c
)
:
Expression
(
e_constant
),
c_
(
c
)
{}
private:
private:
...
@@ -67,24 +75,31 @@ class ConstantExpression : public Expression {
...
@@ -67,24 +75,31 @@ class ConstantExpression : public Expression {
// arithmetic expression
// arithmetic expression
class
BinaryExpression
:
public
Expression
{
class
BinaryExpression
:
public
Expression
{
friend
class
::
GVN
;
public:
public:
static
std
::
shared_ptr
<
BinaryExpression
>
create
(
Instruction
::
OpID
op
,
static
std
::
shared_ptr
<
BinaryExpression
>
create
(
std
::
shared_ptr
<
Expression
>
lhs
,
Instruction
::
OpID
op
,
std
::
shared_ptr
<
Expression
>
rhs
)
{
std
::
shared_ptr
<
Expression
>
lhs
,
std
::
shared_ptr
<
Expression
>
rhs
)
{
return
std
::
make_shared
<
BinaryExpression
>
(
op
,
lhs
,
rhs
);
return
std
::
make_shared
<
BinaryExpression
>
(
op
,
lhs
,
rhs
);
}
}
virtual
std
::
string
print
()
{
virtual
std
::
string
print
()
{
return
"("
+
Instruction
::
get_instr_op_name
(
op_
)
+
" "
+
lhs_
->
print
()
+
" "
+
rhs_
->
print
()
+
")"
;
return
"("
+
Instruction
::
get_instr_op_name
(
op_
)
+
" "
+
lhs_
->
print
()
+
" "
+
rhs_
->
print
()
+
")"
;
}
}
bool
equiv
(
const
BinaryExpression
*
other
)
const
{
bool
equiv
(
const
BinaryExpression
*
other
)
const
{
if
(
op_
==
other
->
op_
and
*
lhs_
==
*
other
->
lhs_
and
*
rhs_
==
*
other
->
rhs_
)
if
(
op_
==
other
->
op_
and
*
lhs_
==
*
other
->
lhs_
and
*
rhs_
==
*
other
->
rhs_
)
return
true
;
return
true
;
else
else
return
false
;
return
false
;
}
}
BinaryExpression
(
Instruction
::
OpID
op
,
std
::
shared_ptr
<
Expression
>
lhs
,
std
::
shared_ptr
<
Expression
>
rhs
)
BinaryExpression
(
Instruction
::
OpID
op
,
std
::
shared_ptr
<
Expression
>
lhs
,
std
::
shared_ptr
<
Expression
>
rhs
)
:
Expression
(
e_bin
),
op_
(
op
),
lhs_
(
lhs
),
rhs_
(
rhs
)
{}
:
Expression
(
e_bin
),
op_
(
op
),
lhs_
(
lhs
),
rhs_
(
rhs
)
{}
private:
private:
...
@@ -93,33 +108,122 @@ class BinaryExpression : public Expression {
...
@@ -93,33 +108,122 @@ class BinaryExpression : public Expression {
};
};
class
PhiExpression
:
public
Expression
{
class
PhiExpression
:
public
Expression
{
friend
class
::
GVN
;
public:
public:
static
std
::
shared_ptr
<
PhiExpression
>
create
(
std
::
shared_ptr
<
Expression
>
lhs
,
std
::
shared_ptr
<
Expression
>
rhs
)
{
static
std
::
shared_ptr
<
PhiExpression
>
create
(
std
::
shared_ptr
<
Expression
>
lhs
,
std
::
shared_ptr
<
Expression
>
rhs
)
{
return
std
::
make_shared
<
PhiExpression
>
(
lhs
,
rhs
);
return
std
::
make_shared
<
PhiExpression
>
(
lhs
,
rhs
);
}
}
virtual
std
::
string
print
()
{
return
"(phi "
+
lhs_
->
print
()
+
" "
+
rhs_
->
print
()
+
")"
;
}
virtual
std
::
string
print
()
{
return
"(phi "
+
lhs_
->
print
()
+
" "
+
rhs_
->
print
()
+
")"
;
}
bool
equiv
(
const
PhiExpression
*
other
)
const
{
bool
equiv
(
const
PhiExpression
*
other
)
const
{
if
(
*
lhs_
==
*
other
->
lhs_
and
*
rhs_
==
*
other
->
rhs_
)
if
(
*
lhs_
==
*
other
->
lhs_
and
*
rhs_
==
*
other
->
rhs_
)
return
true
;
return
true
;
else
else
return
false
;
return
false
;
}
}
PhiExpression
(
std
::
shared_ptr
<
Expression
>
lhs
,
std
::
shared_ptr
<
Expression
>
rhs
)
PhiExpression
(
std
::
shared_ptr
<
Expression
>
lhs
,
std
::
shared_ptr
<
Expression
>
rhs
)
:
Expression
(
e_phi
),
lhs_
(
lhs
),
rhs_
(
rhs
)
{}
:
Expression
(
e_phi
),
lhs_
(
lhs
),
rhs_
(
rhs
)
{}
private:
private:
std
::
shared_ptr
<
Expression
>
lhs_
,
rhs_
;
std
::
shared_ptr
<
Expression
>
lhs_
,
rhs_
;
};
};
// type cast expression
class
CastExpression
:
public
Expression
{
public:
static
std
::
shared_ptr
<
CastExpression
>
create
(
Instruction
::
OpID
op
,
std
::
shared_ptr
<
Expression
>
src
,
Type
*
dest_type
)
{
return
std
::
make_shared
<
CastExpression
>
(
op
,
src
,
dest_type
);
}
virtual
std
::
string
print
()
{
return
"("
+
dest_ty_
->
print
()
+
" "
+
Instruction
::
get_instr_op_name
(
op_
)
+
" "
+
src_
->
print
()
+
")"
;
}
bool
equiv
(
const
CastExpression
*
other
)
const
{
return
op_
==
other
->
op_
and
src_
==
other
->
src_
and
dest_ty_
==
other
->
dest_ty_
;
}
CastExpression
(
Instruction
::
OpID
op
,
std
::
shared_ptr
<
Expression
>
src
,
Type
*
dest_type
)
:
Expression
(
e_cast
),
op_
(
op
),
src_
(
src
),
dest_ty_
(
dest_type
)
{}
private:
Instruction
::
OpID
op_
;
std
::
shared_ptr
<
Expression
>
src_
;
Type
*
dest_ty_
;
};
// type cast expression
class
GEPExpression
:
public
Expression
{
public:
static
std
::
shared_ptr
<
GEPExpression
>
create
(
std
::
shared_ptr
<
Expression
>
ptr
,
std
::
vector
<
std
::
shared_ptr
<
Expression
>>
&
idxs
)
{
return
std
::
make_shared
<
GEPExpression
>
(
ptr
,
idxs
);
}
virtual
std
::
string
print
()
{
std
::
string
ret
=
"(GEP "
+
ptr_
->
print
();
for
(
auto
idx
:
idxs_
)
ret
+=
" "
+
idx
->
print
();
return
ret
+
")"
;
}
bool
equiv
(
const
GEPExpression
*
other
)
const
{
if
(
idxs_
.
size
()
!=
other
->
idxs_
.
size
())
return
false
;
for
(
int
i
=
0
;
i
!=
idxs_
.
size
();
++
i
)
if
(
not
(
idxs_
[
i
]
==
other
->
idxs_
[
i
]))
return
false
;
return
ptr_
==
other
->
ptr_
;
}
GEPExpression
(
std
::
shared_ptr
<
Expression
>
ptr
,
std
::
vector
<
std
::
shared_ptr
<
Expression
>>
&
idxs
)
:
Expression
(
e_gep
),
ptr_
(
ptr
),
idxs_
(
idxs
)
{}
private:
std
::
shared_ptr
<
Expression
>
ptr_
;
std
::
vector
<
std
::
shared_ptr
<
Expression
>>
idxs_
;
};
// unique expression: not equal to any one else
class
UniqueExpression
:
public
Expression
{
public:
static
std
::
shared_ptr
<
UniqueExpression
>
create
(
Instruction
*
instr
)
{
return
std
::
make_shared
<
UniqueExpression
>
(
instr
);
}
virtual
std
::
string
print
()
{
return
"(UNIQUE "
+
instr_
->
print
()
+
")"
;
}
bool
equiv
(
const
UniqueExpression
*
other
)
const
{
return
false
;
}
UniqueExpression
(
Instruction
*
instr
)
:
Expression
(
e_unique
),
instr_
(
instr
)
{}
private:
std
::
shared_ptr
<
Instruction
>
instr_
;
};
}
// namespace GVNExpression
}
// namespace GVNExpression
/**
/**
* Congruence class in each partitions
* Congruence class in each partitions
* note: for constant propagation, you might need to add other fields
* note: for constant propagation, you might need to add other fields
* and for load/store redundancy detection, you most certainly need to modify the class
* and for load/store redundancy detection, you most certainly need to modify
* the class
*/
*/
struct
CongruenceClass
{
struct
CongruenceClass
{
size_t
index_
;
size_t
index_
;
// representative of the congruence class, used to replace all the members (except itself) when analysis is done
// representative of the congruence class, used to replace all the members
// (except itself) when analysis is done
Value
*
leader_
;
Value
*
leader_
;
// value expression in congruence class
// value expression in congruence class
std
::
shared_ptr
<
GVNExpression
::
Expression
>
value_expr_
;
std
::
shared_ptr
<
GVNExpression
::
Expression
>
value_expr_
;
...
@@ -128,17 +232,22 @@ struct CongruenceClass {
...
@@ -128,17 +232,22 @@ struct CongruenceClass {
// equivalent variables in one congruence class
// equivalent variables in one congruence class
std
::
set
<
Value
*>
members_
;
std
::
set
<
Value
*>
members_
;
CongruenceClass
(
size_t
index
)
:
index_
(
index
),
leader_
{},
value_expr_
{},
value_phi_
{},
members_
{}
{}
CongruenceClass
(
size_t
index
)
:
index_
(
index
),
leader_
{},
value_expr_
{},
value_phi_
{},
members_
{}
{}
bool
operator
<
(
const
CongruenceClass
&
other
)
const
{
return
this
->
index_
<
other
.
index_
;
}
bool
operator
<
(
const
CongruenceClass
&
other
)
const
{
return
this
->
index_
<
other
.
index_
;
}
bool
operator
==
(
const
CongruenceClass
&
other
)
const
;
bool
operator
==
(
const
CongruenceClass
&
other
)
const
;
};
};
namespace
std
{
namespace
std
{
template
<
>
template
<
>
// overload std::less for std::shared_ptr<CongruenceClass>, i.e. how to sort the congruence classes
// overload std::less for std::shared_ptr<CongruenceClass>, i.e. how to sort the
// congruence classes
struct
less
<
std
::
shared_ptr
<
CongruenceClass
>>
{
struct
less
<
std
::
shared_ptr
<
CongruenceClass
>>
{
bool
operator
()(
const
std
::
shared_ptr
<
CongruenceClass
>
&
a
,
const
std
::
shared_ptr
<
CongruenceClass
>
&
b
)
const
{
bool
operator
()(
const
std
::
shared_ptr
<
CongruenceClass
>
&
a
,
const
std
::
shared_ptr
<
CongruenceClass
>
&
b
)
const
{
// nullptrs should never appear in partitions, so we just dereference it
// nullptrs should never appear in partitions, so we just dereference it
return
*
a
<
*
b
;
return
*
a
<
*
b
;
}
}
...
@@ -154,17 +263,24 @@ class GVN : public Pass {
...
@@ -154,17 +263,24 @@ class GVN : public Pass {
// init for pass metadata;
// init for pass metadata;
void
initPerFunction
();
void
initPerFunction
();
// fill the following functions according to Pseudocode, **you might need to add more arguments**
// fill the following functions according to Pseudocode, **you might need to
// add more arguments**
void
detectEquivalences
();
void
detectEquivalences
();
partitions
join
(
const
partitions
&
P1
,
const
partitions
&
P2
);
partitions
join
(
const
partitions
&
P1
,
const
partitions
&
P2
);
std
::
shared_ptr
<
CongruenceClass
>
intersect
(
std
::
shared_ptr
<
CongruenceClass
>
,
std
::
shared_ptr
<
CongruenceClass
>
);
std
::
shared_ptr
<
CongruenceClass
>
intersect
(
std
::
shared_ptr
<
CongruenceClass
>
,
std
::
shared_ptr
<
CongruenceClass
>
);
partitions
transferFunction
(
Instruction
*
x
,
Value
*
e
,
partitions
pin
);
partitions
transferFunction
(
Instruction
*
x
,
Value
*
e
,
partitions
pin
);
partitions
transferFunction
(
BasicBlock
*
bb
);
partitions
transferFunction
(
BasicBlock
*
bb
);
std
::
shared_ptr
<
GVNExpression
::
PhiExpression
>
valuePhiFunc
(
std
::
shared_ptr
<
GVNExpression
::
Expression
>
,
std
::
shared_ptr
<
GVNExpression
::
PhiExpression
>
valuePhiFunc
(
const
partitions
&
);
std
::
shared_ptr
<
GVNExpression
::
Expression
>
,
std
::
shared_ptr
<
GVNExpression
::
Expression
>
valueExpr
(
Instruction
*
instr
);
BasicBlock
*
bb
);
std
::
shared_ptr
<
GVNExpression
::
Expression
>
getVN
(
const
partitions
&
pout
,
std
::
shared_ptr
<
GVNExpression
::
Expression
>
valueExpr
(
std
::
shared_ptr
<
GVNExpression
::
Expression
>
ve
);
Instruction
*
instr
,
partitions
*
part
=
nullptr
);
std
::
shared_ptr
<
GVNExpression
::
Expression
>
getVN
(
const
partitions
&
pout
,
std
::
shared_ptr
<
GVNExpression
::
Expression
>
ve
);
// replace cc members with leader
// replace cc members with leader
void
replace_cc_members
();
void
replace_cc_members
();
...
@@ -180,6 +296,7 @@ class GVN : public Pass {
...
@@ -180,6 +296,7 @@ class GVN : public Pass {
// self add
// self add
//
//
std
::
uint64_t
new_number
()
{
return
next_value_number_
++
;
}
std
::
uint64_t
new_number
()
{
return
next_value_number_
++
;
}
static
int
pretend_copy_stmt
(
Instruction
*
inst
,
BasicBlock
*
bb
);
private:
private:
bool
dump_json_
;
bool
dump_json_
;
...
@@ -191,7 +308,9 @@ class GVN : public Pass {
...
@@ -191,7 +308,9 @@ class GVN : public Pass {
std
::
unique_ptr
<
DeadCode
>
dce_
;
std
::
unique_ptr
<
DeadCode
>
dce_
;
// self add
// self add
std
::
map
<
Instruction
*
,
bool
>
_TOP
;
std
::
map
<
BasicBlock
*
,
bool
>
_TOP
;
partitions
join_helper
(
BasicBlock
*
pre1
,
BasicBlock
*
pre2
);
BasicBlock
*
curr_bb
;
};
};
bool
operator
==
(
const
GVN
::
partitions
&
p1
,
const
GVN
::
partitions
&
p2
);
bool
operator
==
(
const
GVN
::
partitions
&
p1
,
const
GVN
::
partitions
&
p2
);
src/cminusfc/.gitignore
0 → 100644
View file @
f26d91aa
cminusf_builder_stu.cpp
src/cminusfc/cminusf_builder.cpp
View file @
f26d91aa
...
@@ -5,22 +5,20 @@
...
@@ -5,22 +5,20 @@
#include "cminusf_builder.hpp"
#include "cminusf_builder.hpp"
#include "logging.hpp"
#define CONST_FP(num) ConstantFP::get((float)num, module.get())
#define CONST_FP(num) ConstantFP::get((float)num, module.get())
#define CONST_INT(num) ConstantInt::get(num, module.get())
#define CONST_INT(num) ConstantInt::get(num, module.get())
// TODO: Global Variable Declarations
// You can define global variables here
// You can define global variables here
// to store state. You can expand these
// to store state
// definitions if you need to.
//
the latest return
value
//
store temporary
value
Value
*
cur_value
=
nullptr
;
Value
*
tmp_val
=
nullptr
;
//
if var is assignment's left part, LV is tr
ue
//
whether require lval
ue
bool
LV
=
false
;
bool
require_lvalue
=
false
;
// function that is being built
// function that is being built
Function
*
cur_fun
=
nullptr
;
Function
*
cur_fun
=
nullptr
;
// detect scope pre-enter (for elegance only)
bool
pre_enter_scope
=
false
;
// types
// types
Type
*
VOID_T
;
Type
*
VOID_T
;
...
@@ -30,9 +28,22 @@ Type *INT32PTR_T;
...
@@ -30,9 +28,22 @@ Type *INT32PTR_T;
Type
*
FLOAT_T
;
Type
*
FLOAT_T
;
Type
*
FLOATPTR_T
;
Type
*
FLOATPTR_T
;
// initializer
bool
ConstantZero
*
I32Initializer
;
promote
(
IRBuilder
*
builder
,
Value
**
l_val_p
,
Value
**
r_val_p
)
{
ConstantZero
*
FloatInitializer
;
bool
is_int
;
auto
&
l_val
=
*
l_val_p
;
auto
&
r_val
=
*
r_val_p
;
if
(
l_val
->
get_type
()
==
r_val
->
get_type
())
{
is_int
=
l_val
->
get_type
()
->
is_integer_type
();
}
else
{
is_int
=
false
;
if
(
l_val
->
get_type
()
->
is_integer_type
())
l_val
=
builder
->
create_sitofp
(
l_val
,
FLOAT_T
);
else
r_val
=
builder
->
create_sitofp
(
r_val
,
FLOAT_T
);
}
return
is_int
;
}
/*
/*
* use CMinusfBuilder::Scope to construct scopes
* use CMinusfBuilder::Scope to construct scopes
...
@@ -42,182 +53,61 @@ ConstantZero *FloatInitializer;
...
@@ -42,182 +53,61 @@ ConstantZero *FloatInitializer;
* scope.find: find and return the value bound to the name
* scope.find: find and return the value bound to the name
*/
*/
void
error_exit
(
std
::
string
s
)
{
void
LOG_ERROR
<<
s
;
CminusfBuilder
::
visit
(
ASTProgram
&
node
)
{
std
::
abort
();
}
// This function makes sure that
// 1. 2 values have same type
// 2. type is either i32 or float
void
CminusfBuilder
::
biop_type_check
(
Value
*&
lvalue
,
Value
*&
rvalue
,
std
::
string
util
)
{
if
(
Type
::
is_eq_type
(
lvalue
->
get_type
(),
rvalue
->
get_type
()))
{
if
(
lvalue
->
get_type
()
->
is_integer_type
()
or
lvalue
->
get_type
()
->
is_float_type
())
{
// check for i1
if
(
Type
::
is_eq_type
(
lvalue
->
get_type
(),
INT1_T
))
{
lvalue
=
builder
->
create_zext
(
lvalue
,
INT32_T
);
rvalue
=
builder
->
create_zext
(
rvalue
,
INT32_T
);
}
}
else
error_exit
(
"not supported type cast for "
+
util
);
return
;
}
// only support cast between int and float: i32, i1, float
//
// case that integer and float is mixed, directly cast integer to float
if
(
lvalue
->
get_type
()
->
is_integer_type
()
and
rvalue
->
get_type
()
->
is_float_type
())
lvalue
=
builder
->
create_sitofp
(
lvalue
,
FLOAT_T
);
else
if
(
lvalue
->
get_type
()
->
is_float_type
()
and
rvalue
->
get_type
()
->
is_integer_type
())
rvalue
=
builder
->
create_sitofp
(
rvalue
,
FLOAT_T
);
else
if
(
lvalue
->
get_type
()
->
is_integer_type
()
and
rvalue
->
get_type
()
->
is_integer_type
())
{
// case that I32 and I1 mixed
if
(
Type
::
is_eq_type
(
lvalue
->
get_type
(),
INT1_T
))
lvalue
=
builder
->
create_zext
(
lvalue
,
INT32_T
);
else
rvalue
=
builder
->
create_zext
(
rvalue
,
INT32_T
);
}
else
{
// we only support computing among i1, i32 and float
error_exit
(
"not supported type cast for "
+
util
);
}
}
// this function makes sure value is a bool type
void
CminusfBuilder
::
cast_to_i1
(
Value
*&
value
)
{
assert
(
value
->
get_type
()
->
is_integer_type
()
or
value
->
get_type
()
->
is_float_type
());
if
(
value
->
get_type
()
->
is_float_type
())
// value = builder->create_fptosi(value, INT1_T);
value
=
builder
->
create_fcmp_ne
(
value
,
CONST_FP
(
0
));
else
if
(
Type
::
is_eq_type
(
value
->
get_type
(),
INT32_T
))
value
=
builder
->
create_icmp_ne
(
value
,
CONST_INT
(
0
));
}
void
CminusfBuilder
::
visit
(
ASTProgram
&
node
)
{
VOID_T
=
Type
::
get_void_type
(
module
.
get
());
VOID_T
=
Type
::
get_void_type
(
module
.
get
());
INT1_T
=
Type
::
get_int1_type
(
module
.
get
());
INT1_T
=
Type
::
get_int1_type
(
module
.
get
());
INT32_T
=
Type
::
get_int32_type
(
module
.
get
());
INT32_T
=
Type
::
get_int32_type
(
module
.
get
());
INT32PTR_T
=
Type
::
get_int32_ptr_type
(
module
.
get
());
INT32PTR_T
=
Type
::
get_int32_ptr_type
(
module
.
get
());
FLOAT_T
=
Type
::
get_float_type
(
module
.
get
());
FLOAT_T
=
Type
::
get_float_type
(
module
.
get
());
FLOATPTR_T
=
Type
::
get_float_ptr_type
(
module
.
get
());
FLOATPTR_T
=
Type
::
get_float_ptr_type
(
module
.
get
());
I32Initializer
=
ConstantZero
::
get
(
INT32_T
,
builder
->
get_module
());
FloatInitializer
=
ConstantZero
::
get
(
FLOAT_T
,
builder
->
get_module
());
for
(
auto
decl
:
node
.
declarations
)
{
for
(
auto
decl
:
node
.
declarations
)
{
decl
->
accept
(
*
this
);
decl
->
accept
(
*
this
);
}
}
}
}
// Done
void
void
CminusfBuilder
::
visit
(
ASTNum
&
node
)
{
CminusfBuilder
::
visit
(
ASTNum
&
node
)
{
//!TODO: This function is empty now.
if
(
node
.
type
==
TYPE_INT
)
// Add some code here.
tmp_val
=
CONST_INT
(
node
.
i_val
);
//
else
switch
(
node
.
type
)
{
tmp_val
=
CONST_FP
(
node
.
f_val
);
case
TYPE_INT
:
cur_value
=
CONST_INT
(
node
.
i_val
);
return
;
case
TYPE_FLOAT
:
cur_value
=
CONST_FP
(
node
.
f_val
);
return
;
default:
error_exit
(
"ASTNum is not int or float"
);
}
}
}
// Done
void
void
CminusfBuilder
::
visit
(
ASTVarDeclaration
&
node
)
{
CminusfBuilder
::
visit
(
ASTVarDeclaration
&
node
)
{
//!TODO: This function is empty now.
Type
*
var_type
;
// Add some code here.
if
(
node
.
type
==
TYPE_INT
)
bool
global
=
(
builder
->
get_insert_block
()
==
nullptr
);
var_type
=
Type
::
get_int32_type
(
module
.
get
());
if
(
node
.
num
)
{
else
// declares an array
var_type
=
Type
::
get_float_type
(
module
.
get
());
//
if
(
node
.
num
==
nullptr
)
{
// get array size
if
(
scope
.
in_global
())
{
node
.
num
->
accept
(
*
this
);
auto
initializer
=
ConstantZero
::
get
(
var_type
,
module
.
get
());
//
auto
var
=
GlobalVariable
::
create
(
// !no type cast here!
node
.
id
,
module
.
get
(),
var_type
,
false
,
initializer
);
if
(
not
(
node
.
num
->
type
==
TYPE_INT
))
scope
.
push
(
node
.
id
,
var
);
error_exit
(
"size of array has non-integer type"
);
}
else
{
auto
var
=
builder
->
create_alloca
(
var_type
);
int
size
=
node
.
num
->
i_val
;
scope
.
push
(
node
.
id
,
var
);
if
(
size
<=
0
)
error_exit
(
"array size["
+
std
::
to_string
(
size
)
+
"] <= 0"
);
switch
(
node
.
type
)
{
case
TYPE_INT
:
{
auto
I32Array_T
=
Type
::
get_array_type
(
INT32_T
,
size
);
if
(
global
)
cur_value
=
GlobalVariable
::
create
(
node
.
id
,
builder
->
get_module
(),
I32Array_T
,
false
,
I32Initializer
);
else
cur_value
=
builder
->
create_alloca
(
I32Array_T
);
break
;
}
case
TYPE_FLOAT
:
{
auto
FloatArray_T
=
Type
::
get_array_type
(
FLOAT_T
,
size
);
if
(
global
)
cur_value
=
GlobalVariable
::
create
(
node
.
id
,
builder
->
get_module
(),
FloatArray_T
,
false
,
FloatInitializer
);
else
cur_value
=
builder
->
create_alloca
(
FloatArray_T
);
break
;
}
default:
error_exit
(
"Variable type(not array) is not int or float"
);
}
}
assert
(
cur_value
->
get_type
()
->
is_pointer_type
()
&&
"IF SEE THIS: API ERROR"
);
}
else
{
}
else
{
// flat int or float type
auto
*
array_type
=
ArrayType
::
get
(
var_type
,
node
.
num
->
i_val
);
switch
(
node
.
type
)
{
if
(
scope
.
in_global
())
{
case
TYPE_INT
:
auto
initializer
=
ConstantZero
::
get
(
array_type
,
module
.
get
());
if
(
global
)
auto
var
=
GlobalVariable
::
create
(
cur_value
=
GlobalVariable
::
create
(
node
.
id
,
builder
->
get_module
(),
INT32_T
,
false
,
I32Initializer
);
node
.
id
,
module
.
get
(),
array_type
,
false
,
initializer
);
else
scope
.
push
(
node
.
id
,
var
);
cur_value
=
builder
->
create_alloca
(
INT32_T
);
}
else
{
break
;
auto
var
=
builder
->
create_alloca
(
array_type
);
scope
.
push
(
node
.
id
,
var
);
case
TYPE_FLOAT
:
if
(
global
)
cur_value
=
GlobalVariable
::
create
(
node
.
id
,
builder
->
get_module
(),
FLOAT_T
,
false
,
FloatInitializer
);
else
{
/* Beautiful is better than ugly.
* Explicit is better than implicit.
* Simple is better than complex.
* Complex is better than complicated.
* Flat is better than nested.
* Sparse is better than dense.
* Readability counts.
* Special cases aren't special enough to break the rules.
* Although practicality beats purity.
* Errors should never pass silently.
* Unless explicitly silenced.
* In the face of ambiguity, refuse the temptation to guess.
* There should be one-- and preferably only one --obvious way to do it.
* Although that way may not be obvious at first unless you're Dutch.
* Now is better than never.
* Although never is often better than *right* now.
* If the implementation is hard to explain, it's a bad idea.
* If the implementation is easy to explain, it may be a good idea.
* Namespaces are one honking great idea -- let's do more of those! */
// cur_value = builder->create_alloca(INT32_T);
cur_value
=
builder
->
create_alloca
(
FLOAT_T
);
}
break
;
default:
error_exit
(
"Variable type(not array) is not int or float"
);
}
}
}
}
if
(
not
scope
.
push
(
node
.
id
,
cur_value
))
error_exit
(
"variable redefined: "
+
node
.
id
);
LOG_DEBUG
<<
"add entry: "
<<
node
.
id
<<
" "
<<
cur_value
;
}
}
// Done
void
void
CminusfBuilder
::
visit
(
ASTFunDeclaration
&
node
)
{
CminusfBuilder
::
visit
(
ASTFunDeclaration
&
node
)
{
FunctionType
*
fun_type
;
FunctionType
*
fun_type
;
Type
*
ret_type
;
Type
*
ret_type
;
std
::
vector
<
Type
*>
param_types
;
std
::
vector
<
Type
*>
param_types
;
...
@@ -229,46 +119,54 @@ void CminusfBuilder::visit(ASTFunDeclaration &node) {
...
@@ -229,46 +119,54 @@ void CminusfBuilder::visit(ASTFunDeclaration &node) {
ret_type
=
VOID_T
;
ret_type
=
VOID_T
;
for
(
auto
&
param
:
node
.
params
)
{
for
(
auto
&
param
:
node
.
params
)
{
//!TODO: Please accomplish param_types.
if
(
param
->
type
==
TYPE_INT
)
{
//
if
(
param
->
isarray
)
{
// First make function BB, which needs this param type,
param_types
.
push_back
(
INT32PTR_T
);
// then set_insert_point, we can call accept to gen code,
}
else
{
switch
(
param
->
type
)
{
param_types
.
push_back
(
INT32_T
);
case
TYPE_INT
:
}
param_types
.
push_back
(
param
->
isarray
?
INT32PTR_T
:
INT32_T
);
}
else
{
break
;
if
(
param
->
isarray
)
{
case
TYPE_FLOAT
:
param_types
.
push_back
(
FLOATPTR_T
);
param_types
.
push_back
(
param
->
isarray
?
FLOATPTR_T
:
FLOAT_T
);
}
else
{
break
;
param_types
.
push_back
(
FLOAT_T
);
case
TYPE_VOID
:
}
if
(
not
param_types
.
empty
())
error_exit
(
"function parameters weird"
);
break
;
}
}
}
}
fun_type
=
FunctionType
::
get
(
ret_type
,
param_types
);
fun_type
=
FunctionType
::
get
(
ret_type
,
param_types
);
auto
fun
=
Function
::
create
(
fun_type
,
node
.
id
,
module
.
get
());
auto
fun
=
Function
::
create
(
fun_type
,
node
.
id
,
module
.
get
());
cur_fun
=
fun
;
scope
.
push
(
node
.
id
,
fun
);
scope
.
push
(
node
.
id
,
fun
);
cur_fun
=
fun
;
auto
funBB
=
BasicBlock
::
create
(
module
.
get
(),
"entry"
,
fun
);
auto
funBB
=
BasicBlock
::
create
(
module
.
get
(),
"entry"
,
fun
);
builder
->
set_insert_point
(
funBB
);
builder
->
set_insert_point
(
funBB
);
scope
.
enter
();
scope
.
enter
();
pre_enter_scope
=
true
;
std
::
vector
<
Value
*>
args
;
std
::
vector
<
Value
*>
args
;
for
(
auto
arg
=
fun
->
arg_begin
();
arg
!=
fun
->
arg_end
();
arg
++
)
{
for
(
auto
arg
=
fun
->
arg_begin
();
arg
!=
fun
->
arg_end
();
arg
++
)
{
args
.
push_back
(
*
arg
);
args
.
push_back
(
*
arg
);
}
}
for
(
int
i
=
0
;
i
<
node
.
params
.
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
node
.
params
.
size
();
++
i
)
{
//!TODO: You need to deal with params
if
(
node
.
params
[
i
]
->
isarray
)
{
// and store them in the scope.
Value
*
array_alloc
;
cur_value
=
args
[
i
];
if
(
node
.
params
[
i
]
->
type
==
TYPE_INT
)
node
.
params
[
i
]
->
accept
(
*
this
);
array_alloc
=
builder
->
create_alloca
(
INT32PTR_T
);
else
array_alloc
=
builder
->
create_alloca
(
FLOATPTR_T
);
builder
->
create_store
(
args
[
i
],
array_alloc
);
scope
.
push
(
node
.
params
[
i
]
->
id
,
array_alloc
);
}
else
{
Value
*
alloc
;
if
(
node
.
params
[
i
]
->
type
==
TYPE_INT
)
alloc
=
builder
->
create_alloca
(
INT32_T
);
else
alloc
=
builder
->
create_alloca
(
FLOAT_T
);
builder
->
create_store
(
args
[
i
],
alloc
);
scope
.
push
(
node
.
params
[
i
]
->
id
,
alloc
);
}
}
}
node
.
compound_stmt
->
accept
(
*
this
);
node
.
compound_stmt
->
accept
(
*
this
);
//
default return value
//
can't deal with return in both blocks
if
(
builder
->
get_insert_block
()
->
get_terminator
()
==
nullptr
)
{
if
(
builder
->
get_insert_block
()
->
get_terminator
()
==
nullptr
)
{
if
(
cur_fun
->
get_return_type
()
->
is_void_type
())
if
(
cur_fun
->
get_return_type
()
->
is_void_type
())
builder
->
create_void_ret
();
builder
->
create_void_ret
();
...
@@ -280,40 +178,17 @@ void CminusfBuilder::visit(ASTFunDeclaration &node) {
...
@@ -280,40 +178,17 @@ void CminusfBuilder::visit(ASTFunDeclaration &node) {
scope
.
exit
();
scope
.
exit
();
}
}
// Done
void
void
CminusfBuilder
::
visit
(
ASTParam
&
node
)
{
CminusfBuilder
::
visit
(
ASTParam
&
node
)
{}
//!TODO: This function is empty now.
// If the parameter is int|float, copy and store them
auto
param_value
=
cur_value
;
switch
(
node
.
type
)
{
case
TYPE_INT
:
{
if
(
node
.
isarray
)
cur_value
=
builder
->
create_alloca
(
INT32PTR_T
);
else
cur_value
=
builder
->
create_alloca
(
INT32_T
);
break
;
}
case
TYPE_FLOAT
:
{
if
(
node
.
isarray
)
cur_value
=
builder
->
create_alloca
(
FLOATPTR_T
);
else
cur_value
=
builder
->
create_alloca
(
FLOAT_T
);
break
;
}
case
TYPE_VOID
:
return
;
}
scope
.
push
(
node
.
id
,
cur_value
);
builder
->
create_store
(
param_value
,
cur_value
);
}
// Done?
void
void
CminusfBuilder
::
visit
(
ASTCompoundStmt
&
node
)
{
CminusfBuilder
::
visit
(
ASTCompoundStmt
&
node
)
{
//!TODO: This function is not complete.
bool
need_exit_scope
=
!
pre_enter_scope
;
// You may need to add some code here
if
(
pre_enter_scope
)
{
// to deal with complex statements.
pre_enter_scope
=
false
;
}
else
{
scope
.
enter
();
scope
.
enter
();
}
for
(
auto
&
decl
:
node
.
local_declarations
)
{
for
(
auto
&
decl
:
node
.
local_declarations
)
{
decl
->
accept
(
*
this
);
decl
->
accept
(
*
this
);
...
@@ -325,413 +200,303 @@ void CminusfBuilder::visit(ASTCompoundStmt &node) {
...
@@ -325,413 +200,303 @@ void CminusfBuilder::visit(ASTCompoundStmt &node) {
break
;
break
;
}
}
scope
.
exit
();
if
(
need_exit_scope
)
{
scope
.
exit
();
}
}
}
// Done
void
void
CminusfBuilder
::
visit
(
ASTExpressionStmt
&
node
)
{
CminusfBuilder
::
visit
(
ASTExpressionStmt
&
node
)
{
//!TODO: This function is empty now.
if
(
node
.
expression
!=
nullptr
)
// Add some code here.
if
(
node
.
expression
)
node
.
expression
->
accept
(
*
this
);
node
.
expression
->
accept
(
*
this
);
}
}
// Done
void
void
CminusfBuilder
::
visit
(
ASTSelectionStmt
&
node
)
{
CminusfBuilder
::
visit
(
ASTSelectionStmt
&
node
)
{
//!TODO: This function is empty now.
// Add some code here.
scope
.
enter
();
node
.
expression
->
accept
(
*
this
);
node
.
expression
->
accept
(
*
this
);
auto
cond
=
cur_value
;
auto
ret_val
=
tmp_val
;
cast_to_i1
(
cond
);
auto
trueBB
=
BasicBlock
::
create
(
module
.
get
(),
""
,
cur_fun
);
BasicBlock
*
falseBB
{};
auto
ifBB
=
BasicBlock
::
create
(
builder
->
get_module
(),
""
,
cur_fun
);
auto
contBB
=
BasicBlock
::
create
(
module
.
get
(),
""
,
cur_fun
);
auto
endBB
=
BasicBlock
::
create
(
builder
->
get_module
(),
""
,
cur_fun
);
Value
*
cond_val
;
if
(
node
.
else_statement
)
{
if
(
ret_val
->
get_type
()
->
is_integer_type
())
auto
elseBB
=
BasicBlock
::
create
(
builder
->
get_module
(),
""
,
cur_fun
);
cond_val
=
builder
->
create_icmp_ne
(
ret_val
,
CONST_INT
(
0
));
builder
->
create_cond_br
(
cond
,
ifBB
,
elseBB
);
else
cond_val
=
builder
->
create_fcmp_ne
(
ret_val
,
CONST_FP
(
0.
));
builder
->
set_insert_point
(
ifBB
);
node
.
if_statement
->
accept
(
*
this
);
builder
->
create_br
(
endBB
);
builder
->
set_insert_point
(
elseBB
);
node
.
else_statement
->
accept
(
*
this
);
builder
->
create_br
(
endBB
);
builder
->
set_insert_point
(
endBB
);
if
(
node
.
else_statement
==
nullptr
)
{
builder
->
create_cond_br
(
cond_val
,
trueBB
,
contBB
);
}
else
{
}
else
{
builder
->
create_cond_br
(
cond
,
ifBB
,
endBB
);
falseBB
=
BasicBlock
::
create
(
module
.
get
(),
""
,
cur_fun
);
builder
->
create_cond_br
(
cond_val
,
trueBB
,
falseBB
);
}
builder
->
set_insert_point
(
trueBB
);
node
.
if_statement
->
accept
(
*
this
);
builder
->
set_insert_point
(
ifBB
);
if
(
builder
->
get_insert_block
()
->
get_terminator
()
==
nullptr
)
node
.
if_statement
->
accept
(
*
this
);
builder
->
create_br
(
contBB
);
builder
->
create_br
(
endBB
);
builder
->
set_insert_point
(
endBB
);
if
(
node
.
else_statement
==
nullptr
)
{
// falseBB->erase_from_parent(); // did not clean up memory
}
else
{
builder
->
set_insert_point
(
falseBB
);
node
.
else_statement
->
accept
(
*
this
);
if
(
builder
->
get_insert_block
()
->
get_terminator
()
==
nullptr
)
builder
->
create_br
(
contBB
);
}
}
scope
.
exit
();
}
// Done
builder
->
set_insert_point
(
contBB
);
void
CminusfBuilder
::
visit
(
ASTIterationStmt
&
node
)
{
}
//!TODO: This function is empty now.
// Add some code here.
scope
.
enter
();
auto
HEAD
=
BasicBlock
::
create
(
builder
->
get_module
(),
""
,
cur_fun
);
auto
BODY
=
BasicBlock
::
create
(
builder
->
get_module
(),
""
,
cur_fun
);
auto
END
=
BasicBlock
::
create
(
builder
->
get_module
(),
""
,
cur_fun
);
builder
->
create_br
(
HEAD
);
builder
->
set_insert_point
(
HEAD
);
void
CminusfBuilder
::
visit
(
ASTIterationStmt
&
node
)
{
auto
exprBB
=
BasicBlock
::
create
(
module
.
get
(),
""
,
cur_fun
);
if
(
builder
->
get_insert_block
()
->
get_terminator
()
==
nullptr
)
builder
->
create_br
(
exprBB
);
builder
->
set_insert_point
(
exprBB
);
node
.
expression
->
accept
(
*
this
);
node
.
expression
->
accept
(
*
this
);
auto
cond
=
cur_value
;
auto
ret_val
=
tmp_val
;
cast_to_i1
(
cond
);
auto
trueBB
=
BasicBlock
::
create
(
module
.
get
(),
""
,
cur_fun
);
builder
->
create_cond_br
(
cond
,
BODY
,
END
);
auto
contBB
=
BasicBlock
::
create
(
module
.
get
(),
""
,
cur_fun
);
Value
*
cond_val
;
if
(
ret_val
->
get_type
()
->
is_integer_type
())
cond_val
=
builder
->
create_icmp_ne
(
ret_val
,
CONST_INT
(
0
));
else
cond_val
=
builder
->
create_fcmp_ne
(
ret_val
,
CONST_FP
(
0.
));
builder
->
set_insert_point
(
BODY
);
builder
->
create_cond_br
(
cond_val
,
trueBB
,
contBB
);
builder
->
set_insert_point
(
trueBB
);
node
.
statement
->
accept
(
*
this
);
node
.
statement
->
accept
(
*
this
);
builder
->
create_br
(
HEAD
);
if
(
builder
->
get_insert_block
()
->
get_terminator
()
==
nullptr
)
builder
->
create_br
(
exprBB
);
builder
->
set_insert_point
(
END
);
builder
->
set_insert_point
(
contBB
);
scope
.
exit
();
}
}
// Done
void
void
CminusfBuilder
::
visit
(
ASTReturnStmt
&
node
)
{
CminusfBuilder
::
visit
(
ASTReturnStmt
&
node
)
{
if
(
node
.
expression
==
nullptr
)
{
if
(
node
.
expression
==
nullptr
)
{
builder
->
create_void_ret
();
builder
->
create_void_ret
();
}
else
{
}
else
{
//!TODO: The given code is incomplete.
auto
fun_ret_type
=
cur_fun
->
get_function_type
()
->
get_return_type
();
// You need to solve other return cases (e.g. return an integer).
//
node
.
expression
->
accept
(
*
this
);
node
.
expression
->
accept
(
*
this
);
// type cast
if
(
fun_ret_type
!=
tmp_val
->
get_type
())
{
// return type can only be int, float or void
if
(
fun_ret_type
->
is_integer_type
())
if
(
not
Type
::
is_eq_type
(
cur_fun
->
get_return_type
(),
cur_value
->
get_type
()))
{
tmp_val
=
builder
->
create_fptosi
(
tmp_val
,
INT32_T
);
if
(
not
cur_value
->
get_type
()
->
is_integer_type
()
and
not
cur_value
->
get_type
()
->
is_float_type
())
error_exit
(
"unsupported return type"
);
if
(
cur_value
->
get_type
()
->
is_float_type
())
cur_value
=
builder
->
create_fptosi
(
cur_value
,
INT32_T
);
else
if
(
cur_fun
->
get_return_type
()
->
is_float_type
())
cur_value
=
builder
->
create_sitofp
(
cur_value
,
FLOAT_T
);
else
else
cur_value
=
builder
->
create_zext
(
cur_value
,
INT32
_T
);
tmp_val
=
builder
->
create_sitofp
(
tmp_val
,
FLOAT
_T
);
}
}
builder
->
create_ret
(
cur_value
);
builder
->
create_ret
(
tmp_val
);
LOG_DEBUG
<<
"create ret:
\n
"
<<
builder
->
get_module
()
->
print
();
}
}
}
}
// Done
void
// if LV is marked, return memory addr
CminusfBuilder
::
visit
(
ASTVar
&
node
)
{
// else return value stored inside
auto
var
=
scope
.
find
(
node
.
id
);
void
CminusfBuilder
::
visit
(
ASTVar
&
node
)
{
assert
(
var
!=
nullptr
);
//!TODO: This function is empty now.
auto
is_int
=
// Add some code here.
var
->
get_type
()
->
get_pointer_element_type
()
->
is_integer_type
();
//
auto
is_float
=
// First it's pointer type, the pointed elements have 3 cases:
var
->
get_type
()
->
get_pointer_element_type
()
->
is_float_type
();
// 1. int or float
auto
is_ptr
=
// 2. [i32 x n] or [float x n]
var
->
get_type
()
->
get_pointer_element_type
()
->
is_pointer_type
();
// 3. int*
bool
should_return_lvalue
=
require_lvalue
;
auto
memory
=
scope
.
find
(
node
.
id
);
require_lvalue
=
false
;
Value
*
addr
;
if
(
node
.
expression
==
nullptr
)
{
if
(
memory
==
nullptr
)
if
(
should_return_lvalue
)
{
error_exit
(
"variable "
+
node
.
id
+
" not declared"
);
tmp_val
=
var
;
LOG_DEBUG
<<
"find entry: "
<<
node
.
id
<<
" "
<<
memory
;
require_lvalue
=
false
;
assert
(
memory
->
get_type
()
->
is_pointer_type
());
auto
element_type
=
memory
->
get_type
()
->
get_pointer_element_type
();
if
(
node
.
expression
)
{
// e.g. int a[10]; // mem is [i32 x 10]*
bool
old_LV
=
LV
;
LV
=
false
;
node
.
expression
->
accept
(
*
this
);
LV
=
old_LV
;
// subscription type cast
if
(
not
Type
::
is_eq_type
(
cur_value
->
get_type
(),
INT32_T
))
{
if
(
Type
::
is_eq_type
(
cur_value
->
get_type
(),
FLOAT_T
))
cur_value
=
builder
->
create_fptosi
(
cur_value
,
INT32_T
);
else
error_exit
(
"bad type for subscription"
);
}
auto
cond
=
builder
->
create_icmp_lt
(
cur_value
,
CONST_INT
(
0
));
auto
except_func
=
scope
.
find
(
"neg_idx_except"
);
auto
TBB
=
BasicBlock
::
create
(
builder
->
get_module
(),
""
,
cur_fun
);
auto
passBB
=
BasicBlock
::
create
(
builder
->
get_module
(),
""
,
cur_fun
);
builder
->
create_cond_br
(
cond
,
TBB
,
passBB
);
builder
->
set_insert_point
(
TBB
);
builder
->
create_call
(
except_func
,
{});
builder
->
create_br
(
passBB
);
builder
->
set_insert_point
(
passBB
);
// Now the subscription is in cur_value, which is good value
// We should focus on the var type:
//
// assert it's pointer type
if
(
element_type
->
is_float_type
()
or
element_type
->
is_integer_type
())
{
// 1. int or float
error_exit
(
"invalid types for array subscript"
);
}
else
if
(
element_type
->
is_array_type
())
{
// 2. [i32 x n] or [float x n]
//
// addr is actually &memory[0][cur_value]
addr
=
builder
->
create_gep
(
memory
,
{
CONST_INT
(
0
),
cur_value
});
}
else
if
(
element_type
->
is_pointer_type
())
{
// 3. int*
// what to do:
// - int**->int*
// - seek addr
//
// addr is actually &(*memory)[cur_value]
addr
=
builder
->
create_load
(
memory
);
addr
=
builder
->
create_gep
(
addr
,
{
cur_value
});
}
// The logic for this part is the same as `int|float` without subscription.
// Cause we have subscription to find the particular element(int or float),
// we make `addr` its memory address.
}
else
{
// e.g. int a; // a is i32*
if
(
element_type
->
is_float_type
()
or
element_type
->
is_integer_type
())
{
// 1. int or float
// addr is the element's addr
addr
=
memory
;
}
else
{
}
else
{
if
(
LV
)
if
(
is_int
||
is_float
||
is_ptr
)
{
error_exit
(
"error: pointer or array type is not assignable"
);
tmp_val
=
builder
->
create_load
(
var
);
// For array* or pointer* type, the right-value behaviour is quite special,
}
else
{
// so treat them apart.
tmp_val
=
if
(
element_type
->
is_array_type
())
{
builder
->
create_gep
(
var
,
{
CONST_INT
(
0
),
CONST_INT
(
0
)});
// 2. [i32 x n] or [float x n]
// addr is the first element's address in the array
cur_value
=
builder
->
create_gep
(
memory
,
{
CONST_INT
(
0
),
CONST_INT
(
0
)});
}
else
if
(
element_type
->
is_pointer_type
())
{
// 3. int*
// addr is the content in the memory, which is actually pointer type
cur_value
=
builder
->
create_load
(
memory
);
}
}
return
;
}
}
}
if
(
LV
)
{
LOG_INFO
<<
"directly return addr"
<<
node
.
id
;
cur_value
=
addr
;
}
else
{
}
else
{
LOG_INFO
<<
"create load for var: "
<<
node
.
id
;
node
.
expression
->
accept
(
*
this
);
cur_value
=
builder
->
create_load
(
addr
);
auto
val
=
tmp_val
;
Value
*
is_neg
;
auto
exceptBB
=
BasicBlock
::
create
(
module
.
get
(),
""
,
cur_fun
);
auto
contBB
=
BasicBlock
::
create
(
module
.
get
(),
""
,
cur_fun
);
if
(
val
->
get_type
()
->
is_float_type
())
val
=
builder
->
create_fptosi
(
val
,
INT32_T
);
is_neg
=
builder
->
create_icmp_lt
(
val
,
CONST_INT
(
0
));
builder
->
create_cond_br
(
is_neg
,
exceptBB
,
contBB
);
builder
->
set_insert_point
(
exceptBB
);
auto
neg_idx_except_fun
=
scope
.
find
(
"neg_idx_except"
);
builder
->
create_call
(
static_cast
<
Function
*>
(
neg_idx_except_fun
),
{});
if
(
cur_fun
->
get_return_type
()
->
is_void_type
())
builder
->
create_void_ret
();
else
if
(
cur_fun
->
get_return_type
()
->
is_float_type
())
builder
->
create_ret
(
CONST_FP
(
0.
));
else
builder
->
create_ret
(
CONST_INT
(
0
));
builder
->
set_insert_point
(
contBB
);
Value
*
tmp_ptr
;
if
(
is_int
||
is_float
)
tmp_ptr
=
builder
->
create_gep
(
var
,
{
val
});
else
if
(
is_ptr
)
{
auto
array_load
=
builder
->
create_load
(
var
);
tmp_ptr
=
builder
->
create_gep
(
array_load
,
{
val
});
}
else
tmp_ptr
=
builder
->
create_gep
(
var
,
{
CONST_INT
(
0
),
val
});
if
(
should_return_lvalue
)
{
tmp_val
=
tmp_ptr
;
require_lvalue
=
false
;
}
else
{
tmp_val
=
builder
->
create_load
(
tmp_ptr
);
}
}
}
}
}
// Done
void
void
CminusfBuilder
::
visit
(
ASTAssignExpression
&
node
)
{
CminusfBuilder
::
visit
(
ASTAssignExpression
&
node
)
{
//!TODO: This function is empty now.
// Add some code here.
//
LV
=
true
;
node
.
var
->
accept
(
*
this
);
LV
=
false
;
auto
addr
=
cur_value
;
node
.
expression
->
accept
(
*
this
);
node
.
expression
->
accept
(
*
this
);
auto
expr_result
=
tmp_val
;
assert
(
addr
->
get_type
()
->
get_pointer_element_type
()
!=
nullptr
);
require_lvalue
=
true
;
// type cast: left is a pointer type, pointed to i32 or float
node
.
var
->
accept
(
*
this
);
if
(
not
Type
::
is_eq_type
(
addr
->
get_type
()
->
get_pointer_element_type
(),
cur_value
->
get_type
()))
{
auto
var_addr
=
tmp_val
;
if
(
cur_value
->
get_type
()
->
is_float_type
())
if
(
var_addr
->
get_type
()
->
get_pointer_element_type
()
!=
cur_value
=
builder
->
create_fptosi
(
cur_value
,
INT32_T
);
expr_result
->
get_type
())
{
else
if
(
addr
->
get_type
()
->
get_pointer_element_type
()
->
is_float_type
())
if
(
expr_result
->
get_type
()
==
INT32_T
)
cur_value
=
builder
->
create_sitofp
(
cur_value
,
FLOAT_T
);
expr_result
=
builder
->
create_sitofp
(
expr_result
,
FLOAT_T
);
else
if
(
Type
::
is_eq_type
(
cur_value
->
get_type
(),
INT1_T
))
cur_value
=
builder
->
create_zext
(
cur_value
,
INT32_T
);
else
else
e
rror_exit
(
"bad type for assignment"
);
e
xpr_result
=
builder
->
create_fptosi
(
expr_result
,
INT32_T
);
}
}
// gen code
builder
->
create_store
(
expr_result
,
var_addr
);
builder
->
create_store
(
cur_value
,
addr
)
;
tmp_val
=
expr_result
;
}
}
// Done
void
void
CminusfBuilder
::
visit
(
ASTSimpleExpression
&
node
)
{
CminusfBuilder
::
visit
(
ASTSimpleExpression
&
node
)
{
//!TODO: This function is empty now.
if
(
node
.
additive_expression_r
==
nullptr
)
{
// Add some code here.
//
if
(
node
.
additive_expression_r
)
{
node
.
additive_expression_l
->
accept
(
*
this
);
node
.
additive_expression_l
->
accept
(
*
this
);
auto
lvalue
=
cur_value
;
}
else
{
node
.
additive_expression_l
->
accept
(
*
this
);
auto
l_val
=
tmp_val
;
node
.
additive_expression_r
->
accept
(
*
this
);
node
.
additive_expression_r
->
accept
(
*
this
);
auto
rvalue
=
cur_value
;
auto
r_val
=
tmp_val
;
// check type
bool
is_int
=
promote
(
&*
builder
,
&
l_val
,
&
r_val
);
biop_type_check
(
lvalue
,
rvalue
,
"cmp"
);
Value
*
cmp
;
bool
float_cmp
=
lvalue
->
get_type
()
->
is_float_type
();
switch
(
node
.
op
)
{
switch
(
node
.
op
)
{
case
OP_L
E
:
{
case
OP_L
T
:
if
(
float_cmp
)
if
(
is_int
)
c
ur_value
=
builder
->
create_fcmp_le
(
lvalue
,
rvalue
);
c
mp
=
builder
->
create_icmp_lt
(
l_val
,
r_val
);
else
else
c
ur_value
=
builder
->
create_icmp_le
(
lvalue
,
rvalue
);
c
mp
=
builder
->
create_fcmp_lt
(
l_val
,
r_val
);
break
;
break
;
}
case
OP_LE
:
case
OP_LT
:
{
if
(
is_int
)
if
(
float_cmp
)
cmp
=
builder
->
create_icmp_le
(
l_val
,
r_val
);
cur_value
=
builder
->
create_fcmp_lt
(
lvalue
,
rvalue
);
else
else
c
ur_value
=
builder
->
create_icmp_lt
(
lvalue
,
rvalue
);
c
mp
=
builder
->
create_fcmp_le
(
l_val
,
r_val
);
break
;
break
;
}
case
OP_GE
:
case
OP_GT
:
{
if
(
is_int
)
if
(
float_cmp
)
cmp
=
builder
->
create_icmp_ge
(
l_val
,
r_val
);
cur_value
=
builder
->
create_fcmp_gt
(
lvalue
,
rvalue
);
else
else
c
ur_value
=
builder
->
create_icmp_gt
(
lvalue
,
rvalue
);
c
mp
=
builder
->
create_fcmp_ge
(
l_val
,
r_val
);
break
;
break
;
}
case
OP_GT
:
case
OP_GE
:
{
if
(
is_int
)
if
(
float_cmp
)
cmp
=
builder
->
create_icmp_gt
(
l_val
,
r_val
);
cur_value
=
builder
->
create_fcmp_ge
(
lvalue
,
rvalue
);
else
else
c
ur_value
=
builder
->
create_icmp_ge
(
lvalue
,
rvalue
);
c
mp
=
builder
->
create_fcmp_gt
(
l_val
,
r_val
);
break
;
break
;
}
case
OP_EQ
:
case
OP_EQ
:
{
if
(
is_int
)
if
(
float_cmp
)
cmp
=
builder
->
create_icmp_eq
(
l_val
,
r_val
);
cur_value
=
builder
->
create_fcmp_eq
(
lvalue
,
rvalue
);
else
else
c
ur_value
=
builder
->
create_icmp_eq
(
lvalue
,
rvalue
);
c
mp
=
builder
->
create_fcmp_eq
(
l_val
,
r_val
);
break
;
break
;
}
case
OP_NEQ
:
case
OP_NEQ
:
{
if
(
is_int
)
if
(
float_cmp
)
cmp
=
builder
->
create_icmp_ne
(
l_val
,
r_val
);
cur_value
=
builder
->
create_fcmp_ne
(
lvalue
,
rvalue
);
else
else
c
ur_value
=
builder
->
create_icmp_ne
(
lvalue
,
rvalue
);
c
mp
=
builder
->
create_fcmp_ne
(
l_val
,
r_val
);
break
;
break
;
}
}
}
}
else
node
.
additive_expression_l
->
accept
(
*
this
);
tmp_val
=
builder
->
create_zext
(
cmp
,
INT32_T
);
}
}
}
// Done
void
void
CminusfBuilder
::
visit
(
ASTAdditiveExpression
&
node
)
{
CminusfBuilder
::
visit
(
ASTAdditiveExpression
&
node
)
{
//!TODO: This function is empty now.
if
(
node
.
additive_expression
==
nullptr
)
{
// Add some code here.
node
.
term
->
accept
(
*
this
);
//
}
else
{
if
(
node
.
additive_expression
)
{
node
.
additive_expression
->
accept
(
*
this
);
node
.
additive_expression
->
accept
(
*
this
);
auto
l
value
=
cur_value
;
auto
l
_val
=
tmp_val
;
node
.
term
->
accept
(
*
this
);
node
.
term
->
accept
(
*
this
);
auto
rvalue
=
cur_value
;
auto
r_val
=
tmp_val
;
// check type
bool
is_int
=
promote
(
&*
builder
,
&
l_val
,
&
r_val
);
biop_type_check
(
lvalue
,
rvalue
,
"addop"
);
bool
float_type
=
lvalue
->
get_type
()
->
is_float_type
();
// now left and right is the same type
switch
(
node
.
op
)
{
switch
(
node
.
op
)
{
case
OP_PLUS
:
{
case
OP_PLUS
:
if
(
float_type
)
if
(
is_int
)
cur_value
=
builder
->
create_fadd
(
lvalue
,
rvalue
);
tmp_val
=
builder
->
create_iadd
(
l_val
,
r_val
);
else
else
cur_value
=
builder
->
create_iadd
(
lvalue
,
rvalue
);
tmp_val
=
builder
->
create_fadd
(
l_val
,
r_val
);
break
;
break
;
}
case
OP_MINUS
:
case
OP_MINUS
:
{
if
(
is_int
)
if
(
float_type
)
tmp_val
=
builder
->
create_isub
(
l_val
,
r_val
);
cur_value
=
builder
->
create_fsub
(
lvalue
,
rvalue
);
else
else
cur_value
=
builder
->
create_isub
(
lvalue
,
rvalue
);
tmp_val
=
builder
->
create_fsub
(
l_val
,
r_val
);
break
;
break
;
}
}
}
}
else
}
node
.
term
->
accept
(
*
this
);
}
}
// Done
void
void
CminusfBuilder
::
visit
(
ASTTerm
&
node
)
{
CminusfBuilder
::
visit
(
ASTTerm
&
node
)
{
//!TODO: This function is empty now.
if
(
node
.
term
==
nullptr
)
{
// Add some code here.
node
.
factor
->
accept
(
*
this
);
if
(
node
.
term
)
{
}
else
{
node
.
term
->
accept
(
*
this
);
node
.
term
->
accept
(
*
this
);
auto
l
value
=
cur_value
;
auto
l
_val
=
tmp_val
;
node
.
factor
->
accept
(
*
this
);
node
.
factor
->
accept
(
*
this
);
auto
rvalue
=
cur_value
;
auto
r_val
=
tmp_val
;
// check type
bool
is_int
=
promote
(
&*
builder
,
&
l_val
,
&
r_val
);
biop_type_check
(
lvalue
,
rvalue
,
"mul"
);
bool
float_type
=
lvalue
->
get_type
()
->
is_float_type
();
// now left and right is the same type
switch
(
node
.
op
)
{
switch
(
node
.
op
)
{
case
OP_MUL
:
{
case
OP_MUL
:
if
(
float_type
)
if
(
is_int
)
cur_value
=
builder
->
create_fmul
(
lvalue
,
rvalue
);
tmp_val
=
builder
->
create_imul
(
l_val
,
r_val
);
else
else
cur_value
=
builder
->
create_imul
(
lvalue
,
rvalue
);
tmp_val
=
builder
->
create_fmul
(
l_val
,
r_val
);
break
;
break
;
}
case
OP_DIV
:
case
OP_DIV
:
{
if
(
is_int
)
if
(
float_type
)
tmp_val
=
builder
->
create_isdiv
(
l_val
,
r_val
);
cur_value
=
builder
->
create_fdiv
(
lvalue
,
rvalue
);
else
else
cur_value
=
builder
->
create_isdiv
(
lvalue
,
rvalue
);
tmp_val
=
builder
->
create_fdiv
(
l_val
,
r_val
);
break
;
break
;
}
}
}
}
else
}
node
.
factor
->
accept
(
*
this
);
}
}
// Done
void
void
CminusfBuilder
::
visit
(
ASTCall
&
node
)
{
CminusfBuilder
::
visit
(
ASTCall
&
node
)
{
//!TODO: This function is empty now.
auto
fun
=
static_cast
<
Function
*>
(
scope
.
find
(
node
.
id
));
// Add some code here.
Function
*
func
=
static_cast
<
Function
*>
(
scope
.
find
(
node
.
id
));
std
::
vector
<
Value
*>
args
;
std
::
vector
<
Value
*>
args
;
if
(
func
==
nullptr
)
auto
param_type
=
fun
->
get_function_type
()
->
param_begin
();
error_exit
(
"function "
+
node
.
id
+
" not declared"
);
for
(
auto
&
arg
:
node
.
args
)
{
if
(
node
.
args
.
size
()
!=
func
->
get_num_of_args
())
arg
->
accept
(
*
this
);
error_exit
(
"expect "
+
std
::
to_string
(
func
->
get_num_of_args
())
+
" params, but "
+
if
(
!
tmp_val
->
get_type
()
->
is_pointer_type
()
&&
std
::
to_string
(
node
.
args
.
size
())
+
" is given"
);
*
param_type
!=
tmp_val
->
get_type
())
{
// check every argument
if
(
tmp_val
->
get_type
()
->
is_integer_type
())
for
(
int
i
=
0
;
i
!=
node
.
args
.
size
();
++
i
)
{
tmp_val
=
builder
->
create_sitofp
(
tmp_val
,
FLOAT_T
);
// ith parameter's type
else
Type
*
param_type
=
func
->
get_function_type
()
->
get_param_type
(
i
);
tmp_val
=
builder
->
create_fptosi
(
tmp_val
,
INT32_T
);
node
.
args
[
i
]
->
accept
(
*
this
);
// type cast
if
(
not
Type
::
is_eq_type
(
param_type
,
cur_value
->
get_type
()))
{
if
(
param_type
->
is_pointer_type
())
{
// shouldn't need type cast for pointer, logically
if
(
param_type
->
get_pointer_element_type
()
->
is_integer_type
()
or
param_type
->
get_pointer_element_type
()
->
is_float_type
())
error_exit
(
"BUG HERE: ASTVar return value is not int* or float*"
);
else
error_exit
(
"BUG HERE: function param needs weird pointer type"
);
}
else
if
(
param_type
->
is_integer_type
()
or
param_type
->
is_float_type
())
{
// need type cast between int and float
if
(
not
cur_value
->
get_type
()
->
is_integer_type
()
and
not
cur_value
->
get_type
()
->
is_float_type
())
error_exit
(
"unexpected type cast!"
);
if
(
param_type
->
is_float_type
())
cur_value
=
builder
->
create_sitofp
(
cur_value
,
FLOAT_T
);
else
if
(
param_type
->
is_integer_type
())
if
(
cur_value
->
get_type
()
->
is_integer_type
())
cur_value
=
builder
->
create_zext
(
cur_value
,
INT32_T
);
else
cur_value
=
builder
->
create_fptosi
(
cur_value
,
INT32_T
);
else
error_exit
(
"unexpected type cast!"
);
}
else
error_exit
(
"unexpected case when casting arguments for function call "
+
node
.
id
);
}
}
args
.
push_back
(
tmp_val
);
// now cur_value fits the param type
param_type
++
;
args
.
push_back
(
cur_value
);
}
}
cur_value
=
builder
->
create_call
(
func
,
args
);
tmp_val
=
builder
->
create_call
(
static_cast
<
Function
*>
(
fun
),
args
);
}
}
src/lightir/Instruction.cpp
View file @
f26d91aa
...
@@ -10,7 +10,10 @@
...
@@ -10,7 +10,10 @@
#include <cassert>
#include <cassert>
#include <vector>
#include <vector>
Instruction
::
Instruction
(
Type
*
ty
,
OpID
id
,
unsigned
num_ops
,
BasicBlock
*
parent
)
Instruction
::
Instruction
(
Type
*
ty
,
OpID
id
,
unsigned
num_ops
,
BasicBlock
*
parent
)
:
User
(
ty
,
""
,
num_ops
),
op_id_
(
id
),
num_ops_
(
num_ops
),
parent_
(
parent
)
{
:
User
(
ty
,
""
,
num_ops
),
op_id_
(
id
),
num_ops_
(
num_ops
),
parent_
(
parent
)
{
parent_
->
add_instruction
(
this
);
parent_
->
add_instruction
(
this
);
}
}
...
@@ -18,56 +21,75 @@ Instruction::Instruction(Type *ty, OpID id, unsigned num_ops, BasicBlock *parent
...
@@ -18,56 +21,75 @@ Instruction::Instruction(Type *ty, OpID id, unsigned num_ops, BasicBlock *parent
Instruction
::
Instruction
(
Type
*
ty
,
OpID
id
,
unsigned
num_ops
)
Instruction
::
Instruction
(
Type
*
ty
,
OpID
id
,
unsigned
num_ops
)
:
User
(
ty
,
""
,
num_ops
),
op_id_
(
id
),
num_ops_
(
num_ops
),
parent_
(
nullptr
)
{}
:
User
(
ty
,
""
,
num_ops
),
op_id_
(
id
),
num_ops_
(
num_ops
),
parent_
(
nullptr
)
{}
Function
*
Instruction
::
get_function
()
{
return
parent_
->
get_parent
();
}
Function
*
Instruction
::
get_function
()
{
return
parent_
->
get_parent
();
}
Module
*
Instruction
::
get_module
()
{
return
parent_
->
get_module
();
}
Module
*
Instruction
::
get_module
()
{
return
parent_
->
get_module
();
}
BinaryInst
::
BinaryInst
(
Type
*
ty
,
OpID
id
,
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
)
:
BaseInst
<
BinaryInst
>
(
ty
,
id
,
2
,
bb
)
{
BinaryInst
::
BinaryInst
(
Type
*
ty
,
OpID
id
,
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
)
:
BaseInst
<
BinaryInst
>
(
ty
,
id
,
2
,
bb
)
{
set_operand
(
0
,
v1
);
set_operand
(
0
,
v1
);
set_operand
(
1
,
v2
);
set_operand
(
1
,
v2
);
// assertValid();
// assertValid();
}
}
void
BinaryInst
::
assertValid
()
{
void
BinaryInst
::
assertValid
()
{
assert
(
get_operand
(
0
)
->
get_type
()
->
is_integer_type
());
assert
(
get_operand
(
0
)
->
get_type
()
->
is_integer_type
());
assert
(
get_operand
(
1
)
->
get_type
()
->
is_integer_type
());
assert
(
get_operand
(
1
)
->
get_type
()
->
is_integer_type
());
assert
(
static_cast
<
IntegerType
*>
(
get_operand
(
0
)
->
get_type
())
->
get_num_bits
()
==
assert
(
static_cast
<
IntegerType
*>
(
get_operand
(
1
)
->
get_type
())
->
get_num_bits
());
static_cast
<
IntegerType
*>
(
get_operand
(
0
)
->
get_type
())
->
get_num_bits
()
==
static_cast
<
IntegerType
*>
(
get_operand
(
1
)
->
get_type
())
->
get_num_bits
());
}
}
BinaryInst
*
BinaryInst
::
create_add
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
Module
*
m
)
{
BinaryInst
*
BinaryInst
::
create_add
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
Module
*
m
)
{
return
create
(
Type
::
get_int32_type
(
m
),
Instruction
::
add
,
v1
,
v2
,
bb
);
return
create
(
Type
::
get_int32_type
(
m
),
Instruction
::
add
,
v1
,
v2
,
bb
);
}
}
BinaryInst
*
BinaryInst
::
create_sub
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
Module
*
m
)
{
BinaryInst
*
BinaryInst
::
create_sub
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
Module
*
m
)
{
return
create
(
Type
::
get_int32_type
(
m
),
Instruction
::
sub
,
v1
,
v2
,
bb
);
return
create
(
Type
::
get_int32_type
(
m
),
Instruction
::
sub
,
v1
,
v2
,
bb
);
}
}
BinaryInst
*
BinaryInst
::
create_mul
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
Module
*
m
)
{
BinaryInst
*
BinaryInst
::
create_mul
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
Module
*
m
)
{
return
create
(
Type
::
get_int32_type
(
m
),
Instruction
::
mul
,
v1
,
v2
,
bb
);
return
create
(
Type
::
get_int32_type
(
m
),
Instruction
::
mul
,
v1
,
v2
,
bb
);
}
}
BinaryInst
*
BinaryInst
::
create_sdiv
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
Module
*
m
)
{
BinaryInst
*
BinaryInst
::
create_sdiv
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
Module
*
m
)
{
return
create
(
Type
::
get_int32_type
(
m
),
Instruction
::
sdiv
,
v1
,
v2
,
bb
);
return
create
(
Type
::
get_int32_type
(
m
),
Instruction
::
sdiv
,
v1
,
v2
,
bb
);
}
}
BinaryInst
*
BinaryInst
::
create_fadd
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
Module
*
m
)
{
BinaryInst
*
BinaryInst
::
create_fadd
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
Module
*
m
)
{
return
create
(
Type
::
get_float_type
(
m
),
Instruction
::
fadd
,
v1
,
v2
,
bb
);
return
create
(
Type
::
get_float_type
(
m
),
Instruction
::
fadd
,
v1
,
v2
,
bb
);
}
}
BinaryInst
*
BinaryInst
::
create_fsub
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
Module
*
m
)
{
BinaryInst
*
BinaryInst
::
create_fsub
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
Module
*
m
)
{
return
create
(
Type
::
get_float_type
(
m
),
Instruction
::
fsub
,
v1
,
v2
,
bb
);
return
create
(
Type
::
get_float_type
(
m
),
Instruction
::
fsub
,
v1
,
v2
,
bb
);
}
}
BinaryInst
*
BinaryInst
::
create_fmul
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
Module
*
m
)
{
BinaryInst
*
BinaryInst
::
create_fmul
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
Module
*
m
)
{
return
create
(
Type
::
get_float_type
(
m
),
Instruction
::
fmul
,
v1
,
v2
,
bb
);
return
create
(
Type
::
get_float_type
(
m
),
Instruction
::
fmul
,
v1
,
v2
,
bb
);
}
}
BinaryInst
*
BinaryInst
::
create_fdiv
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
Module
*
m
)
{
BinaryInst
*
BinaryInst
::
create_fdiv
(
Value
*
v1
,
Value
*
v2
,
BasicBlock
*
bb
,
Module
*
m
)
{
return
create
(
Type
::
get_float_type
(
m
),
Instruction
::
fdiv
,
v1
,
v2
,
bb
);
return
create
(
Type
::
get_float_type
(
m
),
Instruction
::
fdiv
,
v1
,
v2
,
bb
);
}
}
std
::
string
BinaryInst
::
print
()
{
std
::
string
BinaryInst
::
print
()
{
std
::
string
instr_ir
;
std
::
string
instr_ir
;
instr_ir
+=
"%"
;
instr_ir
+=
"%"
;
instr_ir
+=
this
->
get_name
();
instr_ir
+=
this
->
get_name
();
...
@@ -78,7 +100,8 @@ std::string BinaryInst::print() {
...
@@ -78,7 +100,8 @@ std::string BinaryInst::print() {
instr_ir
+=
" "
;
instr_ir
+=
" "
;
instr_ir
+=
print_as_op
(
this
->
get_operand
(
0
),
false
);
instr_ir
+=
print_as_op
(
this
->
get_operand
(
0
),
false
);
instr_ir
+=
", "
;
instr_ir
+=
", "
;
if
(
Type
::
is_eq_type
(
this
->
get_operand
(
0
)
->
get_type
(),
this
->
get_operand
(
1
)
->
get_type
()))
{
if
(
Type
::
is_eq_type
(
this
->
get_operand
(
0
)
->
get_type
(),
this
->
get_operand
(
1
)
->
get_type
()))
{
instr_ir
+=
print_as_op
(
this
->
get_operand
(
1
),
false
);
instr_ir
+=
print_as_op
(
this
->
get_operand
(
1
),
false
);
}
else
{
}
else
{
instr_ir
+=
print_as_op
(
this
->
get_operand
(
1
),
true
);
instr_ir
+=
print_as_op
(
this
->
get_operand
(
1
),
true
);
...
@@ -93,18 +116,27 @@ CmpInst::CmpInst(Type *ty, CmpOp op, Value *lhs, Value *rhs, BasicBlock *bb)
...
@@ -93,18 +116,27 @@ CmpInst::CmpInst(Type *ty, CmpOp op, Value *lhs, Value *rhs, BasicBlock *bb)
// assertValid();
// assertValid();
}
}
void
CmpInst
::
assertValid
()
{
void
CmpInst
::
assertValid
()
{
assert
(
get_operand
(
0
)
->
get_type
()
->
is_integer_type
());
assert
(
get_operand
(
0
)
->
get_type
()
->
is_integer_type
());
assert
(
get_operand
(
1
)
->
get_type
()
->
is_integer_type
());
assert
(
get_operand
(
1
)
->
get_type
()
->
is_integer_type
());
assert
(
static_cast
<
IntegerType
*>
(
get_operand
(
0
)
->
get_type
())
->
get_num_bits
()
==
assert
(
static_cast
<
IntegerType
*>
(
get_operand
(
1
)
->
get_type
())
->
get_num_bits
());
static_cast
<
IntegerType
*>
(
get_operand
(
0
)
->
get_type
())
}
->
get_num_bits
()
==
static_cast
<
IntegerType
*>
(
get_operand
(
1
)
->
get_type
())
->
get_num_bits
());
CmpInst
*
CmpInst
::
create_cmp
(
CmpOp
op
,
Value
*
lhs
,
Value
*
rhs
,
BasicBlock
*
bb
,
Module
*
m
)
{
}
CmpInst
*
CmpInst
::
create_cmp
(
CmpOp
op
,
Value
*
lhs
,
Value
*
rhs
,
BasicBlock
*
bb
,
Module
*
m
)
{
return
create
(
m
->
get_int1_type
(),
op
,
lhs
,
rhs
,
bb
);
return
create
(
m
->
get_int1_type
(),
op
,
lhs
,
rhs
,
bb
);
}
}
std
::
string
CmpInst
::
print
()
{
std
::
string
CmpInst
::
print
()
{
std
::
string
instr_ir
;
std
::
string
instr_ir
;
instr_ir
+=
"%"
;
instr_ir
+=
"%"
;
instr_ir
+=
this
->
get_name
();
instr_ir
+=
this
->
get_name
();
...
@@ -117,7 +149,8 @@ std::string CmpInst::print() {
...
@@ -117,7 +149,8 @@ std::string CmpInst::print() {
instr_ir
+=
" "
;
instr_ir
+=
" "
;
instr_ir
+=
print_as_op
(
this
->
get_operand
(
0
),
false
);
instr_ir
+=
print_as_op
(
this
->
get_operand
(
0
),
false
);
instr_ir
+=
", "
;
instr_ir
+=
", "
;
if
(
Type
::
is_eq_type
(
this
->
get_operand
(
0
)
->
get_type
(),
this
->
get_operand
(
1
)
->
get_type
()))
{
if
(
Type
::
is_eq_type
(
this
->
get_operand
(
0
)
->
get_type
(),
this
->
get_operand
(
1
)
->
get_type
()))
{
instr_ir
+=
print_as_op
(
this
->
get_operand
(
1
),
false
);
instr_ir
+=
print_as_op
(
this
->
get_operand
(
1
),
false
);
}
else
{
}
else
{
instr_ir
+=
print_as_op
(
this
->
get_operand
(
1
),
true
);
instr_ir
+=
print_as_op
(
this
->
get_operand
(
1
),
true
);
...
@@ -132,16 +165,23 @@ FCmpInst::FCmpInst(Type *ty, CmpOp op, Value *lhs, Value *rhs, BasicBlock *bb)
...
@@ -132,16 +165,23 @@ FCmpInst::FCmpInst(Type *ty, CmpOp op, Value *lhs, Value *rhs, BasicBlock *bb)
// assertValid();
// assertValid();
}
}
void
FCmpInst
::
assert_valid
()
{
void
FCmpInst
::
assert_valid
()
{
assert
(
get_operand
(
0
)
->
get_type
()
->
is_float_type
());
assert
(
get_operand
(
0
)
->
get_type
()
->
is_float_type
());
assert
(
get_operand
(
1
)
->
get_type
()
->
is_float_type
());
assert
(
get_operand
(
1
)
->
get_type
()
->
is_float_type
());
}
}
FCmpInst
*
FCmpInst
::
create_fcmp
(
CmpOp
op
,
Value
*
lhs
,
Value
*
rhs
,
BasicBlock
*
bb
,
Module
*
m
)
{
FCmpInst
*
FCmpInst
::
create_fcmp
(
CmpOp
op
,
Value
*
lhs
,
Value
*
rhs
,
BasicBlock
*
bb
,
Module
*
m
)
{
return
create
(
m
->
get_int1_type
(),
op
,
lhs
,
rhs
,
bb
);
return
create
(
m
->
get_int1_type
(),
op
,
lhs
,
rhs
,
bb
);
}
}
std
::
string
FCmpInst
::
print
()
{
std
::
string
FCmpInst
::
print
()
{
std
::
string
instr_ir
;
std
::
string
instr_ir
;
instr_ir
+=
"%"
;
instr_ir
+=
"%"
;
instr_ir
+=
this
->
get_name
();
instr_ir
+=
this
->
get_name
();
...
@@ -154,7 +194,8 @@ std::string FCmpInst::print() {
...
@@ -154,7 +194,8 @@ std::string FCmpInst::print() {
instr_ir
+=
" "
;
instr_ir
+=
" "
;
instr_ir
+=
print_as_op
(
this
->
get_operand
(
0
),
false
);
instr_ir
+=
print_as_op
(
this
->
get_operand
(
0
),
false
);
instr_ir
+=
","
;
instr_ir
+=
","
;
if
(
Type
::
is_eq_type
(
this
->
get_operand
(
0
)
->
get_type
(),
this
->
get_operand
(
1
)
->
get_type
()))
{
if
(
Type
::
is_eq_type
(
this
->
get_operand
(
0
)
->
get_type
(),
this
->
get_operand
(
1
)
->
get_type
()))
{
instr_ir
+=
print_as_op
(
this
->
get_operand
(
1
),
false
);
instr_ir
+=
print_as_op
(
this
->
get_operand
(
1
),
false
);
}
else
{
}
else
{
instr_ir
+=
print_as_op
(
this
->
get_operand
(
1
),
true
);
instr_ir
+=
print_as_op
(
this
->
get_operand
(
1
),
true
);
...
@@ -163,7 +204,10 @@ std::string FCmpInst::print() {
...
@@ -163,7 +204,10 @@ std::string FCmpInst::print() {
}
}
CallInst
::
CallInst
(
Function
*
func
,
std
::
vector
<
Value
*>
args
,
BasicBlock
*
bb
)
CallInst
::
CallInst
(
Function
*
func
,
std
::
vector
<
Value
*>
args
,
BasicBlock
*
bb
)
:
BaseInst
<
CallInst
>
(
func
->
get_return_type
(),
Instruction
::
call
,
args
.
size
()
+
1
,
bb
)
{
:
BaseInst
<
CallInst
>
(
func
->
get_return_type
(),
Instruction
::
call
,
args
.
size
()
+
1
,
bb
)
{
assert
(
func
->
get_num_of_args
()
==
args
.
size
());
assert
(
func
->
get_num_of_args
()
==
args
.
size
());
int
num_ops
=
args
.
size
()
+
1
;
int
num_ops
=
args
.
size
()
+
1
;
set_operand
(
0
,
func
);
set_operand
(
0
,
func
);
...
@@ -172,13 +216,18 @@ CallInst::CallInst(Function *func, std::vector<Value *> args, BasicBlock *bb)
...
@@ -172,13 +216,18 @@ CallInst::CallInst(Function *func, std::vector<Value *> args, BasicBlock *bb)
}
}
}
}
CallInst
*
CallInst
::
create
(
Function
*
func
,
std
::
vector
<
Value
*>
args
,
BasicBlock
*
bb
)
{
CallInst
*
CallInst
::
create
(
Function
*
func
,
std
::
vector
<
Value
*>
args
,
BasicBlock
*
bb
)
{
return
BaseInst
<
CallInst
>::
create
(
func
,
args
,
bb
);
return
BaseInst
<
CallInst
>::
create
(
func
,
args
,
bb
);
}
}
FunctionType
*
CallInst
::
get_function_type
()
const
{
return
static_cast
<
FunctionType
*>
(
get_operand
(
0
)
->
get_type
());
}
FunctionType
*
CallInst
::
get_function_type
()
const
{
return
static_cast
<
FunctionType
*>
(
get_operand
(
0
)
->
get_type
());
}
std
::
string
CallInst
::
print
()
{
std
::
string
CallInst
::
print
()
{
std
::
string
instr_ir
;
std
::
string
instr_ir
;
if
(
!
this
->
is_void
())
{
if
(
!
this
->
is_void
())
{
instr_ir
+=
"%"
;
instr_ir
+=
"%"
;
...
@@ -190,7 +239,8 @@ std::string CallInst::print() {
...
@@ -190,7 +239,8 @@ std::string CallInst::print() {
instr_ir
+=
this
->
get_function_type
()
->
get_return_type
()
->
print
();
instr_ir
+=
this
->
get_function_type
()
->
get_return_type
()
->
print
();
instr_ir
+=
" "
;
instr_ir
+=
" "
;
assert
(
dynamic_cast
<
Function
*>
(
this
->
get_operand
(
0
))
&&
"Wrong call operand function"
);
assert
(
dynamic_cast
<
Function
*>
(
this
->
get_operand
(
0
))
&&
"Wrong call operand function"
);
instr_ir
+=
print_as_op
(
this
->
get_operand
(
0
),
false
);
instr_ir
+=
print_as_op
(
this
->
get_operand
(
0
),
false
);
instr_ir
+=
"("
;
instr_ir
+=
"("
;
for
(
int
i
=
1
;
i
<
this
->
get_num_operand
();
i
++
)
{
for
(
int
i
=
1
;
i
<
this
->
get_num_operand
();
i
++
)
{
...
@@ -204,19 +254,32 @@ std::string CallInst::print() {
...
@@ -204,19 +254,32 @@ std::string CallInst::print() {
return
instr_ir
;
return
instr_ir
;
}
}
BranchInst
::
BranchInst
(
Value
*
cond
,
BasicBlock
*
if_true
,
BasicBlock
*
if_false
,
BasicBlock
*
bb
)
BranchInst
::
BranchInst
(
Value
*
cond
,
:
BaseInst
<
BranchInst
>
(
Type
::
get_void_type
(
if_true
->
get_module
()),
Instruction
::
br
,
3
,
bb
)
{
BasicBlock
*
if_true
,
BasicBlock
*
if_false
,
BasicBlock
*
bb
)
:
BaseInst
<
BranchInst
>
(
Type
::
get_void_type
(
if_true
->
get_module
()),
Instruction
::
br
,
3
,
bb
)
{
set_operand
(
0
,
cond
);
set_operand
(
0
,
cond
);
set_operand
(
1
,
if_true
);
set_operand
(
1
,
if_true
);
set_operand
(
2
,
if_false
);
set_operand
(
2
,
if_false
);
}
}
BranchInst
::
BranchInst
(
BasicBlock
*
if_true
,
BasicBlock
*
bb
)
BranchInst
::
BranchInst
(
BasicBlock
*
if_true
,
BasicBlock
*
bb
)
:
BaseInst
<
BranchInst
>
(
Type
::
get_void_type
(
if_true
->
get_module
()),
Instruction
::
br
,
1
,
bb
)
{
:
BaseInst
<
BranchInst
>
(
Type
::
get_void_type
(
if_true
->
get_module
()),
Instruction
::
br
,
1
,
bb
)
{
set_operand
(
0
,
if_true
);
set_operand
(
0
,
if_true
);
}
}
BranchInst
*
BranchInst
::
create_cond_br
(
Value
*
cond
,
BasicBlock
*
if_true
,
BasicBlock
*
if_false
,
BasicBlock
*
bb
)
{
BranchInst
*
BranchInst
::
create_cond_br
(
Value
*
cond
,
BasicBlock
*
if_true
,
BasicBlock
*
if_false
,
BasicBlock
*
bb
)
{
if_true
->
add_pre_basic_block
(
bb
);
if_true
->
add_pre_basic_block
(
bb
);
if_false
->
add_pre_basic_block
(
bb
);
if_false
->
add_pre_basic_block
(
bb
);
bb
->
add_succ_basic_block
(
if_false
);
bb
->
add_succ_basic_block
(
if_false
);
...
@@ -225,16 +288,21 @@ BranchInst *BranchInst::create_cond_br(Value *cond, BasicBlock *if_true, BasicBl
...
@@ -225,16 +288,21 @@ BranchInst *BranchInst::create_cond_br(Value *cond, BasicBlock *if_true, BasicBl
return
create
(
cond
,
if_true
,
if_false
,
bb
);
return
create
(
cond
,
if_true
,
if_false
,
bb
);
}
}
BranchInst
*
BranchInst
::
create_br
(
BasicBlock
*
if_true
,
BasicBlock
*
bb
)
{
BranchInst
*
BranchInst
::
create_br
(
BasicBlock
*
if_true
,
BasicBlock
*
bb
)
{
if_true
->
add_pre_basic_block
(
bb
);
if_true
->
add_pre_basic_block
(
bb
);
bb
->
add_succ_basic_block
(
if_true
);
bb
->
add_succ_basic_block
(
if_true
);
return
create
(
if_true
,
bb
);
return
create
(
if_true
,
bb
);
}
}
bool
BranchInst
::
is_cond_br
()
const
{
return
get_num_operand
()
==
3
;
}
bool
BranchInst
::
is_cond_br
()
const
{
return
get_num_operand
()
==
3
;
}
std
::
string
BranchInst
::
print
()
{
std
::
string
BranchInst
::
print
()
{
std
::
string
instr_ir
;
std
::
string
instr_ir
;
instr_ir
+=
this
->
get_module
()
->
get_instr_op_name
(
this
->
get_instr_type
());
instr_ir
+=
this
->
get_module
()
->
get_instr_op_name
(
this
->
get_instr_type
());
instr_ir
+=
" "
;
instr_ir
+=
" "
;
...
@@ -250,20 +318,36 @@ std::string BranchInst::print() {
...
@@ -250,20 +318,36 @@ std::string BranchInst::print() {
}
}
ReturnInst
::
ReturnInst
(
Value
*
val
,
BasicBlock
*
bb
)
ReturnInst
::
ReturnInst
(
Value
*
val
,
BasicBlock
*
bb
)
:
BaseInst
<
ReturnInst
>
(
Type
::
get_void_type
(
bb
->
get_module
()),
Instruction
::
ret
,
1
,
bb
)
{
:
BaseInst
<
ReturnInst
>
(
Type
::
get_void_type
(
bb
->
get_module
()),
Instruction
::
ret
,
1
,
bb
)
{
set_operand
(
0
,
val
);
set_operand
(
0
,
val
);
}
}
ReturnInst
::
ReturnInst
(
BasicBlock
*
bb
)
ReturnInst
::
ReturnInst
(
BasicBlock
*
bb
)
:
BaseInst
<
ReturnInst
>
(
Type
::
get_void_type
(
bb
->
get_module
()),
Instruction
::
ret
,
0
,
bb
)
{}
:
BaseInst
<
ReturnInst
>
(
Type
::
get_void_type
(
bb
->
get_module
()),
Instruction
::
ret
,
0
,
bb
)
{}
ReturnInst
*
ReturnInst
::
create_ret
(
Value
*
val
,
BasicBlock
*
bb
)
{
return
create
(
val
,
bb
);
}
ReturnInst
*
ReturnInst
::
create_ret
(
Value
*
val
,
BasicBlock
*
bb
)
{
return
create
(
val
,
bb
);
}
ReturnInst
*
ReturnInst
::
create_void_ret
(
BasicBlock
*
bb
)
{
return
create
(
bb
);
}
ReturnInst
*
ReturnInst
::
create_void_ret
(
BasicBlock
*
bb
)
{
return
create
(
bb
);
}
bool
ReturnInst
::
is_void_ret
()
const
{
return
get_num_operand
()
==
0
;
}
bool
ReturnInst
::
is_void_ret
()
const
{
return
get_num_operand
()
==
0
;
}
std
::
string
ReturnInst
::
print
()
{
std
::
string
ReturnInst
::
print
()
{
std
::
string
instr_ir
;
std
::
string
instr_ir
;
instr_ir
+=
this
->
get_module
()
->
get_instr_op_name
(
this
->
get_instr_type
());
instr_ir
+=
this
->
get_module
()
->
get_instr_op_name
(
this
->
get_instr_type
());
instr_ir
+=
" "
;
instr_ir
+=
" "
;
...
@@ -278,7 +362,9 @@ std::string ReturnInst::print() {
...
@@ -278,7 +362,9 @@ std::string ReturnInst::print() {
return
instr_ir
;
return
instr_ir
;
}
}
GetElementPtrInst
::
GetElementPtrInst
(
Value
*
ptr
,
std
::
vector
<
Value
*>
idxs
,
BasicBlock
*
bb
)
GetElementPtrInst
::
GetElementPtrInst
(
Value
*
ptr
,
std
::
vector
<
Value
*>
idxs
,
BasicBlock
*
bb
)
:
BaseInst
<
GetElementPtrInst
>
(
PointerType
::
get
(
get_element_type
(
ptr
,
idxs
)),
:
BaseInst
<
GetElementPtrInst
>
(
PointerType
::
get
(
get_element_type
(
ptr
,
idxs
)),
Instruction
::
getelementptr
,
Instruction
::
getelementptr
,
1
+
idxs
.
size
(),
1
+
idxs
.
size
(),
...
@@ -290,10 +376,12 @@ GetElementPtrInst::GetElementPtrInst(Value *ptr, std::vector<Value *> idxs, Basi
...
@@ -290,10 +376,12 @@ GetElementPtrInst::GetElementPtrInst(Value *ptr, std::vector<Value *> idxs, Basi
element_ty_
=
get_element_type
(
ptr
,
idxs
);
element_ty_
=
get_element_type
(
ptr
,
idxs
);
}
}
Type
*
GetElementPtrInst
::
get_element_type
(
Value
*
ptr
,
std
::
vector
<
Value
*>
idxs
)
{
Type
*
GetElementPtrInst
::
get_element_type
(
Value
*
ptr
,
std
::
vector
<
Value
*>
idxs
)
{
Type
*
ty
=
ptr
->
get_type
()
->
get_pointer_element_type
();
Type
*
ty
=
ptr
->
get_type
()
->
get_pointer_element_type
();
assert
(
"GetElementPtrInst ptr is wrong type"
&&
assert
(
(
ty
->
is_array_type
()
||
ty
->
is_integer_type
()
||
ty
->
is_float_type
()));
"GetElementPtrInst ptr is wrong type"
&&
(
ty
->
is_array_type
()
||
ty
->
is_integer_type
()
||
ty
->
is_float_type
()));
if
(
ty
->
is_array_type
())
{
if
(
ty
->
is_array_type
())
{
ArrayType
*
arr_ty
=
static_cast
<
ArrayType
*>
(
ty
);
ArrayType
*
arr_ty
=
static_cast
<
ArrayType
*>
(
ty
);
for
(
int
i
=
1
;
i
<
idxs
.
size
();
i
++
)
{
for
(
int
i
=
1
;
i
<
idxs
.
size
();
i
++
)
{
...
@@ -309,13 +397,20 @@ Type *GetElementPtrInst::get_element_type(Value *ptr, std::vector<Value *> idxs)
...
@@ -309,13 +397,20 @@ Type *GetElementPtrInst::get_element_type(Value *ptr, std::vector<Value *> idxs)
return
ty
;
return
ty
;
}
}
Type
*
GetElementPtrInst
::
get_element_type
()
const
{
return
element_ty_
;
}
Type
*
GetElementPtrInst
::
get_element_type
()
const
{
return
element_ty_
;
}
GetElementPtrInst
*
GetElementPtrInst
::
create_gep
(
Value
*
ptr
,
std
::
vector
<
Value
*>
idxs
,
BasicBlock
*
bb
)
{
GetElementPtrInst
*
GetElementPtrInst
::
create_gep
(
Value
*
ptr
,
std
::
vector
<
Value
*>
idxs
,
BasicBlock
*
bb
)
{
return
create
(
ptr
,
idxs
,
bb
);
return
create
(
ptr
,
idxs
,
bb
);
}
}
std
::
string
GetElementPtrInst
::
print
()
{
std
::
string
GetElementPtrInst
::
print
()
{
std
::
string
instr_ir
;
std
::
string
instr_ir
;
instr_ir
+=
"%"
;
instr_ir
+=
"%"
;
instr_ir
+=
this
->
get_name
();
instr_ir
+=
this
->
get_name
();
...
@@ -323,7 +418,8 @@ std::string GetElementPtrInst::print() {
...
@@ -323,7 +418,8 @@ std::string GetElementPtrInst::print() {
instr_ir
+=
this
->
get_module
()
->
get_instr_op_name
(
this
->
get_instr_type
());
instr_ir
+=
this
->
get_module
()
->
get_instr_op_name
(
this
->
get_instr_type
());
instr_ir
+=
" "
;
instr_ir
+=
" "
;
assert
(
this
->
get_operand
(
0
)
->
get_type
()
->
is_pointer_type
());
assert
(
this
->
get_operand
(
0
)
->
get_type
()
->
is_pointer_type
());
instr_ir
+=
this
->
get_operand
(
0
)
->
get_type
()
->
get_pointer_element_type
()
->
print
();
instr_ir
+=
this
->
get_operand
(
0
)
->
get_type
()
->
get_pointer_element_type
()
->
print
();
instr_ir
+=
", "
;
instr_ir
+=
", "
;
for
(
int
i
=
0
;
i
<
this
->
get_num_operand
();
i
++
)
{
for
(
int
i
=
0
;
i
<
this
->
get_num_operand
();
i
++
)
{
if
(
i
>
0
)
if
(
i
>
0
)
...
@@ -336,14 +432,21 @@ std::string GetElementPtrInst::print() {
...
@@ -336,14 +432,21 @@ std::string GetElementPtrInst::print() {
}
}
StoreInst
::
StoreInst
(
Value
*
val
,
Value
*
ptr
,
BasicBlock
*
bb
)
StoreInst
::
StoreInst
(
Value
*
val
,
Value
*
ptr
,
BasicBlock
*
bb
)
:
BaseInst
<
StoreInst
>
(
Type
::
get_void_type
(
bb
->
get_module
()),
Instruction
::
store
,
2
,
bb
)
{
:
BaseInst
<
StoreInst
>
(
Type
::
get_void_type
(
bb
->
get_module
()),
Instruction
::
store
,
2
,
bb
)
{
set_operand
(
0
,
val
);
set_operand
(
0
,
val
);
set_operand
(
1
,
ptr
);
set_operand
(
1
,
ptr
);
}
}
StoreInst
*
StoreInst
::
create_store
(
Value
*
val
,
Value
*
ptr
,
BasicBlock
*
bb
)
{
return
create
(
val
,
ptr
,
bb
);
}
StoreInst
*
StoreInst
::
create_store
(
Value
*
val
,
Value
*
ptr
,
BasicBlock
*
bb
)
{
return
create
(
val
,
ptr
,
bb
);
}
std
::
string
StoreInst
::
print
()
{
std
::
string
StoreInst
::
print
()
{
std
::
string
instr_ir
;
std
::
string
instr_ir
;
instr_ir
+=
this
->
get_module
()
->
get_instr_op_name
(
this
->
get_instr_type
());
instr_ir
+=
this
->
get_module
()
->
get_instr_op_name
(
this
->
get_instr_type
());
instr_ir
+=
" "
;
instr_ir
+=
" "
;
...
@@ -355,19 +458,27 @@ std::string StoreInst::print() {
...
@@ -355,19 +458,27 @@ std::string StoreInst::print() {
return
instr_ir
;
return
instr_ir
;
}
}
LoadInst
::
LoadInst
(
Type
*
ty
,
Value
*
ptr
,
BasicBlock
*
bb
)
:
BaseInst
<
LoadInst
>
(
ty
,
Instruction
::
load
,
1
,
bb
)
{
LoadInst
::
LoadInst
(
Type
*
ty
,
Value
*
ptr
,
BasicBlock
*
bb
)
:
BaseInst
<
LoadInst
>
(
ty
,
Instruction
::
load
,
1
,
bb
)
{
assert
(
ptr
->
get_type
()
->
is_pointer_type
());
assert
(
ptr
->
get_type
()
->
is_pointer_type
());
assert
(
ty
==
static_cast
<
PointerType
*>
(
ptr
->
get_type
())
->
get_element_type
());
assert
(
ty
==
static_cast
<
PointerType
*>
(
ptr
->
get_type
())
->
get_element_type
());
set_operand
(
0
,
ptr
);
set_operand
(
0
,
ptr
);
}
}
LoadInst
*
LoadInst
::
create_load
(
Type
*
ty
,
Value
*
ptr
,
BasicBlock
*
bb
)
{
return
create
(
ty
,
ptr
,
bb
);
}
LoadInst
*
LoadInst
::
create_load
(
Type
*
ty
,
Value
*
ptr
,
BasicBlock
*
bb
)
{
return
create
(
ty
,
ptr
,
bb
);
}
Type
*
LoadInst
::
get_load_type
()
const
{
Type
*
return
static_cast
<
PointerType
*>
(
get_operand
(
0
)
->
get_type
())
->
get_element_type
();
LoadInst
::
get_load_type
()
const
{
return
static_cast
<
PointerType
*>
(
get_operand
(
0
)
->
get_type
())
->
get_element_type
();
}
}
std
::
string
LoadInst
::
print
()
{
std
::
string
LoadInst
::
print
()
{
std
::
string
instr_ir
;
std
::
string
instr_ir
;
instr_ir
+=
"%"
;
instr_ir
+=
"%"
;
instr_ir
+=
this
->
get_name
();
instr_ir
+=
this
->
get_name
();
...
@@ -375,7 +486,8 @@ std::string LoadInst::print() {
...
@@ -375,7 +486,8 @@ std::string LoadInst::print() {
instr_ir
+=
this
->
get_module
()
->
get_instr_op_name
(
this
->
get_instr_type
());
instr_ir
+=
this
->
get_module
()
->
get_instr_op_name
(
this
->
get_instr_type
());
instr_ir
+=
" "
;
instr_ir
+=
" "
;
assert
(
this
->
get_operand
(
0
)
->
get_type
()
->
is_pointer_type
());
assert
(
this
->
get_operand
(
0
)
->
get_type
()
->
is_pointer_type
());
instr_ir
+=
this
->
get_operand
(
0
)
->
get_type
()
->
get_pointer_element_type
()
->
print
();
instr_ir
+=
this
->
get_operand
(
0
)
->
get_type
()
->
get_pointer_element_type
()
->
print
();
instr_ir
+=
","
;
instr_ir
+=
","
;
instr_ir
+=
" "
;
instr_ir
+=
" "
;
instr_ir
+=
print_as_op
(
this
->
get_operand
(
0
),
true
);
instr_ir
+=
print_as_op
(
this
->
get_operand
(
0
),
true
);
...
@@ -383,13 +495,21 @@ std::string LoadInst::print() {
...
@@ -383,13 +495,21 @@ std::string LoadInst::print() {
}
}
AllocaInst
::
AllocaInst
(
Type
*
ty
,
BasicBlock
*
bb
)
AllocaInst
::
AllocaInst
(
Type
*
ty
,
BasicBlock
*
bb
)
:
BaseInst
<
AllocaInst
>
(
PointerType
::
get
(
ty
),
Instruction
::
alloca
,
0
,
bb
),
alloca_ty_
(
ty
)
{}
:
BaseInst
<
AllocaInst
>
(
PointerType
::
get
(
ty
),
Instruction
::
alloca
,
0
,
bb
)
,
alloca_ty_
(
ty
)
{}
AllocaInst
*
AllocaInst
::
create_alloca
(
Type
*
ty
,
BasicBlock
*
bb
)
{
return
create
(
ty
,
bb
);
}
AllocaInst
*
AllocaInst
::
create_alloca
(
Type
*
ty
,
BasicBlock
*
bb
)
{
return
create
(
ty
,
bb
);
}
Type
*
AllocaInst
::
get_alloca_type
()
const
{
return
alloca_ty_
;
}
Type
*
AllocaInst
::
get_alloca_type
()
const
{
return
alloca_ty_
;
}
std
::
string
AllocaInst
::
print
()
{
std
::
string
AllocaInst
::
print
()
{
std
::
string
instr_ir
;
std
::
string
instr_ir
;
instr_ir
+=
"%"
;
instr_ir
+=
"%"
;
instr_ir
+=
this
->
get_name
();
instr_ir
+=
this
->
get_name
();
...
@@ -400,15 +520,23 @@ std::string AllocaInst::print() {
...
@@ -400,15 +520,23 @@ std::string AllocaInst::print() {
return
instr_ir
;
return
instr_ir
;
}
}
ZextInst
::
ZextInst
(
OpID
op
,
Value
*
val
,
Type
*
ty
,
BasicBlock
*
bb
)
:
BaseInst
<
ZextInst
>
(
ty
,
op
,
1
,
bb
),
dest_ty_
(
ty
)
{
ZextInst
::
ZextInst
(
OpID
op
,
Value
*
val
,
Type
*
ty
,
BasicBlock
*
bb
)
:
BaseInst
<
ZextInst
>
(
ty
,
op
,
1
,
bb
),
dest_ty_
(
ty
)
{
set_operand
(
0
,
val
);
set_operand
(
0
,
val
);
}
}
ZextInst
*
ZextInst
::
create_zext
(
Value
*
val
,
Type
*
ty
,
BasicBlock
*
bb
)
{
return
create
(
Instruction
::
zext
,
val
,
ty
,
bb
);
}
ZextInst
*
ZextInst
::
create_zext
(
Value
*
val
,
Type
*
ty
,
BasicBlock
*
bb
)
{
return
create
(
Instruction
::
zext
,
val
,
ty
,
bb
);
}
Type
*
ZextInst
::
get_dest_type
()
const
{
return
dest_ty_
;
}
Type
*
ZextInst
::
get_dest_type
()
const
{
return
dest_ty_
;
}
std
::
string
ZextInst
::
print
()
{
std
::
string
ZextInst
::
print
()
{
std
::
string
instr_ir
;
std
::
string
instr_ir
;
instr_ir
+=
"%"
;
instr_ir
+=
"%"
;
instr_ir
+=
this
->
get_name
();
instr_ir
+=
this
->
get_name
();
...
@@ -428,13 +556,18 @@ FpToSiInst::FpToSiInst(OpID op, Value *val, Type *ty, BasicBlock *bb)
...
@@ -428,13 +556,18 @@ FpToSiInst::FpToSiInst(OpID op, Value *val, Type *ty, BasicBlock *bb)
set_operand
(
0
,
val
);
set_operand
(
0
,
val
);
}
}
FpToSiInst
*
FpToSiInst
::
create_fptosi
(
Value
*
val
,
Type
*
ty
,
BasicBlock
*
bb
)
{
FpToSiInst
*
FpToSiInst
::
create_fptosi
(
Value
*
val
,
Type
*
ty
,
BasicBlock
*
bb
)
{
return
create
(
Instruction
::
fptosi
,
val
,
ty
,
bb
);
return
create
(
Instruction
::
fptosi
,
val
,
ty
,
bb
);
}
}
Type
*
FpToSiInst
::
get_dest_type
()
const
{
return
dest_ty_
;
}
Type
*
FpToSiInst
::
get_dest_type
()
const
{
return
dest_ty_
;
}
std
::
string
FpToSiInst
::
print
()
{
std
::
string
FpToSiInst
::
print
()
{
std
::
string
instr_ir
;
std
::
string
instr_ir
;
instr_ir
+=
"%"
;
instr_ir
+=
"%"
;
instr_ir
+=
this
->
get_name
();
instr_ir
+=
this
->
get_name
();
...
@@ -454,13 +587,18 @@ SiToFpInst::SiToFpInst(OpID op, Value *val, Type *ty, BasicBlock *bb)
...
@@ -454,13 +587,18 @@ SiToFpInst::SiToFpInst(OpID op, Value *val, Type *ty, BasicBlock *bb)
set_operand
(
0
,
val
);
set_operand
(
0
,
val
);
}
}
SiToFpInst
*
SiToFpInst
::
create_sitofp
(
Value
*
val
,
Type
*
ty
,
BasicBlock
*
bb
)
{
SiToFpInst
*
SiToFpInst
::
create_sitofp
(
Value
*
val
,
Type
*
ty
,
BasicBlock
*
bb
)
{
return
create
(
Instruction
::
sitofp
,
val
,
ty
,
bb
);
return
create
(
Instruction
::
sitofp
,
val
,
ty
,
bb
);
}
}
Type
*
SiToFpInst
::
get_dest_type
()
const
{
return
dest_ty_
;
}
Type
*
SiToFpInst
::
get_dest_type
()
const
{
return
dest_ty_
;
}
std
::
string
SiToFpInst
::
print
()
{
std
::
string
SiToFpInst
::
print
()
{
std
::
string
instr_ir
;
std
::
string
instr_ir
;
instr_ir
+=
"%"
;
instr_ir
+=
"%"
;
instr_ir
+=
this
->
get_name
();
instr_ir
+=
this
->
get_name
();
...
@@ -475,7 +613,11 @@ std::string SiToFpInst::print() {
...
@@ -475,7 +613,11 @@ std::string SiToFpInst::print() {
return
instr_ir
;
return
instr_ir
;
}
}
PhiInst
::
PhiInst
(
OpID
op
,
std
::
vector
<
Value
*>
vals
,
std
::
vector
<
BasicBlock
*>
val_bbs
,
Type
*
ty
,
BasicBlock
*
bb
)
PhiInst
::
PhiInst
(
OpID
op
,
std
::
vector
<
Value
*>
vals
,
std
::
vector
<
BasicBlock
*>
val_bbs
,
Type
*
ty
,
BasicBlock
*
bb
)
:
BaseInst
<
PhiInst
>
(
ty
,
op
,
2
*
vals
.
size
())
{
:
BaseInst
<
PhiInst
>
(
ty
,
op
,
2
*
vals
.
size
())
{
for
(
int
i
=
0
;
i
<
vals
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
vals
.
size
();
i
++
)
{
set_operand
(
2
*
i
,
vals
[
i
]);
set_operand
(
2
*
i
,
vals
[
i
]);
...
@@ -484,13 +626,15 @@ PhiInst::PhiInst(OpID op, std::vector<Value *> vals, std::vector<BasicBlock *> v
...
@@ -484,13 +626,15 @@ PhiInst::PhiInst(OpID op, std::vector<Value *> vals, std::vector<BasicBlock *> v
this
->
set_parent
(
bb
);
this
->
set_parent
(
bb
);
}
}
PhiInst
*
PhiInst
::
create_phi
(
Type
*
ty
,
BasicBlock
*
bb
)
{
PhiInst
*
PhiInst
::
create_phi
(
Type
*
ty
,
BasicBlock
*
bb
)
{
std
::
vector
<
Value
*>
vals
;
std
::
vector
<
Value
*>
vals
;
std
::
vector
<
BasicBlock
*>
val_bbs
;
std
::
vector
<
BasicBlock
*>
val_bbs
;
return
create
(
Instruction
::
phi
,
vals
,
val_bbs
,
ty
,
bb
);
return
create
(
Instruction
::
phi
,
vals
,
val_bbs
,
ty
,
bb
);
}
}
std
::
string
PhiInst
::
print
()
{
std
::
string
PhiInst
::
print
()
{
std
::
string
instr_ir
;
std
::
string
instr_ir
;
instr_ir
+=
"%"
;
instr_ir
+=
"%"
;
instr_ir
+=
this
->
get_name
();
instr_ir
+=
this
->
get_name
();
...
@@ -508,9 +652,12 @@ std::string PhiInst::print() {
...
@@ -508,9 +652,12 @@ std::string PhiInst::print() {
instr_ir
+=
print_as_op
(
this
->
get_operand
(
2
*
i
+
1
),
false
);
instr_ir
+=
print_as_op
(
this
->
get_operand
(
2
*
i
+
1
),
false
);
instr_ir
+=
" ]"
;
instr_ir
+=
" ]"
;
}
}
if
(
this
->
get_num_operand
()
/
2
<
this
->
get_parent
()
->
get_pre_basic_blocks
().
size
())
{
if
(
this
->
get_num_operand
()
/
2
<
this
->
get_parent
()
->
get_pre_basic_blocks
().
size
())
{
for
(
auto
pre_bb
:
this
->
get_parent
()
->
get_pre_basic_blocks
())
{
for
(
auto
pre_bb
:
this
->
get_parent
()
->
get_pre_basic_blocks
())
{
if
(
std
::
find
(
this
->
get_operands
().
begin
(),
this
->
get_operands
().
end
(),
static_cast
<
Value
*>
(
pre_bb
))
==
if
(
std
::
find
(
this
->
get_operands
().
begin
(),
this
->
get_operands
().
end
(),
static_cast
<
Value
*>
(
pre_bb
))
==
this
->
get_operands
().
end
())
{
this
->
get_operands
().
end
())
{
// find a pre_bb is not in phi
// find a pre_bb is not in phi
instr_ir
+=
", [ undef, "
+
print_as_op
(
pre_bb
,
false
)
+
" ]"
;
instr_ir
+=
", [ undef, "
+
print_as_op
(
pre_bb
,
false
)
+
" ]"
;
...
...
src/optimization/GVN.cpp
View file @
f26d91aa
...
@@ -209,6 +209,17 @@ print_partitions(const GVN::partitions &p) {
...
@@ -209,6 +209,17 @@ print_partitions(const GVN::partitions &p) {
}
}
}
// namespace utils
}
// namespace utils
GVN
::
partitions
GVN
::
join_helper
(
BasicBlock
*
pre1
,
BasicBlock
*
pre2
)
{
assert
(
not
_TOP
[
pre1
]
or
not
_TOP
[
pre2
]
&&
"should flow here, not jump"
);
if
(
_TOP
[
pre1
])
return
pout_
[
pre2
];
else
if
(
_TOP
[
pre2
])
return
pout_
[
pre1
];
return
join
(
pout_
[
pre1
],
pout_
[
pre2
]);
}
GVN
::
partitions
GVN
::
partitions
GVN
::
join
(
const
partitions
&
P1
,
const
partitions
&
P2
)
{
GVN
::
join
(
const
partitions
&
P1
,
const
partitions
&
P2
)
{
// TODO: do intersection pair-wise
// TODO: do intersection pair-wise
...
@@ -228,7 +239,6 @@ std::shared_ptr<CongruenceClass>
...
@@ -228,7 +239,6 @@ std::shared_ptr<CongruenceClass>
GVN
::
intersect
(
std
::
shared_ptr
<
CongruenceClass
>
ci
,
GVN
::
intersect
(
std
::
shared_ptr
<
CongruenceClass
>
ci
,
std
::
shared_ptr
<
CongruenceClass
>
cj
)
{
std
::
shared_ptr
<
CongruenceClass
>
cj
)
{
// TODO
// TODO
// If no common members, return null
auto
c
=
createCongruenceClass
();
auto
c
=
createCongruenceClass
();
std
::
set
<
Value
*>
intersection
;
std
::
set
<
Value
*>
intersection
;
...
@@ -240,10 +250,21 @@ GVN::intersect(std::shared_ptr<CongruenceClass> ci,
...
@@ -240,10 +250,21 @@ GVN::intersect(std::shared_ptr<CongruenceClass> ci,
c
->
members_
=
intersection
;
c
->
members_
=
intersection
;
if
(
ci
->
index_
==
cj
->
index_
)
if
(
ci
->
index_
==
cj
->
index_
)
c
->
index_
=
ci
->
index_
;
c
->
index_
=
ci
->
index_
;
if
(
ci
->
leader_
==
cj
->
leader_
)
if
(
ci
->
value_expr_
==
cj
->
value_expr_
)
c
->
leader_
=
cj
->
leader_
;
c
->
value_expr_
=
ci
->
value_expr_
;
/* if (*ci == *cj)
if
(
ci
->
value_phi_
and
cj
->
value_phi_
and
* return ci; */
*
ci
->
value_phi_
==
*
cj
->
value_phi_
)
c
->
value_phi_
=
ci
->
value_phi_
;
// if (c->members_.size() or c->value_expr_ or c->value_phi_) // not empty
// ??
// What if the ve is nullptr?
if
(
c
->
members_
.
size
())
// not empty
if
(
c
->
index_
==
0
)
{
c
->
index_
=
new_number
();
c
->
value_phi_
=
PhiExpression
::
create
(
ci
->
value_expr_
,
cj
->
value_expr_
);
}
return
c
;
return
c
;
}
}
...
@@ -251,30 +272,31 @@ GVN::intersect(std::shared_ptr<CongruenceClass> ci,
...
@@ -251,30 +272,31 @@ GVN::intersect(std::shared_ptr<CongruenceClass> ci,
void
void
GVN
::
detectEquivalences
()
{
GVN
::
detectEquivalences
()
{
bool
changed
;
bool
changed
;
std
::
cout
<<
"all the instruction address:"
<<
std
::
endl
;
for
(
auto
&
bb
:
func_
->
get_basic_blocks
())
{
for
(
auto
&
instr
:
bb
.
get_instructions
())
std
::
cout
<<
&
instr
<<
"
\t
"
<<
instr
.
print
()
<<
std
::
endl
;
}
// initialize pout with top
// initialize pout with top
for
(
auto
&
bb
:
func_
->
get_basic_blocks
())
{
for
(
auto
&
bb
:
func_
->
get_basic_blocks
())
{
// pin_[&bb].clear();
_TOP
[
&
bb
]
=
true
;
// pout_[&bb].clear();
for
(
auto
&
instr
:
bb
.
get_instructions
())
_TOP
[
&
instr
]
=
true
;
}
}
// modify entry block
auto
Entry
=
func_
->
get_entry_block
();
auto
Entry
=
func_
->
get_entry_block
();
_TOP
[
&*
Entry
->
get_instructions
().
begin
()]
=
false
;
_TOP
[
Entry
]
=
false
;
pin_
[
Entry
].
clear
();
pin_
[
Entry
].
clear
();
pout_
[
Entry
]
.
clear
();
// pout_[Entry]
= transferFunction(Entry);
pout_
[
Entry
]
=
transferFunction
(
Entry
);
// iterate until converge
// iterate until converge
do
{
do
{
changed
=
false
;
changed
=
false
;
// see the pseudo code in documentation
for
(
auto
&
_bb
:
func_
->
get_basic_blocks
())
{
for
(
auto
&
_bb
:
func_
->
get_basic_blocks
())
{
// you might need to visit the
// blocks in depth-first order
auto
bb
=
&
_bb
;
auto
bb
=
&
_bb
;
// get PIN of bb by predecessor(s)
if
(
bb
==
Entry
)
continue
;
// get PIN of bb from predecessor(s)
auto
pre_bbs_
=
bb
->
get_pre_basic_blocks
();
auto
pre_bbs_
=
bb
->
get_pre_basic_blocks
();
if
(
bb
!=
Entry
)
{
if
(
bb
!=
Entry
)
{
// only update PIN for blocks that are not Entry
// only update PIN for blocks that are not Entry
...
@@ -283,12 +305,12 @@ GVN::detectEquivalences() {
...
@@ -283,12 +305,12 @@ GVN::detectEquivalences() {
case
2
:
{
case
2
:
{
auto
pre_1
=
*
pre_bbs_
.
begin
();
auto
pre_1
=
*
pre_bbs_
.
begin
();
auto
pre_2
=
*
(
++
pre_bbs_
.
begin
());
auto
pre_2
=
*
(
++
pre_bbs_
.
begin
());
pin_
[
bb
]
=
join
(
pin_
[
pre_1
],
pin_
[
pre_2
]
);
pin_
[
bb
]
=
join
_helper
(
pre_1
,
pre_2
);
break
;
break
;
}
}
case
1
:
{
case
1
:
{
auto
pre
=
*
(
pre_bbs_
.
begin
());
auto
pre
=
*
(
pre_bbs_
.
begin
());
pin_
[
bb
]
=
clone
(
pin_
[
pre
])
;
pin_
[
bb
]
=
pout_
[
pre
]
;
break
;
break
;
}
}
default:
default:
...
@@ -297,82 +319,246 @@ GVN::detectEquivalences() {
...
@@ -297,82 +319,246 @@ GVN::detectEquivalences() {
abort
();
abort
();
}
}
}
}
auto
part
=
pin_
[
bb
];
auto
part
=
transferFunction
(
bb
);
// iterate through all instructions in the block
for
(
auto
&
instr
:
bb
->
get_instructions
())
{
// ??
if
(
not
instr
.
is_phi
())
part
=
transferFunction
(
&
instr
,
instr
.
get_operand
(
1
),
part
);
}
// and the phi instruction in all the successors
for
(
auto
succ
:
bb
->
get_succ_basic_blocks
())
{
for
(
auto
&
instr
:
succ
->
get_instructions
())
if
(
instr
.
is_phi
())
{
Instruction
*
pretend
;
// ??
part
=
transferFunction
(
pretend
,
instr
.
get_operand
(
1
),
part
);
}
}
// check changes in pout
// check changes in pout
changed
|=
not
(
part
==
pout_
[
bb
]);
changed
|=
not
(
part
==
pout_
[
bb
]);
pout_
[
bb
]
=
part
;
pout_
[
bb
]
=
part
;
_TOP
[
bb
]
=
false
;
}
}
}
while
(
changed
);
}
while
(
changed
);
}
}
shared_ptr
<
Expression
>
shared_ptr
<
Expression
>
GVN
::
valueExpr
(
Instruction
*
instr
)
{
GVN
::
valueExpr
(
Instruction
*
instr
,
partitions
*
part
)
{
// TODO
// TODO
return
{};
// ?? should use part?
std
::
string
err
{
"Undefined"
};
std
::
cout
<<
instr
->
print
()
<<
std
::
endl
;
if
(
instr
->
isBinary
()
or
instr
->
is_cmp
()
or
instr
->
is_fcmp
())
{
auto
op1
=
instr
->
get_operand
(
0
);
auto
op2
=
instr
->
get_operand
(
1
);
auto
op1_const
=
dynamic_cast
<
Constant
*>
(
op1
);
auto
op2_const
=
dynamic_cast
<
Constant
*>
(
op2
);
if
(
op1_const
and
op2_const
)
{
// both are constant number, so:
// constant fold!
return
ConstantExpression
::
create
(
folder_
->
compute
(
instr
,
op1_const
,
op2_const
));
}
else
{
// both none constant
auto
op1_instr
=
dynamic_cast
<
Instruction
*>
(
op1
);
auto
op2_instr
=
dynamic_cast
<
Instruction
*>
(
op2
);
assert
((
op1_instr
or
op1_const
)
and
(
op2_instr
or
op2_const
)
&&
"must be this case"
);
return
BinaryExpression
::
create
(
instr
->
get_instr_type
(),
(
op1_const
?
ConstantExpression
::
create
(
op1_const
)
:
valueExpr
(
op1_instr
)),
(
op2_const
?
ConstantExpression
::
create
(
op2_const
)
:
valueExpr
(
op2_instr
)));
}
}
else
if
(
instr
->
is_phi
())
{
err
=
"phi"
;
}
else
if
(
instr
->
is_fp2si
()
or
instr
->
is_si2fp
()
or
instr
->
is_zext
())
{
auto
op
=
instr
->
get_operand
(
0
);
auto
op_const
=
dynamic_cast
<
Constant
*>
(
op
);
auto
op_instr
=
dynamic_cast
<
Instruction
*>
(
op
);
assert
(
op_instr
or
op_const
);
// get dest type
auto
instr_fp2si
=
dynamic_cast
<
FpToSiInst
*>
(
instr
);
auto
instr_si2fp
=
dynamic_cast
<
SiToFpInst
*>
(
instr
);
auto
instr_zext
=
dynamic_cast
<
ZextInst
*>
(
instr
);
Type
*
dest_type
=
nullptr
;
if
(
instr_fp2si
)
dest_type
=
instr_fp2si
->
get_dest_type
();
else
if
(
instr_si2fp
)
dest_type
=
instr_si2fp
->
get_dest_type
();
else
if
(
instr_zext
)
dest_type
=
instr_zext
->
get_dest_type
();
else
err
=
"cast"
;
if
(
dest_type
)
{
if
(
op_const
)
return
ConstantExpression
::
create
(
folder_
->
compute
(
instr
,
op_const
));
else
return
CastExpression
::
create
(
instr
->
get_instr_type
(),
valueExpr
(
op_instr
),
dest_type
);
}
}
else
if
(
instr
->
is_gep
())
{
auto
operands
=
instr
->
get_operands
();
auto
ptr
=
operands
[
0
];
std
::
vector
<
std
::
shared_ptr
<
Expression
>>
idxs
;
// check for base address
assert
(
not
dynamic_cast
<
Constant
*>
(
ptr
)
and
dynamic_cast
<
Instruction
*>
(
ptr
)
&&
"base address should only be from instruction"
);
// set idxes
for
(
int
i
=
1
;
i
<
operands
.
size
();
i
++
)
{
if
(
dynamic_cast
<
Constant
*>
(
operands
[
i
]))
idxs
.
push_back
(
ConstantExpression
::
create
(
dynamic_cast
<
Constant
*>
(
operands
[
i
])));
else
{
assert
(
dynamic_cast
<
Instruction
*>
(
operands
[
i
]));
idxs
.
push_back
(
valueExpr
(
dynamic_cast
<
Instruction
*>
(
operands
[
i
])));
}
}
return
GEPExpression
::
create
(
valueExpr
(
dynamic_cast
<
Instruction
*>
(
ptr
)),
idxs
);
}
else
if
(
instr
->
is_load
()
or
instr
->
is_alloca
()
or
instr
->
is_call
())
{
return
UniqueExpression
::
create
(
instr
);
}
std
::
cerr
<<
"Undefined case: "
<<
err
<<
std
::
endl
;
abort
();
}
}
// instruction of the form `x = e`, mostly x is just e (SSA),
// instruction of the form `x = e`, mostly x is just e (SSA),
// but for copy stmt x is a phi instruction in the successor.
// but for copy stmt x is a phi instruction in the successor.
// Phi values (not copy stmt) should be handled in detectEquiv
// Phi values (not copy stmt) should be handled in detectEquiv
//
// assert the x is an instruction that can generate a new value
//
/// \param bb basic block in which the transfer function is called
/// \param bb basic block in which the transfer function is called
GVN
::
partitions
GVN
::
partitions
GVN
::
transferFunction
(
Instruction
*
x
,
Value
*
e
,
partitions
pin
)
{
GVN
::
transferFunction
(
Instruction
*
x
,
Value
*
e
,
partitions
pin
)
{
partitions
pout
=
clone
(
pin
);
partitions
pout
=
pin
;
// TODO: deal with copy-stmt case
// ?? deal with copy statement
auto
e_instr
=
dynamic_cast
<
Instruction
*>
(
e
);
auto
e_const
=
dynamic_cast
<
Constant
*>
(
e
);
assert
((
not
e
or
e_instr
or
e_const
)
&&
"A value must be from an instruction or constant"
);
// erase the old record for x
std
::
set
<
Value
*>::
iterator
it
;
for
(
auto
c
:
pin
)
if
((
it
=
std
::
find
(
c
->
members_
.
begin
(),
c
->
members_
.
end
(),
x
))
!=
c
->
members_
.
end
())
{
c
->
members_
.
erase
(
it
);
}
// TODO: get different ValueExpr by Instruction::OpID, modify pout
// TODO: get different ValueExpr by Instruction::OpID, modify pout
std
::
set
<
Value
*>::
iterator
iter
;
// ??
// get ve and vpf
shared_ptr
<
Expression
>
ve
;
if
(
e
)
{
if
(
e_const
)
ve
=
ConstantExpression
::
create
(
e_const
);
else
ve
=
valueExpr
(
e_instr
,
&
pin
);
}
else
ve
=
valueExpr
(
x
,
&
pin
);
auto
vpf
=
valuePhiFunc
(
ve
,
curr_bb
);
for
(
auto
c
:
pout
)
{
for
(
auto
c
:
pout
)
{
if
((
iter
=
std
::
find
(
c
->
members_
.
begin
(),
c
->
members_
.
end
(),
x
))
!=
if
(
ve
==
c
->
value_expr_
or
(
vpf
and
vpf
==
c
->
value_phi_
))
{
c
->
members_
.
end
())
{
c
->
value_expr_
=
ve
;
// static_cast<Value *>(x))) != c->members_.end()) {
c
->
members_
.
insert
(
x
);
c
->
members_
.
erase
(
iter
);
}
else
{
auto
c
=
createCongruenceClass
(
new_number
());
c
->
members_
.
insert
(
x
);
c
->
value_expr_
=
ve
;
c
->
value_phi_
=
vpf
;
pout
.
insert
(
c
);
}
}
}
}
auto
ve
=
valueExpr
(
x
);
auto
vpf
=
valuePhiFunc
(
ve
,
pin
);
/* pout.insert({});
/* // first version: ignore ve and vpf
* auto c = CongruenceClass(new_number());
* // and only update index, leader and members
* c.leader_ = e; */
* auto c = createCongruenceClass(new_number());
* c->leader_ = x;
* c->members_.insert(x);
* pout.insert(c); */
return
pout
;
return
pout
;
}
}
/*
* read the pin for the block and then execute transferFunction() for all
* instructions inside.
*/
GVN
::
partitions
GVN
::
partitions
GVN
::
transferFunction
(
BasicBlock
*
bb
)
{
GVN
::
transferFunction
(
BasicBlock
*
bb
)
{
partitions
pout
=
clone
(
pin_
[
bb
]);
curr_bb
=
bb
;
// ??
int
res
;
return
pout
;
auto
part
=
pin_
[
bb
];
/* LOG_INFO << "transferFunction(bb=" << bb->get_name() << ")\n";
* LOG_INFO << "pin:\n";
* utils::print_partitions(pin_[bb]);
* LOG_INFO << "pout before:\n";
* utils::print_partitions(pout_[bb]); */
// iterate through all instructions in the block
for
(
auto
&
instr
:
bb
->
get_instructions
())
{
// ?? what about orther instructions? Are they all ok?
if
(
not
instr
.
is_phi
()
and
not
instr
.
is_void
())
part
=
transferFunction
(
&
instr
,
nullptr
,
part
);
}
// and the phi instruction in all the successors
for
(
auto
succ
:
bb
->
get_succ_basic_blocks
())
{
for
(
auto
&
instr
:
succ
->
get_instructions
())
{
if
(
instr
.
is_phi
())
{
if
((
res
=
pretend_copy_stmt
(
&
instr
,
bb
)
==
-
1
))
continue
;
part
=
transferFunction
(
&
instr
,
instr
.
get_operand
(
res
),
part
);
}
}
}
/* LOG_INFO << "pout after:\n";
* utils::print_partitions(part);
* std::cout << std::endl; */
return
part
;
}
}
shared_ptr
<
PhiExpression
>
shared_ptr
<
PhiExpression
>
GVN
::
valuePhiFunc
(
shared_ptr
<
Expression
>
ve
,
const
partitions
&
P
)
{
GVN
::
valuePhiFunc
(
shared_ptr
<
Expression
>
ve
,
BasicBlock
*
bb
)
{
// TODO
// TODO
return
{};
if
(
ve
->
get_expr_type
()
!=
Expression
::
e_bin
)
return
nullptr
;
auto
ve_bin
=
static_cast
<
BinaryExpression
*>
(
ve
.
get
());
if
(
ve_bin
->
lhs_
->
get_expr_type
()
!=
Expression
::
e_phi
or
ve_bin
->
rhs_
->
get_expr_type
()
!=
Expression
::
e_phi
)
return
nullptr
;
// get 2 phi expressions
auto
lhs
=
static_cast
<
PhiExpression
*>
(
ve_bin
->
lhs_
.
get
());
auto
rhs
=
static_cast
<
PhiExpression
*>
(
ve_bin
->
rhs_
.
get
());
// get 2 predecessors
auto
pre_bbs_
=
bb
->
get_pre_basic_blocks
();
auto
pre_1
=
*
pre_bbs_
.
begin
();
auto
pre_2
=
*
(
++
pre_bbs_
.
begin
());
// try to get the merged value expression
auto
vl_merge
=
BinaryExpression
::
create
(
ve_bin
->
op_
,
lhs
->
lhs_
,
rhs
->
lhs_
);
auto
vr_merge
=
BinaryExpression
::
create
(
ve_bin
->
op_
,
lhs
->
rhs_
,
rhs
->
rhs_
);
auto
vi
=
getVN
(
pout_
[
pre_1
],
vl_merge
);
auto
vj
=
getVN
(
pout_
[
pre_2
],
vr_merge
);
if
(
vi
==
nullptr
)
vi
=
valuePhiFunc
(
vl_merge
,
pre_1
);
if
(
vj
==
nullptr
)
vj
=
valuePhiFunc
(
vr_merge
,
pre_2
);
if
(
vi
and
vj
)
return
PhiExpression
::
create
(
vi
,
vj
);
else
return
nullptr
;
}
}
shared_ptr
<
Expression
>
shared_ptr
<
Expression
>
GVN
::
getVN
(
const
partitions
&
pout
,
shared_ptr
<
Expression
>
ve
)
{
GVN
::
getVN
(
const
partitions
&
pout
,
shared_ptr
<
Expression
>
ve
)
{
// TODO: return what?
// TODO: return what?
/* for (auto c : pout) {
* if (c->value_expr_ == ve)
* return ve;
* } */
for
(
auto
it
=
pout
.
begin
();
it
!=
pout
.
end
();
it
++
)
for
(
auto
it
=
pout
.
begin
();
it
!=
pout
.
end
();
it
++
)
if
((
*
it
)
->
value_expr_
and
*
(
*
it
)
->
value_expr_
==
*
ve
)
if
((
*
it
)
->
value_expr_
and
*
(
*
it
)
->
value_expr_
==
*
ve
)
return
{}
;
return
ve
;
return
nullptr
;
return
nullptr
;
}
}
...
@@ -490,6 +676,12 @@ GVNExpression::operator==(const Expression &lhs, const Expression &rhs) {
...
@@ -490,6 +676,12 @@ GVNExpression::operator==(const Expression &lhs, const Expression &rhs) {
return
equiv_as
<
BinaryExpression
>
(
lhs
,
rhs
);
return
equiv_as
<
BinaryExpression
>
(
lhs
,
rhs
);
case
Expression
::
e_phi
:
case
Expression
::
e_phi
:
return
equiv_as
<
PhiExpression
>
(
lhs
,
rhs
);
return
equiv_as
<
PhiExpression
>
(
lhs
,
rhs
);
case
Expression
::
e_cast
:
return
equiv_as
<
CastExpression
>
(
lhs
,
rhs
);
case
Expression
::
e_gep
:
return
equiv_as
<
GEPExpression
>
(
lhs
,
rhs
);
case
Expression
::
e_unique
:
return
equiv_as
<
UniqueExpression
>
(
lhs
,
rhs
);
}
}
}
}
...
@@ -517,12 +709,35 @@ operator==(const GVN::partitions &p1, const GVN::partitions &p2) {
...
@@ -517,12 +709,35 @@ operator==(const GVN::partitions &p1, const GVN::partitions &p2) {
// cannot direct compare???
// cannot direct compare???
if
(
p1
.
size
()
!=
p2
.
size
())
if
(
p1
.
size
()
!=
p2
.
size
())
return
false
;
return
false
;
return
std
::
equal
(
p1
.
begin
(),
p1
.
end
(),
p2
.
begin
(),
p2
.
end
());
auto
it1
=
p1
.
begin
();
auto
it2
=
p2
.
begin
();
for
(;
it1
!=
p1
.
end
();
++
it1
,
++
it2
)
if
(
not
(
**
it1
==
**
it2
))
return
false
;
return
true
;
}
}
// only compare
index
// only compare
members
bool
bool
CongruenceClass
::
operator
==
(
const
CongruenceClass
&
other
)
const
{
CongruenceClass
::
operator
==
(
const
CongruenceClass
&
other
)
const
{
// TODO: which fields need to be compared?
// TODO: which fields need to be compared?
return
index_
==
other
.
index_
;
if
(
members_
.
size
()
!=
other
.
members_
.
size
())
return
false
;
return
members_
==
other
.
members_
;
}
int
GVN
::
pretend_copy_stmt
(
Instruction
*
instr
,
BasicBlock
*
bb
)
{
auto
phi
=
static_cast
<
PhiInst
*>
(
instr
);
// res = phi [op1, name1], [op2, name2]
// ^0 ^1 ^2 ^3
if
(
phi
->
get_operand
(
1
)
->
get_name
()
==
bb
->
get_name
())
{
// pretend copy statement:
// `res = op1`
return
0
;
}
else
if
(
phi
->
get_operand
(
3
)
->
get_name
()
==
bb
->
get_name
())
{
// `res = op2`
return
2
;
}
return
-
1
;
}
}
tests/4-ir-opt/testcases/GVN/functional/.gitignore
0 → 100644
View file @
f26d91aa
*.ll
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment