Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
akg
提交
05313741
A
akg
项目概览
MindSpore
/
akg
通知
59
Star
7
Fork
7
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
akg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
05313741
编写于
8月 06, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 06, 2020
浏览文件
操作
浏览文件
下载
差异文件
!102 Add Loop Mutator & Relu to pass three address
Merge pull request !102 from ConnZhai/conn
上级
bf3a6c3a
05a96ace
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
477 addition
and
309 deletion
+477
-309
src/pass/to_three_address.cc
src/pass/to_three_address.cc
+446
-267
tests/unittest_cpp/src/pass_test/to_three_address_test.cc
tests/unittest_cpp/src/pass_test/to_three_address_test.cc
+31
-42
未找到文件。
src/pass/to_three_address.cc
浏览文件 @
05313741
...
...
@@ -35,22 +35,71 @@ using VarSet = std::unordered_set<Var, air::NodeHash, air::NodeEqual>;
// forward declaration
class
ThreeAddressExprMutator
;
class
ThreeAddressFilt
er
:
public
IRVisitor
{
class
ExprArgsFetch
er
:
public
IRVisitor
{
public:
bool
Find
(
const
Stmt
&
s
)
{
Visit
(
s
);
return
need_
;
explicit
ExprArgsFetcher
(
Array
<
Expr
>
args
)
:
args_
(
args
),
index_
(
args_
.
size
()
-
1
)
{}
~
ExprArgsFetcher
()
override
=
default
;
Array
<
Expr
>
GetArgs
(
const
Expr
&
e
)
{
Visit
(
e
);
if
(
max_dim
>=
args_
.
size
())
{
return
args_
;
}
Array
<
Expr
>
args
;
while
(
index_
<
args_
.
size
())
{
args
.
push_back
(
args_
[
index_
]);
index_
++
;
}
if
(
CountVars
(
args
)
==
CountVars
(
args_
))
{
return
args_
;
}
return
args
;
}
bool
MustBroadcast
(
const
Expr
&
e
)
{
if
(
is_constant
(
e
)
||
CountVars
(
e
)
==
0
)
{
return
false
;
}
size_t
size
=
GetArgs
(
e
).
size
();
return
size
>
max_dim
;
}
void
Visit_
(
const
Call
*
op
)
override
{
if
(
op
->
name
==
"load3d_l1_ub"
)
{
need_
=
false
;
if
(
op
->
call_type
==
Call
::
CallType
::
Halide
)
{
max_dim
=
max_dim
<
op
->
args
.
size
()
?
op
->
args
.
size
()
:
max_dim
;
for
(
Expr
arg
:
op
->
args
)
{
size_t
index
=
GetIndex
(
arg
);
index_
=
index_
>
index
?
index
:
index_
;
}
}
else
{
CHECK
(
op
->
call_type
==
Call
::
CallType
::
PureIntrinsic
);
for
(
Expr
e
:
op
->
args
)
{
Array
<
Expr
>
args
=
GetArgs
(
e
);
max_dim
=
max_dim
<
args
.
size
()
?
args
.
size
()
:
max_dim
;
for
(
Expr
arg
:
args
)
{
size_t
index
=
GetIndex
(
arg
);
index_
=
index_
>
index
?
index
:
index_
;
}
}
}
IRVisitor
::
Visit_
(
op
);
}
private:
bool
need_
{
true
};
size_t
GetIndex
(
const
Expr
&
arg
)
{
if
(
is_constant
(
arg
))
{
return
index_
;
}
for
(
size_t
i
=
0
;
i
<
args_
.
size
();
++
i
)
{
if
(
args_
[
i
].
same_as
(
arg
))
{
return
i
;
}
}
return
index_
;
}
Array
<
Expr
>
args_
;
size_t
index_
;
size_t
max_dim
{
0
};
};
class
ScalarOperandFinder
:
public
IRVisitor
{
...
...
@@ -188,50 +237,21 @@ std::unordered_set<Tensor> GetExprTensors(const Expr expr) {
return
tensors
;
}
// Replace all instances of a Tensor "from" in an Expr with a new one "to"
class
ReplaceProvideTensors
:
public
IRMutator
{
public:
ReplaceProvideTensors
(
const
Tensor
&
from
,
const
Operation
&
to
)
:
from_
(
from
->
op
),
to_
(
to
)
{}
~
ReplaceProvideTensors
()
override
=
default
;
Stmt
Mutate_
(
const
Provide
*
op
,
const
Stmt
&
s
)
final
{
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
op
=
stmt
.
as
<
Provide
>
();
CHECK
(
op
);
if
(
op
->
func
==
from_
)
{
stmt
=
Provide
::
make
(
to_
,
op
->
value_index
,
op
->
value
,
op
->
args
);
}
return
stmt
;
}
Expr
Mutate_
(
const
Call
*
op
,
const
Expr
&
e
)
override
{
Expr
expr
=
IRMutator
::
Mutate_
(
op
,
e
);
const
Call
*
n
=
expr
.
as
<
Call
>
();
CHECK
(
n
);
if
(
n
->
func
==
from_
)
{
expr
=
Call
::
make
(
n
->
type
,
to_
->
name
,
n
->
args
,
n
->
call_type
,
to_
,
n
->
value_index
);
}
return
expr
;
}
private:
const
Operation
from_
;
const
Operation
to_
;
};
// Mutate expression according to selection choices
class
ThreeAddressExprMutator
:
public
IRMutator
{
public:
ThreeAddressExprMutator
(
const
Tensor
output
,
const
Array
<
Expr
>
&
args
,
const
Array
<
Expr
>
&
shape
,
const
std
::
unordered_set
<
const
Call
*>
&
broadcast
,
bool
IsReductionOp
,
bool
cross_stmt_simplify
)
ThreeAddressExprMutator
(
const
Tensor
output
,
const
Array
<
Expr
>
&
args
,
const
Array
<
Expr
>
&
out_args
,
const
Array
<
Expr
>
&
shape
,
const
std
::
unordered_set
<
const
Call
*>
&
broadcast
,
bool
IsReductionOp
,
bool
cross_stmt_simplify
,
bool
is_simple
=
false
)
:
output_
(
output
),
args_
(
args
),
out_args_
(
out_args
),
shape_
(
shape
),
broadcast_
(
broadcast
),
IsReductionOp_
(
IsReductionOp
),
cross_simplify_
(
cross_stmt_simplify
),
hasher_
(
cross_stmt_simplify
)
{
hasher_
(
cross_stmt_simplify
),
is_simple_
(
is_simple
)
{
CHECK_EQ
(
args_
.
size
(),
shape_
.
size
());
if
(
shape_
.
empty
())
{
// scalar values should have at least one dimension and contains one element
shape_
.
push_back
(
1
);
...
...
@@ -246,7 +266,7 @@ class ThreeAddressExprMutator : public IRMutator {
common_exprs_
.
insert
(
global_common_expr
.
begin
(),
global_common_expr
.
end
());
}
Expr
AllocateTmp
(
Expr
value
)
{
Expr
AllocateTmp
(
Expr
value
,
Array
<
Expr
>
args
=
{}
)
{
// detect common expression
size_t
hash_value
=
hasher_
(
value
);
auto
x
=
common_exprs_
[
hash_value
];
...
...
@@ -263,13 +283,17 @@ class ThreeAddressExprMutator : public IRMutator {
// allocate new immediate tensor
Tensor
imm
;
imm
=
PlaceholderOpNode
::
make
(
output_
->
op
->
name
+
"_"
+
std
::
to_string
(
ct_
++
),
shape_
,
value
.
type
()).
output
(
0
);
if
(
args
.
empty
())
{
args
=
args_
;
}
std
::
string
name
=
output_
->
op
->
name
+
"_"
+
std
::
to_string
(
ct_
++
);
imm
=
PlaceholderOpNode
::
make
(
name
,
GetShape
(
args
),
value
.
type
()).
output
(
0
);
imm_tensors
.
push_back
(
imm
);
imm_ops
.
insert
(
imm
->
op
);
// update common expr
assign_stmt
.
push_back
(
Provide
::
make
(
imm
->
op
,
imm
->
value_index
,
value
,
args
_
));
Expr
ret
=
Call
::
make
(
value
.
type
(),
imm
->
op
->
name
,
args
_
,
Call
::
CallType
::
Halide
,
imm
->
op
,
imm
->
value_index
);
assign_stmt
.
push_back
(
Provide
::
make
(
imm
->
op
,
imm
->
value_index
,
value
,
args
));
Expr
ret
=
Call
::
make
(
value
.
type
(),
imm
->
op
->
name
,
args
,
Call
::
CallType
::
Halide
,
imm
->
op
,
imm
->
value_index
);
common_exprs_
[
hash_value
]
=
std
::
make_pair
(
value
,
ret
);
imm2hash_
[
imm
->
op
]
=
hash_value
;
return
ret
;
...
...
@@ -283,9 +307,13 @@ class ThreeAddressExprMutator : public IRMutator {
common_exprs_
.
erase
(
old_hash
);
// update new common expr
assign_stmt
.
push_back
(
Provide
::
make
(
imm
->
op
,
imm
->
value_index
,
value
,
args_
));
Array
<
Expr
>
args
=
args_
;
if
(
is_simple_
)
{
args
=
ExprArgsFetcher
(
args_
).
GetArgs
(
value
);
}
assign_stmt
.
push_back
(
Provide
::
make
(
imm
->
op
,
imm
->
value_index
,
value
,
args
));
size_t
hash_value
=
hasher_
(
value
);
Expr
ret
=
Call
::
make
(
value
.
type
(),
imm
->
op
->
name
,
args
_
,
Call
::
CallType
::
Halide
,
imm
->
op
,
imm
->
value_index
);
Expr
ret
=
Call
::
make
(
value
.
type
(),
imm
->
op
->
name
,
args
,
Call
::
CallType
::
Halide
,
imm
->
op
,
imm
->
value_index
);
common_exprs_
[
hash_value
]
=
std
::
make_pair
(
value
,
ret
);
imm2hash_
[
imm
->
op
]
=
hash_value
;
return
ret
;
...
...
@@ -324,20 +352,25 @@ class ThreeAddressExprMutator : public IRMutator {
Expr
r
=
Mutate
(
op
->
b
);
in_call_
--
;
bool
broadcast_l
=
!
IsReductionOp_
&&
!
is_constant
(
l
)
&&
CountVars
(
args_
)
>
CountVars
(
l
);
bool
broadcast_r
=
!
IsReductionOp_
&&
!
is_constant
(
r
)
&&
CountVars
(
args_
)
>
CountVars
(
r
);
Array
<
Expr
>
args
=
args_
;
if
(
is_simple_
)
{
args
=
ExprArgsFetcher
(
args_
).
GetArgs
(
T
::
make
(
l
,
r
));
}
bool
broadcast_l
=
!
IsReductionOp_
&&
!
is_constant
(
l
)
&&
CountVars
(
args
)
>
CountVars
(
l
);
bool
broadcast_r
=
!
IsReductionOp_
&&
!
is_constant
(
r
)
&&
CountVars
(
args
)
>
CountVars
(
r
);
if
(
op
->
template
IsInstance
<
Add
>()
||
op
->
template
IsInstance
<
Mul
>())
{
if
(
broadcast_l
&&
broadcast_r
)
{
l
=
AllocateTmp
(
l
);
}
else
if
(
is_constant
(
r
)
&&
broadcast_l
)
{
l
=
AllocateTmp
(
l
);
}
else
if
(
is_constant
(
l
)
&&
broadcast_r
)
{
r
=
AllocateTmp
(
r
);
if
(
broadcast_l
&&
(
broadcast_r
||
is_constant
(
r
)))
{
l
=
AllocateTmp
(
l
,
args
);
}
else
if
((
broadcast_r
&&
is_constant
(
l
)))
{
r
=
AllocateTmp
(
r
,
args
);
}
if
(
CountVars
(
args
)
>
CountVars
(
r
)
&&
ExprArgsFetcher
(
out_args_
).
MustBroadcast
(
r
))
{
r
=
AllocateTmp
(
r
,
args
);
}
}
return
AllocateTmp
(
T
::
make
(
Mutate
(
l
),
Mutate
(
r
)));
return
AllocateTmp
(
T
::
make
(
Mutate
(
l
),
Mutate
(
r
))
,
args
);
}
Expr
Mutate_
(
const
Add
*
op
,
const
Expr
&
e
)
final
{
return
MutateBinaryOp
<
Add
>
(
op
,
e
);
}
...
...
@@ -443,8 +476,15 @@ class ThreeAddressExprMutator : public IRMutator {
// broadcast when need
if
(
broadcast_
.
count
(
op
)
&&
broadcast
)
{
if
(
expr_stack
.
size
()
>=
2
&&
expr_stack
[
expr_stack
.
size
()
-
2
]
->
IsInstance
<
Div
>
())
{
Array
<
Expr
>
args
=
ExprArgsFetcher
(
args_
).
GetArgs
(
expr_stack
[
expr_stack
.
size
()
-
2
]);
if
(
CountVars
(
e
)
<
CountVars
(
args
))
{
return
AllocateTmp
(
e
,
args
);
}
}
else
{
return
AllocateTmp
(
e
);
}
}
// this is going to generate a tensor of tensor expr, like A(B(i))
return
e
;
}
else
if
(
op
->
call_type
==
Call
::
CallType
::
PureIntrinsic
&&
op
->
name
==
air
::
ir
::
intrinsic
::
tvm_if_then_else
)
{
...
...
@@ -546,8 +586,25 @@ class ThreeAddressExprMutator : public IRMutator {
}
}
Array
<
Expr
>
GetShape
(
const
Array
<
Expr
>
&
args
)
{
if
(
CountVars
(
args
)
==
CountVars
(
args_
))
{
return
shape_
;
}
const
size_t
dim
=
args
.
size
();
const
size_t
maxDim
=
output_
->
shape
.
size
();
CHECK_LE
(
dim
,
maxDim
);
Array
<
Expr
>
shape
;
size_t
index
=
maxDim
-
dim
;
while
(
index
<
maxDim
)
{
shape
.
push_back
(
output_
->
shape
[
index
]);
index
++
;
}
return
shape
;
}
Tensor
output_
;
Array
<
Expr
>
args_
;
Array
<
Expr
>
out_args_
;
Array
<
Expr
>
shape_
;
std
::
unordered_map
<
size_t
,
std
::
pair
<
Expr
,
Expr
>>
common_exprs_
;
// hash value -> <match expr, replace expr>
...
...
@@ -566,6 +623,7 @@ class ThreeAddressExprMutator : public IRMutator {
bool
IsReductionOp_
{
false
};
bool
cross_simplify_
;
ExprHasher
hasher_
;
bool
is_simple_
;
};
Expr
ThreeAddressExprMutator
::
Mutate
(
Expr
expr
)
{
...
...
@@ -579,66 +637,78 @@ Expr ThreeAddressExprMutator::Mutate(Expr expr) {
int
ThreeAddressExprMutator
::
ct_
=
0
;
class
Instruction
Selec
tor
{
class
Instruction
Mutator
:
IRMuta
tor
{
public:
InstructionSelector
(
ThreeAddressExprMutator
&
mutator
,
std
::
list
<
Expr
>
&
exprs
,
std
::
unordered_map
<
const
Object
*
,
std
::
string
>
&
notation_map
,
std
::
unordered_map
<
const
Object
*
,
bool
>
&
sign_map
)
:
mutator_
(
mutator
),
exprs_
(
exprs
),
notation_map_
(
notation_map
),
sign_map_
(
sign_map
)
{}
~
InstructionSelector
()
=
default
;
explicit
InstructionMutator
(
ThreeAddressExprMutator
&
mutator
,
Array
<
Expr
>
&
args
)
:
mutator_
(
mutator
),
args_
(
args
)
{}
~
InstructionMutator
()
=
default
;
Expr
Mutate
(
Expr
expr
)
{
if
(
const
Mul
*
op
=
expr
.
as
<
Mul
>
())
{
return
Mutate_
(
op
,
expr
);
Expr
Mutate
(
Expr
value
)
{
return
IRMutator
::
Mutate
(
value
);
}
// VMADD.type {f16, f32} [Xd], [Xn], [Xm], Xt, MASK
// [Xd] = [Xn] * [Xd] + [Xm]
// VAXPY.type {f16, f32, fmix} [Xd], [Xn], Xm, Xt, MASK
// [Xd] = Xm * [Xn] + [Xd]
Expr
Mutate_
(
const
Add
*
op
,
const
Expr
&
e
)
{
Expr
l
=
Mutate
(
op
->
a
);
Expr
r
=
Mutate
(
op
->
b
);
if
(
is_constant
(
l
)
&&
is_constant
(
r
))
{
return
ConstantFold
<
Add
>
(
l
,
r
);
}
if
(
const
Cast
*
op
=
expr
.
as
<
Cast
>
())
{
return
Mutate_
(
op
,
expr
);
return
Add
::
make
(
l
,
r
);
}
if
(
const
Select
*
op
=
expr
.
as
<
Select
>
())
{
return
Mutate_
(
op
,
expr
);
Expr
Mutate_
(
const
Sub
*
op
,
const
Expr
&
e
)
{
Expr
l
=
Mutate
(
op
->
a
);
Expr
r
=
Mutate
(
op
->
b
);
if
(
is_constant
(
l
)
&&
is_constant
(
r
))
{
return
ConstantFold
<
Sub
>
(
l
,
r
);
}
return
expr
;
return
Sub
::
make
(
l
,
r
)
;
}
// vmadd [Xd] = [Xn] * [Xd] + [Xm]
// vaxpy [Xd] = Xm * [Xn] + [Xd]
Expr
Mutate_
(
const
Mul
*
op
,
const
Expr
&
e
)
{
std
::
string
root
=
notation_map_
.
at
(
e
.
get
());
if
(
root
!=
Add
::
_type_key
&&
root
!=
Sub
::
_type_key
)
{
return
e
;
Expr
l
=
Mutate
(
op
->
a
);
Expr
r
=
Mutate
(
op
->
b
);
bool
is_left_constant
=
is_constant
(
l
);
bool
is_right_constant
=
is_constant
(
r
);
if
(
!
is_left_constant
&&
!
is_right_constant
)
{
return
Mul
::
make
(
l
,
r
);
}
bool
is_left_constant
=
is_constant
(
op
->
a
);
bool
is_right_constant
=
is_constant
(
op
->
b
);
if
(
is_left_constant
&&
is_right_constant
)
{
return
e
;
return
ConstantFold
<
Mul
>
(
l
,
r
)
;
}
Expr
expr
=
GetIndexOfPairExprForMul
(
e
);
if
(
expr
.
same_as
(
e
))
{
return
e
;
Expr
constant
=
is_left_constant
?
l
:
r
;
Expr
nonconstant
=
is_left_constant
?
r
:
l
;
if
(
const
Add
*
add
=
nonconstant
.
as
<
Add
>
())
{
return
MulExprMutator
<
Add
>
(
constant
,
add
);
}
else
if
(
const
Sub
*
sub
=
nonconstant
.
as
<
Sub
>
())
{
return
MulExprMutator
<
Sub
>
(
constant
,
sub
);
}
Array
<
Expr
>
args
;
if
(
!
is_left_constant
)
{
args
.
push_back
(
op
->
a
);
}
else
{
args
.
push_back
(
op
->
b
);
return
Mul
::
make
(
l
,
r
);
}
args
.
push_back
(
expr
);
if
(
!
is_right_constant
)
{
args
.
push_back
(
op
->
b
);
}
else
{
args
.
push_back
(
op
->
a
);
Expr
Mutate_
(
const
Div
*
op
,
const
Expr
&
e
)
{
Expr
l
=
Mutate
(
op
->
a
);
Expr
r
=
Mutate
(
op
->
b
);
if
(
is_constant
(
l
)
&&
is_constant
(
r
))
{
return
ConstantFold
<
Div
>
(
l
,
r
);
}
else
if
(
is_constant
(
l
))
{
l
=
mutator_
.
AllocateTmp
(
l
,
ExprArgsFetcher
(
args_
).
GetArgs
(
Div
::
make
(
l
,
r
)));
}
return
Call
::
make
(
op
->
type
,
!
is_left_constant
&&
!
is_right_constant
?
"vmadd"
:
"vaxpy"
,
args
,
Call
::
CallType
::
PureIntrinsic
);
return
Div
::
make
(
l
,
r
);
}
// vrelu [Xd] = max([Xn], 0)
// vmaddrelu [Xd] = max(vmadd [Xd], 0)
Expr
Mutate_
(
const
Max
*
op
,
const
Expr
&
e
)
{
// relu only support fp16
if
(
!
op
->
type
.
is_float
()
||
op
->
type
.
bits
()
!=
16
)
{
return
Max
::
make
(
Mutate
(
op
->
a
),
Mutate
(
op
->
b
));
}
bool
is_left_zero
=
isZero
(
op
->
a
);
bool
is_right_zero
=
IsZero
(
op
->
b
);
if
(
!
is_left_zero
&&
!
is_right_zero
)
{
return
e
;
return
Max
::
make
(
Mutate
(
op
->
a
),
Mutate
(
op
->
b
))
;
}
Expr
expr
=
op
->
a
;
if
(
is_left_zero
)
{
...
...
@@ -656,91 +726,154 @@ class InstructionSelector {
// int32 floor/ceil/round/trunc() --> floor/ceil/round/trunc()
// float(cc1) -> a[i] = cc1; cast(a[i])
Expr
Mutate_
(
const
Cast
*
op
,
const
Expr
&
e
)
{
if
(
op
->
type
.
is_int
()
&&
op
->
value
->
IsInstance
<
Call
>
())
{
const
Call
*
call
=
op
->
value
.
as
<
Call
>
();
Expr
value
=
Mutate
(
op
->
value
);
if
(
op
->
type
.
is_int
()
&&
value
->
IsInstance
<
Call
>
())
{
const
Call
*
call
=
value
.
as
<
Call
>
();
if
(
call
->
name
!=
"floor"
&&
call
->
name
!=
"ceil"
&&
call
->
name
!=
"round"
&&
call
->
name
!=
"trunc"
)
{
return
e
;
return
Cast
::
make
(
op
->
type
,
value
)
;
}
if
(
op
->
type
==
call
->
type
)
{
return
op
->
value
;
return
value
;
}
else
{
return
Call
::
make
(
op
->
type
,
call
->
name
,
call
->
args
,
call
->
call_type
,
call
->
func
,
call
->
value_index
);
}
}
if
(
op
->
type
.
is_float
()
&&
op
->
value
->
IsInstance
<
Variable
>
())
{
return
Cast
::
make
(
op
->
type
,
mutator_
.
AllocateTmp
(
op
->
value
));
if
(
op
->
type
.
is_float
()
&&
value
->
IsInstance
<
Variable
>
())
{
return
Cast
::
make
(
op
->
type
,
mutator_
.
AllocateTmp
(
value
));
}
return
e
;
return
Cast
::
make
(
op
->
type
,
value
)
;
}
Expr
Mutate_
(
const
Select
*
op
,
const
Expr
&
e
)
{
if
(
const
Not
*
notCond
=
op
->
condition
.
as
<
Not
>
())
{
return
Select
::
make
(
notCond
->
a
,
op
->
false_value
,
op
->
true_value
);
Expr
condition
=
Mutate
(
op
->
condition
);
Expr
true_value
=
Mutate
(
op
->
true_value
);
Expr
false_value
=
Mutate
(
op
->
false_value
);
if
(
const
Not
*
notCond
=
condition
.
as
<
Not
>
())
{
return
Select
::
make
(
notCond
->
a
,
false_value
,
true_value
);
}
if
(
const
And
*
andCond
=
op
->
condition
.
as
<
And
>
())
{
Expr
tmpExpr
=
Select
::
make
(
andCond
->
a
,
op
->
true_value
,
op
->
false_value
);
return
Select
::
make
(
andCond
->
b
,
tmpExpr
,
op
->
false_value
);
if
(
const
And
*
andCond
=
condition
.
as
<
And
>
())
{
Expr
tmpExpr
=
Select
::
make
(
andCond
->
a
,
true_value
,
false_value
);
return
Select
::
make
(
andCond
->
b
,
tmpExpr
,
false_value
);
}
if
(
const
Or
*
orCond
=
op
->
condition
.
as
<
Or
>
())
{
Expr
tmpExpr
=
Select
::
make
(
orCond
->
a
,
op
->
true_value
,
op
->
false_value
);
return
Select
::
make
(
orCond
->
b
,
op
->
true_value
,
tmpExpr
);
if
(
const
Or
*
orCond
=
condition
.
as
<
Or
>
())
{
Expr
tmpExpr
=
Select
::
make
(
orCond
->
a
,
true_value
,
false_value
);
return
Select
::
make
(
orCond
->
b
,
true_value
,
tmpExpr
);
}
return
e
;
return
Select
::
make
(
condition
,
true_value
,
false_value
)
;
}
private:
Expr
GetIndexOfPairExprForMul
(
const
Expr
&
expr
)
{
Expr
ret_expr
=
expr
;
bool
pos
=
sign_map_
.
at
(
expr
.
get
());
int
dim
=
CountVars
(
expr
);
for
(
auto
iter
=
exprs_
.
rbegin
();
iter
!=
exprs_
.
rend
();
++
iter
)
{
if
((
sign_map_
.
at
((
*
iter
).
get
())
!=
pos
)
||
is_constant
(
*
iter
)
||
(
iter
->
same_as
(
expr
)))
{
continue
;
template
<
typename
T
>
Expr
MulExprMutator
(
Expr
&
imm
,
const
T
*
op
)
{
Expr
l
=
Mutate
(
op
->
a
);
Expr
r
=
Mutate
(
op
->
b
);
if
(
is_constant
(
l
))
{
return
Mutate
(
T
::
make
(
ConstantFold
<
Mul
>
(
imm
,
l
),
Mul
::
make
(
r
,
imm
)));
}
else
if
(
is_constant
(
r
))
{
return
Mutate
(
T
::
make
(
ConstantFold
<
Mul
>
(
imm
,
r
),
Mul
::
make
(
l
,
imm
)));
}
if
(
CountVars
(
*
iter
)
>
dim
)
{
continue
;
return
Mul
::
make
(
T
::
make
(
l
,
r
),
imm
);
}
ret_expr
=
*
iter
;
exprs_
.
remove_if
([
&
ret_expr
](
Expr
e
)
{
return
e
.
same_as
(
ret_expr
);
});
break
;
template
<
typename
T
>
Expr
ConstantFold
(
const
Expr
&
a
,
const
Expr
&
b
)
{
CHECK
(
a
.
type
().
is_int
()
||
a
.
type
().
is_uint
()
||
a
.
type
().
is_float
());
if
(
a
.
type
()
!=
b
.
type
())
{
CHECK
(
a
.
type
()
==
b
.
type
());
}
CHECK
(
a
.
type
()
==
b
.
type
());
if
(
const
IntImm
*
int_a
=
a
.
as
<
IntImm
>
())
{
const
IntImm
*
int_b
=
b
.
as
<
IntImm
>
();
return
IntImm
::
make
(
a
.
type
(),
ComputeConstant
<
int64_t
,
T
>
(
int_a
->
value
,
int_b
->
value
));
}
if
(
const
UIntImm
*
uint_a
=
a
.
as
<
UIntImm
>
())
{
const
UIntImm
*
uint_b
=
b
.
as
<
UIntImm
>
();
return
UIntImm
::
make
(
a
.
type
(),
ComputeConstant
<
uint64_t
,
T
>
(
uint_a
->
value
,
uint_b
->
value
));
}
const
FloatImm
*
float_a
=
a
.
as
<
FloatImm
>
();
const
FloatImm
*
float_b
=
b
.
as
<
FloatImm
>
();
return
FloatImm
::
make
(
a
.
type
(),
ComputeConstant
<
double
,
T
>
(
float_a
->
value
,
float_b
->
value
));
}
return
ret_expr
;
template
<
typename
Data
,
typename
Op
>
Data
ComputeConstant
(
Data
d1
,
Data
d2
)
{
if
(
Op
::
_type_key
==
Mul
::
_type_key
)
{
return
d1
*
d2
;
}
if
(
Op
::
_type_key
==
Div
::
_type_key
)
{
return
d1
/
d2
;
}
if
(
Op
::
_type_key
==
Add
::
_type_key
)
{
return
d1
+
d2
;
}
CHECK
(
Op
::
_type_key
==
Sub
::
_type_key
);
return
d1
-
d2
;
}
bool
IsCandidate
(
const
Expr
&
e
)
{
if
(
!
e
->
IsInstance
<
Mul
>
())
{
return
false
;
}
const
Mul
*
mul
=
e
.
as
<
Mul
>
();
bool
is_left_constant
=
is_constant
(
mul
->
a
);
bool
is_right_constant
=
is_constant
(
mul
->
b
);
if
(
is_left_constant
&&
is_right_constant
)
{
return
false
;
}
return
mul
->
a
.
type
().
is_float
()
&&
mul
->
a
.
type
()
==
mul
->
b
.
type
();
}
ThreeAddressExprMutator
&
mutator_
;
std
::
list
<
Expr
>
&
exprs_
;
std
::
unordered_map
<
const
Object
*
,
std
::
string
>
&
notation_map_
;
std
::
unordered_map
<
const
Object
*
,
bool
>
&
sign_map_
;
};
Array
<
Expr
>
args_
;
};
// namespace ir
class
ExprOptMutator
:
public
IRMutator
{
public:
ExprOptMutator
(
ThreeAddressExprMutator
&
mutator
)
:
mutator_
(
mutator
)
{}
explicit
ExprOptMutator
(
ThreeAddressExprMutator
&
mutator
,
const
Array
<
Expr
>
&
args
)
:
mutator_
(
mutator
),
args_
(
args
)
{}
~
ExprOptMutator
()
override
=
default
;
Expr
Mutate
(
Expr
expr
)
{
expr
=
IRMutator
::
Mutate
(
expr
);
exprs_
.
sort
([](
Expr
&
e1
,
Expr
&
e2
)
->
bool
{
int
dim1
=
CountVars
(
e1
);
int
dim2
=
CountVars
(
e2
);
if
(
dim1
==
dim2
)
{
return
!
e1
->
IsInstance
<
Mul
>
();
}
return
dim1
<
dim2
;
IRMutator
::
Mutate
(
expr
);
std
::
sort
(
exprs_
.
begin
(),
exprs_
.
end
(),
[
this
](
Expr
&
e1
,
Expr
&
e2
)
->
bool
{
bool
is_const
=
is_constant
(
e1
);
if
(
is_const
||
is_constant
(
e2
))
{
return
!
is_const
;
}
Array
<
Expr
>
args1
=
ExprArgsFetcher
(
args_
).
GetArgs
(
e1
);
Array
<
Expr
>
args2
=
ExprArgsFetcher
(
args_
).
GetArgs
(
e2
);
if
(
args1
.
size
()
!=
args2
.
size
())
{
return
args1
.
size
()
>
args2
.
size
();
}
if
(
sign_map_
[
e1
.
get
()]
!=
sign_map_
[
e2
.
get
()])
{
return
!
sign_map_
[
e1
.
get
()];
}
return
e1
->
IsInstance
<
Mul
>
();
});
InstructionSelector
selector
(
mutator_
,
exprs_
,
notation_map_
,
sign_map_
);
for
(
auto
iter
=
exprs_
.
rbegin
();
iter
!=
exprs_
.
rend
();
++
iter
)
{
*
iter
=
selector
.
Mutate
(
*
iter
);
if
(
exprs_
.
size
()
<
3
)
{
return
expr
;
}
if
(
is_constant
(
exprs_
[
exprs_
.
size
()
-
2
]))
{
return
RebuildExpr
();
}
Expr
e
=
exprs_
.
front
();
Array
<
Expr
>
args
=
ExprArgsFetcher
(
args_
).
GetArgs
(
e
);
e
=
exprs_
[
exprs_
.
size
()
-
3
];
CHECK
(
sign_map_
.
find
(
e
.
get
())
!=
sign_map_
.
end
());
if
(
sign_map_
[
e
.
get
()])
{
e
=
exprs_
[
exprs_
.
size
()
-
2
];
}
if
(
args
.
size
()
>
ExprArgsFetcher
(
args_
).
GetArgs
(
e
).
size
())
{
expr
=
RebuildExpr
();
}
return
expr
;
}
Expr
Mutate_
(
const
Select
*
op
,
const
Expr
&
e
)
{
InitExprStatusIfNeed
(
e
);
Expr
expr
=
Select
::
make
(
op
->
condition
,
ExprOptMutator
(
mutator_
).
Mutate
(
op
->
true_value
),
ExprOptMutator
(
mutator_
).
Mutate
(
op
->
false_value
));
Expr
expr
=
Select
::
make
(
op
->
condition
,
ExprOptMutator
(
mutator_
,
args_
).
Mutate
(
op
->
true_value
),
ExprOptMutator
(
mutator_
,
args_
).
Mutate
(
op
->
false_value
));
exprs_
.
push_back
(
expr
);
UpdateExprStatus
(
e
,
expr
);
return
expr
;
}
...
...
@@ -748,53 +881,9 @@ class ExprOptMutator : public IRMutator {
Expr
Mutate_
(
const
Sub
*
op
,
const
Expr
&
e
)
{
return
AnalyzeBinaryOpExpr
(
op
,
e
);
}
Expr
Mutate_
(
const
Mul
*
op
,
const
Expr
&
e
)
{
bool
is_left_constant
=
is_constant
(
op
->
a
);
bool
is_right_constant
=
is_constant
(
op
->
b
);
if
((
is_left_constant
&&
is_left_constant
)
||
(
!
is_left_constant
&&
!
is_right_constant
))
{
return
AnalyzeBinaryOpExpr
(
op
,
e
);
}
Expr
non_constant_expr
=
is_left_constant
?
op
->
b
:
op
->
a
;
Expr
constant_expr
=
is_left_constant
?
op
->
a
:
op
->
b
;
Expr
Mutate_
(
const
Mul
*
op
,
const
Expr
&
e
)
{
return
AnalyzeBinaryOpExpr
(
op
,
e
);
}
if
(
non_constant_expr
->
IsInstance
<
Add
>
())
{
const
Add
*
add
=
non_constant_expr
.
as
<
Add
>
();
if
(
is_constant
(
add
->
a
)
||
is_constant
(
add
->
b
))
{
Expr
expr
=
Add
::
make
(
Mul
::
make
(
constant_expr
,
add
->
a
),
Mul
::
make
(
constant_expr
,
add
->
b
));
if
(
notation_map_
.
find
(
e
.
get
())
==
notation_map_
.
end
())
{
notation_map_
[
expr
.
get
()]
=
notation_map_
[
e
.
get
()];
}
if
(
sign_map_
.
find
(
e
.
get
())
!=
sign_map_
.
end
())
{
sign_map_
[
expr
.
get
()]
=
sign_map_
[
e
.
get
()];
}
return
IRMutator
::
Mutate
(
expr
);
}
}
if
(
non_constant_expr
->
IsInstance
<
Sub
>
())
{
const
Sub
*
sub
=
non_constant_expr
.
as
<
Sub
>
();
if
(
is_constant
(
sub
->
a
)
||
is_constant
(
sub
->
b
))
{
Expr
expr
=
Sub
::
make
(
Mul
::
make
(
constant_expr
,
sub
->
a
),
Mul
::
make
(
constant_expr
,
sub
->
b
));
if
(
notation_map_
.
find
(
e
.
get
())
==
notation_map_
.
end
())
{
notation_map_
[
expr
.
get
()]
=
notation_map_
[
e
.
get
()];
}
if
(
sign_map_
.
find
(
e
.
get
())
!=
sign_map_
.
end
())
{
sign_map_
[
expr
.
get
()]
=
sign_map_
[
e
.
get
()];
}
return
IRMutator
::
Mutate
(
expr
);
}
}
return
AnalyzeBinaryOpExpr
(
op
,
e
);
}
// Imm / x -> y = Imm; y/x
Expr
Mutate_
(
const
Div
*
op
,
const
Expr
&
e
)
{
if
(
is_constant
(
op
->
a
)
&&
!
is_constant
(
op
->
b
))
{
Expr
expr
=
Div
::
make
(
mutator_
.
AllocateTmp
(
op
->
a
),
op
->
b
);
const
Div
*
div
=
expr
.
as
<
Div
>
();
return
AnalyzeBinaryOpExpr
(
div
,
expr
);
}
return
AnalyzeBinaryOpExpr
(
op
,
e
);
}
Expr
Mutate_
(
const
Div
*
op
,
const
Expr
&
e
)
{
return
AnalyzeBinaryOpExpr
(
op
,
e
);
}
Expr
Mutate_
(
const
Mod
*
op
,
const
Expr
&
e
)
{
return
AnalyzeBinaryOpExpr
(
op
,
e
);
}
...
...
@@ -824,31 +913,35 @@ class ExprOptMutator : public IRMutator {
Expr
Mutate_
(
const
Let
*
op
,
const
Expr
&
e
)
{
InitExprStatusIfNeed
(
e
);
Expr
expr
=
Let
::
make
(
op
->
var
,
ExprOptMutator
(
mutator_
).
Mutate
(
op
->
value
),
ExprOptMutator
(
mutator
_
).
Mutate
(
op
->
body
));
Expr
expr
=
Let
::
make
(
op
->
var
,
ExprOptMutator
(
mutator_
,
args_
).
Mutate
(
op
->
value
),
ExprOptMutator
(
mutator_
,
args
_
).
Mutate
(
op
->
body
));
exprs_
.
push_back
(
expr
);
UpdateExprStatus
(
e
,
expr
);
return
expr
;
}
Expr
Mutate_
(
const
Cast
*
op
,
const
Expr
&
e
)
{
InitExprStatusIfNeed
(
e
);
Expr
expr
=
Cast
::
make
(
op
->
type
,
ExprOptMutator
(
mutator_
).
Mutate
(
op
->
value
));
Expr
expr
=
Cast
::
make
(
op
->
type
,
ExprOptMutator
(
mutator_
,
args_
).
Mutate
(
op
->
value
));
exprs_
.
push_back
(
expr
);
UpdateExprStatus
(
e
,
expr
);
return
expr
;
}
Expr
Mutate_
(
const
Not
*
op
,
const
Expr
&
e
)
{
InitExprStatusIfNeed
(
e
);
Expr
expr
=
Not
::
make
(
ExprOptMutator
(
mutator_
).
Mutate
(
op
->
a
));
Expr
expr
=
Not
::
make
(
ExprOptMutator
(
mutator_
,
args_
).
Mutate
(
op
->
a
));
exprs_
.
push_back
(
expr
);
UpdateExprStatus
(
e
,
expr
);
return
expr
;
}
Expr
Mutate_
(
const
Load
*
op
,
const
Expr
&
e
)
{
InitExprStatusIfNeed
(
e
);
Expr
expr
=
Load
::
make
(
op
->
type
,
op
->
buffer_var
,
ExprOptMutator
(
mutator_
).
Mutate
(
op
->
index
),
ExprOptMutator
(
mutator_
).
Mutate
(
op
->
predicate
));
Expr
expr
=
Load
::
make
(
op
->
type
,
op
->
buffer_var
,
ExprOptMutator
(
mutator_
,
args_
).
Mutate
(
op
->
index
),
ExprOptMutator
(
mutator_
,
args_
).
Mutate
(
op
->
predicate
));
exprs_
.
push_back
(
expr
);
UpdateExprStatus
(
e
,
expr
);
return
expr
;
}
...
...
@@ -856,11 +949,12 @@ class ExprOptMutator : public IRMutator {
InitExprStatusIfNeed
(
e
);
Array
<
Expr
>
source
;
for
(
Expr
src
:
op
->
source
)
{
source
.
push_back
(
ExprOptMutator
(
mutator_
).
Mutate
(
src
));
source
.
push_back
(
ExprOptMutator
(
mutator_
,
args_
).
Mutate
(
src
));
}
Expr
expr
=
Reduce
::
make
(
op
->
combiner
,
source
,
op
->
axis
,
ExprOptMutator
(
mutator_
).
Mutate
(
op
->
condition
),
op
->
value_index
);
Expr
expr
=
Reduce
::
make
(
op
->
combiner
,
source
,
op
->
axis
,
ExprOptMutator
(
mutator_
,
args_
).
Mutate
(
op
->
condition
),
op
->
value_index
);
exprs_
.
push_back
(
expr
);
UpdateExprStatus
(
e
,
expr
);
return
expr
;
}
...
...
@@ -868,14 +962,15 @@ class ExprOptMutator : public IRMutator {
InitExprStatusIfNeed
(
e
);
Array
<
Expr
>
vectors
;
for
(
Expr
v
:
op
->
vectors
)
{
vectors
.
push_back
(
ExprOptMutator
(
mutator_
).
Mutate
(
v
));
vectors
.
push_back
(
ExprOptMutator
(
mutator_
,
args_
).
Mutate
(
v
));
}
Array
<
Expr
>
indices
;
for
(
Expr
indic
:
op
->
indices
)
{
indices
.
push_back
(
ExprOptMutator
(
mutator_
).
Mutate
(
indic
));
indices
.
push_back
(
ExprOptMutator
(
mutator_
,
args_
).
Mutate
(
indic
));
}
Expr
expr
=
Shuffle
::
make
(
vectors
,
indices
);
exprs_
.
push_back
(
expr
);
UpdateExprStatus
(
e
,
expr
);
return
expr
;
}
...
...
@@ -883,26 +978,29 @@ class ExprOptMutator : public IRMutator {
InitExprStatusIfNeed
(
e
);
Array
<
Expr
>
args
;
for
(
Expr
arg
:
op
->
args
)
{
args
.
push_back
(
ExprOptMutator
(
mutator_
).
Mutate
(
arg
));
args
.
push_back
(
ExprOptMutator
(
mutator_
,
args_
).
Mutate
(
arg
));
}
Expr
expr
=
Call
::
make
(
op
->
type
,
op
->
name
,
args
,
op
->
call_type
,
op
->
func
,
op
->
value_index
);
mutator_
.
AddBroadCastCallIfNeed
(
op
,
expr
);
exprs_
.
push_back
(
expr
);
mutator_
.
AddBroadCastCallIfNeed
(
op
,
expr
);
UpdateExprStatus
(
e
,
expr
);
return
expr
;
}
Expr
Mutate_
(
const
Ramp
*
op
,
const
Expr
&
e
)
{
InitExprStatusIfNeed
(
e
);
Expr
expr
=
Ramp
::
make
(
ExprOptMutator
(
mutator_
).
Mutate
(
op
->
base
),
ExprOptMutator
(
mutator
_
).
Mutate
(
op
->
stride
),
op
->
lanes
);
Expr
expr
=
Ramp
::
make
(
ExprOptMutator
(
mutator_
,
args_
).
Mutate
(
op
->
base
),
ExprOptMutator
(
mutator_
,
args
_
).
Mutate
(
op
->
stride
),
op
->
lanes
);
exprs_
.
push_back
(
expr
);
UpdateExprStatus
(
e
,
expr
);
return
expr
;
}
Expr
Mutate_
(
const
Broadcast
*
op
,
const
Expr
&
e
)
{
InitExprStatusIfNeed
(
e
);
Expr
expr
=
Broadcast
::
make
(
ExprOptMutator
(
mutator_
).
Mutate
(
op
->
value
),
op
->
lanes
);
Expr
expr
=
Broadcast
::
make
(
ExprOptMutator
(
mutator_
,
args_
).
Mutate
(
op
->
value
),
op
->
lanes
);
exprs_
.
push_back
(
expr
);
UpdateExprStatus
(
e
,
expr
);
return
expr
;
}
...
...
@@ -927,12 +1025,21 @@ class ExprOptMutator : public IRMutator {
}
}
void
UpdateExprStatus
(
const
Expr
&
before
,
const
Expr
&
after
)
{
const
Object
*
b
=
before
.
get
();
const
Object
*
a
=
after
.
get
();
CHECK
(
notation_map_
.
find
(
b
)
!=
notation_map_
.
end
());
notation_map_
[
a
]
=
notation_map_
[
b
];
CHECK
(
sign_map_
.
find
(
b
)
!=
sign_map_
.
end
());
sign_map_
[
a
]
=
sign_map_
[
b
];
}
bool
IsNewRoot
(
const
Expr
&
e
)
{
CHECK
(
notation_map_
.
find
(
e
.
get
())
!=
notation_map_
.
end
());
std
::
string
root
=
notation_map_
[
e
.
get
()];
std
::
string
type_key
=
e
->
GetTypeKey
();
return
!
((
root
==
Add
::
_type_key
||
root
==
Sub
::
_type_key
)
&&
(
type_key
==
Add
::
_type_key
||
type_key
==
Sub
::
_type_key
))
||
(
type_key
==
Add
::
_type_key
||
type_key
==
Sub
::
_type_key
))
&&
!
((
root
==
Mul
::
_type_key
||
root
==
Div
::
_type_key
)
&&
(
type_key
==
Mul
::
_type_key
||
type_key
==
Div
::
_type_key
));
}
...
...
@@ -946,7 +1053,7 @@ class ExprOptMutator : public IRMutator {
std
::
string
type_key
=
e
->
GetTypeKey
();
Expr
expr
=
e
;
if
(
IsNewRoot
(
e
))
{
expr
=
T
::
make
(
ExprOptMutator
(
mutator_
).
Mutate
(
op
->
a
),
ExprOptMutator
(
mutator
_
).
Mutate
(
op
->
b
));
expr
=
T
::
make
(
ExprOptMutator
(
mutator_
,
args_
).
Mutate
(
op
->
a
),
ExprOptMutator
(
mutator_
,
args
_
).
Mutate
(
op
->
b
));
notation_map_
[
expr
.
get
()]
=
root_of_e
;
sign_map_
[
expr
.
get
()]
=
pos_of_e
;
exprs_
.
push_back
(
expr
);
...
...
@@ -957,6 +1064,7 @@ class ExprOptMutator : public IRMutator {
sign_map_
[
op
->
b
.
get
()]
=
(
type_key
==
Sub
::
_type_key
||
type_key
==
Div
::
_type_key
)
?
!
pos_of_e
:
pos_of_e
;
expr
=
T
::
make
(
IRMutator
::
Mutate
(
op
->
a
),
IRMutator
::
Mutate
(
op
->
b
));
}
UpdateExprStatus
(
e
,
expr
);
return
expr
;
}
...
...
@@ -968,11 +1076,11 @@ class ExprOptMutator : public IRMutator {
Expr
RebuildExpr
()
{
CHECK
(
!
exprs_
.
empty
());
Expr
expr
=
exprs_
.
front
();
exprs_
.
pop_
front
();
Expr
expr
=
exprs_
.
back
();
exprs_
.
pop_
back
();
while
(
!
exprs_
.
empty
())
{
expr
=
RebuildExpr
(
expr
,
exprs_
.
front
());
exprs_
.
pop_
front
();
expr
=
RebuildExpr
(
expr
,
exprs_
.
back
());
exprs_
.
pop_
back
();
}
return
expr
;
}
...
...
@@ -1004,58 +1112,12 @@ class ExprOptMutator : public IRMutator {
}
ThreeAddressExprMutator
&
mutator_
;
std
::
list
<
Expr
>
exprs_
;
Array
<
Expr
>
args_
;
std
::
vector
<
Expr
>
exprs_
;
std
::
unordered_map
<
const
Object
*
,
std
::
string
>
notation_map_
;
std
::
unordered_map
<
const
Object
*
,
bool
>
sign_map_
;
};
class
LoopMutator
:
public
IRMutator
{
public:
LoopMutator
()
:
loop_level_
(
0
)
{}
~
LoopMutator
()
override
=
default
;
Stmt
Mutate_
(
const
For
*
op
,
const
Stmt
&
s
)
final
{
loop_level_
++
;
loop_vars_
.
push_front
(
op
);
Stmt
stmt
=
IRMutator
::
Mutate
(
op
->
body
);
if
(
provides_
.
size
()
==
1
||
provides_
.
front
()
->
args
.
size
()
==
provides_
.
front
()
->
args
.
size
())
{
return
s
;
}
if
(
!
stmt
->
IsInstance
<
For
>
())
{
provides_
.
sort
([](
const
Provide
*
s1
,
const
Provide
*
s2
)
->
bool
{
return
s1
->
args
.
size
()
<
s2
->
args
.
size
();
});
const
Provide
*
provide
=
provides_
.
back
();
stmt
=
Provide
::
make
(
provide
->
func
,
provide
->
value_index
,
provide
->
value
,
provide
->
args
);
provides_
.
pop_back
();
}
stmt
=
RebuildForStmt
(
loop_vars_
.
back
(),
stmt
);
loop_vars_
.
pop_back
();
loop_level_
--
;
return
stmt
;
}
Stmt
Mutate_
(
const
Provide
*
op
,
const
Stmt
&
s
)
final
{
if
(
!
loop_vars_
.
empty
())
{
provides_
.
push_back
(
op
);
}
return
IRMutator
::
Mutate_
(
op
,
s
);
}
private:
Stmt
RebuildForStmt
(
const
For
*
op
,
Stmt
&
body
)
{
Stmt
stmt
=
body
;
while
(
!
provides_
.
empty
()
&&
op
->
loop_var
.
same_as
(
provides_
.
back
()
->
args
[
0
]))
{
const
Provide
*
second
=
provides_
.
back
();
stmt
=
Block
::
make
(
Provide
::
make
(
second
->
func
,
second
->
value_index
,
second
->
value
,
second
->
args
),
stmt
);
provides_
.
pop_back
();
}
return
For
::
make
(
op
->
loop_var
,
op
->
min
,
op
->
extent
,
op
->
for_type
,
op
->
device_api
,
stmt
);
}
size_t
loop_level_
;
std
::
list
<
const
Provide
*>
provides_
;
std
::
list
<
const
For
*>
loop_vars_
;
};
class
InferUpperBound
{
private:
class
Bound
{
...
...
@@ -1384,12 +1446,16 @@ class ThreeAddressStmtMutator : public IRMutator {
args_
=
args
;
static_cast
<
void
>
(
this
->
Mutate
(
op
->
value
));
// mutate according to the result of instruction selection
ThreeAddressExprMutator
mutator
(
output
,
args
,
shape
,
broadcast_
,
is_reduction
,
cross_stmt_simplify_
);
ThreeAddressExprMutator
mutator
(
output
,
args
,
op
->
args
,
shape
,
broadcast_
,
is_reduction
,
cross_stmt_simplify_
,
is_simple_
);
if
(
cross_stmt_simplify_
)
{
// Bring over the common exprs from previous stage
mutator
.
SetCommonExpr
(
global_common_expr_
);
}
value
=
ExprOptMutator
(
mutator
).
Mutate
(
value
);
if
(
is_simple_
)
{
value
=
ExprOptMutator
(
mutator
,
args_
).
Mutate
(
value
);
}
value
=
InstructionMutator
(
mutator
,
args_
).
Mutate
(
value
);
value
=
mutator
.
Mutate
(
value
);
if
(
cross_stmt_simplify_
)
{
// Take back the common exprs for next stages
...
...
@@ -1489,11 +1555,46 @@ class ThreeAddressStmtMutator : public IRMutator {
}
Stmt
Mutate_
(
const
For
*
op
,
const
Stmt
&
s
)
final
{
if
(
loop_level
==
0
)
{
is_simple_
=
IsSimpleFor
(
op
);
}
loop_level
++
;
dom_map
[
op
->
loop_var
]
=
Range
::
make_by_min_extent
(
op
->
min
,
op
->
extent
);
return
IRMutator
::
Mutate_
(
op
,
s
);
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
loop_level
--
;
if
(
loop_level
==
0
)
{
is_simple_
=
true
;
}
return
stmt
;
}
static
bool
IsSimpleFor
(
const
For
*
op
)
{
if
(
const
For
*
sub_for
=
op
->
body
.
as
<
For
>
())
{
return
IsSimpleFor
(
sub_for
);
}
if
(
const
Block
*
block
=
op
->
body
.
as
<
Block
>
())
{
return
IsSimpleBlock
(
block
);
}
return
op
->
body
->
IsInstance
<
Provide
>
();
}
private:
static
bool
IsSimpleBlock
(
const
Block
*
op
)
{
if
(
op
->
first
->
IsInstance
<
Provide
>
()
&&
op
->
rest
->
IsInstance
<
Provide
>
())
{
return
true
;
}
if
(
op
->
first
->
IsInstance
<
Block
>
()
&&
op
->
rest
->
IsInstance
<
Block
>
())
{
return
IsSimpleBlock
(
op
->
first
.
as
<
Block
>
())
&&
IsSimpleBlock
(
op
->
rest
.
as
<
Block
>
());
}
if
(
op
->
first
->
IsInstance
<
Provide
>
()
&&
op
->
rest
->
IsInstance
<
Block
>
())
{
return
IsSimpleBlock
(
op
->
rest
.
as
<
Block
>
());
}
if
(
op
->
first
->
IsInstance
<
Block
>
()
&&
op
->
rest
->
IsInstance
<
Provide
>
())
{
return
IsSimpleBlock
(
op
->
first
.
as
<
Block
>
());
}
return
false
;
}
std
::
unordered_map
<
Tensor
,
std
::
vector
<
Tensor
>>
split_to_
;
std
::
unordered_map
<
FunctionRef
,
std
::
set
<
int
>
,
air
::
NodeHash
,
air
::
NodeEqual
>
op_indices_
;
...
...
@@ -1504,6 +1605,9 @@ class ThreeAddressStmtMutator : public IRMutator {
std
::
unordered_map
<
size_t
,
std
::
pair
<
Expr
,
Expr
>>
global_common_expr_
;
int
loop_level
{
0
};
bool
is_simple_
{
true
};
// mark broadcast
Tensor
output_
;
Array
<
Expr
>
args_
;
...
...
@@ -1513,8 +1617,83 @@ class ThreeAddressStmtMutator : public IRMutator {
bool
cross_stmt_simplify_
;
};
class
LoopMutator
:
public
IRMutator
{
public:
Stmt
Mutate_
(
const
For
*
op
,
const
Stmt
&
s
)
final
{
if
(
loop_vars_
.
empty
()
&&
!
ThreeAddressStmtMutator
::
IsSimpleFor
(
op
))
{
return
s
;
}
loop_vars_
.
push_back
(
op
);
Stmt
stmt
=
IRMutator
::
Mutate
(
op
->
body
);
if
(
!
provides_
.
empty
())
{
provides_
.
sort
([](
const
Provide
*
s1
,
const
Provide
*
s2
)
->
bool
{
return
s1
->
args
.
size
()
<
s2
->
args
.
size
();
});
while
(
!
provides_
.
empty
())
{
SplitProides
();
}
}
for
(
size_t
index
=
0
;
index
<
stmts_
.
size
();
++
index
)
{
if
(
IsContain
(
args_
[
index
],
loop_vars_
.
back
()
->
loop_var
))
{
stmts_
[
index
]
=
For
::
make
(
op
->
loop_var
,
op
->
min
,
op
->
extent
,
op
->
for_type
,
op
->
device_api
,
stmts_
[
index
]);
}
}
loop_vars_
.
pop_back
();
if
(
loop_vars_
.
empty
())
{
stmt
=
stmts_
.
back
();
for
(
auto
iter
=
++
stmts_
.
rbegin
();
iter
!=
stmts_
.
rend
();
iter
++
)
{
stmt
=
Block
::
make
(
*
iter
,
stmt
);
}
stmts_
.
clear
();
args_
.
clear
();
}
return
stmt
;
}
Stmt
Mutate_
(
const
Provide
*
op
,
const
Stmt
&
s
)
final
{
if
(
!
loop_vars_
.
empty
())
{
provides_
.
push_back
(
op
);
}
return
IRMutator
::
Mutate_
(
op
,
s
);
}
private:
void
SplitProides
()
{
const
Provide
*
provide
=
provides_
.
back
();
Stmt
stmt
=
Provide
::
make
(
provide
->
func
,
provide
->
value_index
,
provide
->
value
,
provide
->
args
);
provides_
.
pop_back
();
while
(
!
provides_
.
empty
())
{
const
Provide
*
next
=
provides_
.
back
();
if
(
provide
->
args
.
size
()
!=
next
->
args
.
size
())
{
break
;
}
stmt
=
Block
::
make
(
Provide
::
make
(
next
->
func
,
next
->
value_index
,
next
->
value
,
next
->
args
),
stmt
);
provides_
.
pop_back
();
}
stmts_
.
insert
(
stmts_
.
begin
(),
stmt
);
args_
.
insert
(
args_
.
begin
(),
provide
->
args
);
}
bool
IsContain
(
const
Array
<
Expr
>
&
args
,
const
Var
&
var
)
{
VarSet
all_vars
;
for
(
Expr
e
:
args
)
{
GatherVars
(
e
,
&
all_vars
);
}
for
(
auto
v
:
all_vars
)
{
if
(
v
.
same_as
(
var
))
{
return
true
;
}
}
return
false
;
}
std
::
list
<
const
For
*>
loop_vars_
{};
std
::
list
<
const
Provide
*>
provides_
{};
std
::
vector
<
Stmt
>
stmts_
{};
std
::
vector
<
Array
<
Expr
>>
args_
{};
};
Stmt
ToThreeAddress
(
Stmt
stmt
,
bool
reuse_variable
,
int
minimum_split
,
bool
cross_stmt_simplify
)
{
stmt
=
ThreeAddressStmtMutator
(
reuse_variable
,
minimum_split
,
cross_stmt_simplify
).
Mutate
(
stmt
);
stmt
=
LoopMutator
().
Mutate
(
stmt
);
return
Simplify_cce
(
stmt
);
}
}
// namespace ir
...
...
tests/unittest_cpp/src/pass_test/to_three_address_test.cc
浏览文件 @
05313741
...
...
@@ -36,12 +36,7 @@ TEST(ToThreeAddressTest, BuildCase1) {
UTTensorElementHelper
th
({
16
,
32
,
1024
});
using
Add
=
air
::
ir
::
Add
;
// a(ax1, ax2) + b(ax2) + c(ax0, ax1, ax2) + d(ax2)
air
::
Expr
expr
=
Add
::
make
(
Add
::
make
(
Add
::
make
(
th
.
Elem
(
"a"
,
2
),
th
.
Elem
(
"b"
,
1
)),
th
.
Elem
(
"c"
,
3
)),
th
.
Elem
(
"d"
,
1
));
air
::
Expr
expr
=
Add
::
make
(
Add
::
make
(
Add
::
make
(
th
.
Elem
(
"a"
,
2
),
th
.
Elem
(
"b"
,
1
)),
th
.
Elem
(
"c"
,
3
)),
th
.
Elem
(
"d"
,
1
));
std
::
string
dump_expr
=
UTDumpHelper
::
Dump
(
expr
);
EXPECT_EQ
(
dump_expr
,
"(((a(ax1, ax2) + b(ax2)) + c(ax0, ax1, ax2)) + d(ax2))"
);
}
...
...
@@ -49,12 +44,12 @@ TEST(ToThreeAddressTest, BuildCase1) {
class
ThreeAddressExprMutatorTest
:
public
testing
::
Test
{
public:
ThreeAddressExprMutatorTest
()
:
mutator_
(
air
::
TensorNode
::
make
(
UTExprBuilder
::
CreateShape
(
shape_
),
// shape
:
mutator_
(
air
::
TensorNode
::
make
(
UTExprBuilder
::
CreateShape
(
shape_
),
// shape
dtype_
,
// dtype
UTExprBuilder
::
PlaceholderOpNode
(
"out"
,
shape_
),
// op
0
),
// index
UTExprBuilder
::
CreateVars
({
"ax0"
,
"ax1"
,
"ax2"
}),
// args
UTExprBuilder
::
CreateVars
({
"ax0"
,
"ax1"
,
"ax2"
}),
// args
UTExprBuilder
::
CreateShape
(
shape_
),
// shape
std
::
unordered_set
<
const
Call
*>
(),
// broadcast
false
,
// IsReductionOp
...
...
@@ -76,9 +71,7 @@ TEST_F(ThreeAddressExprMutatorTest, MutateBinaryOp_Add) {
class
PassTestToThreeAddress1
:
public
::
testing
::
Test
{
public:
PassTestToThreeAddress1
()
{
Construct
();
}
PassTestToThreeAddress1
()
{
Construct
();
}
~
PassTestToThreeAddress1
()
=
default
;
void
Construct
()
{
a_
=
UTExprBuilder
::
PlaceholderOpNode
(
"a"
,
{
1024
},
air
::
Float
(
16
));
...
...
@@ -88,17 +81,15 @@ class PassTestToThreeAddress1 : public ::testing::Test {
stmt
=
air
::
ir
::
AttrStmt
::
make
(
out_
,
""
,
UTExprBuilder
::
IntImm
(
1
),
UTStmtBuilder
::
CreateRealizeByPlaceholderOp
(
out_
,
air
::
ir
::
ProducerConsumer
::
make
(
out_
,
true
,
out_
,
air
::
ir
::
ProducerConsumer
::
make
(
out_
,
true
,
UTStmtBuilder
::
CreateFor
(
"i"
,
0
,
32
,
UTStmtBuilder
::
CreateFor
(
"j"
,
0
,
1024
,
UTStmtBuilder
::
CreateProvideBinary
<
air
::
ir
::
Add
>
(
out_
,
{
"i"
,
"j"
},
air
::
ir
::
Add
::
make
(
UTExprBuilder
::
ElementOf
(
a_
,
{
"j"
}),
UTExprBuilder
::
ElementOf
(
b_
,
{
"i"
,
"j"
})),
air
::
ir
::
Add
::
make
(
UTExprBuilder
::
ElementOf
(
a_
,
{
"j"
}),
UTExprBuilder
::
ElementOf
(
b_
,
{
"i"
,
"j"
})),
UTExprBuilder
::
ElementOf
(
c_
,
{
"j"
})))))));
}
...
...
@@ -110,7 +101,7 @@ class PassTestToThreeAddress1 : public ::testing::Test {
};
// class PassTestToThreeAddress1
TEST_F
(
PassTestToThreeAddress1
,
CaseCheck
)
{
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
infos_lhs
=
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
infos_lhs
=
UTProvideCheckerForAssign
().
Find
(
stmt
,
"((a(j) + b(i, j)) + c(j))"
);
ASSERT_EQ
(
infos_lhs
.
size
(),
1
);
EXPECT_EQ
(
std
::
get
<
0
>
(
infos_lhs
[
0
]),
"out(i, j)"
);
...
...
@@ -124,21 +115,19 @@ TEST_F(PassTestToThreeAddress1, TestPass) {
* out_3(i, j) = (a(j) + out_2(i, j))
* out(i, j) = (out_3(i, j) + c(j))
*/
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
info1
=
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
info1
=
UTProvideCheckerForAssign
().
Find
(
stmt_out
,
"b(i, j)"
);
ASSERT_EQ
(
info1
.
size
(),
1
);
std
::
string
dump_b_target
=
std
::
get
<
0
>
(
info1
[
0
]);
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
info2
=
UTProvideCheckerForBinary
().
Find
(
stmt_out
,
UTProvideCheckerForBinary
::
BinaryOpType
::
kAdd
,
"a(j)"
,
dump_b_target
);
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
info2
=
UTProvideCheckerForBinary
().
Find
(
stmt_out
,
UTProvideCheckerForBinary
::
BinaryOpType
::
kAdd
,
"a(j)"
,
dump_b_target
);
ASSERT_EQ
(
info2
.
size
(),
1
);
std
::
string
dump_sum1_target
=
std
::
get
<
0
>
(
info2
[
0
]);
EXPECT_EQ
(
std
::
get
<
2
>
(
info2
[
0
]),
32
*
1024
);
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
info3
=
UTProvideCheckerForBinary
().
Find
(
stmt_out
,
UTProvideCheckerForBinary
::
BinaryOpType
::
kAdd
,
dump_sum1_target
,
"c(j)"
);
std
::
vector
<
std
::
tuple
<
std
::
string
,
const
air
::
ir
::
Provide
*
,
uint64_t
>>
info3
=
UTProvideCheckerForBinary
().
Find
(
stmt_out
,
UTProvideCheckerForBinary
::
BinaryOpType
::
kAdd
,
dump_sum1_target
,
"c(j)"
);
ASSERT_EQ
(
info3
.
size
(),
1
);
EXPECT_EQ
(
std
::
get
<
0
>
(
info3
[
0
]),
"out(i, j)"
);
EXPECT_EQ
(
std
::
get
<
2
>
(
info3
[
0
]),
32
*
1024
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录