Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
akg
提交
fae7628a
A
akg
项目概览
MindSpore
/
akg
通知
59
Star
7
Fork
7
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
akg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fae7628a
编写于
7月 17, 2020
作者:
Z
zhaiyukun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add Optimization for three address
1.Arithmetic priority adjustment 2.Instruction selection
上级
6335daad
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
481 addition
and
293 deletion
+481
-293
src/pass/to_three_address.cc
src/pass/to_three_address.cc
+481
-293
未找到文件。
src/pass/to_three_address.cc
浏览文件 @
fae7628a
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
* See the License for the specific language governing permissions and
* See the License for the specific language governing permissions and
* limitations under the License.
* limitations under the License.
*/
*/
#include <arithmetic/pattern_match.h>
#include <dmlc/common.h>
#include <dmlc/common.h>
#include <tvm/ir.h>
#include <tvm/ir.h>
#include <tvm/tensor.h>
#include <tvm/tensor.h>
...
@@ -36,12 +35,6 @@ using VarSet = std::unordered_set<Var, air::NodeHash, air::NodeEqual>;
...
@@ -36,12 +35,6 @@ using VarSet = std::unordered_set<Var, air::NodeHash, air::NodeEqual>;
// forward declaration
// forward declaration
class
ThreeAddressExprMutator
;
class
ThreeAddressExprMutator
;
struct
ExpressionPattern
{
int
min_level
;
// minimal level
std
::
function
<
int
(
Expr
)
>
score_func
;
// assign score to a subtree
std
::
function
<
Expr
(
Expr
,
ThreeAddressExprMutator
&
)
>
replace_func
;
// replace a subtree with this instruction
};
class
ThreeAddressFilter
:
public
IRVisitor
{
class
ThreeAddressFilter
:
public
IRVisitor
{
public:
public:
bool
Find
(
const
Stmt
&
s
)
{
bool
Find
(
const
Stmt
&
s
)
{
...
@@ -324,14 +317,6 @@ class ThreeAddressExprMutator : public IRMutator {
...
@@ -324,14 +317,6 @@ class ThreeAddressExprMutator : public IRMutator {
// forward declaration
// forward declaration
Expr
Mutate
(
Expr
expr
)
override
;
Expr
Mutate
(
Expr
expr
)
override
;
// do naive three address translation without instruction selection
Expr
MutateWithoutSelection
(
const
Expr
expr
)
{
disable_selection_
=
true
;
Expr
ret
=
Mutate
(
expr
);
disable_selection_
=
false
;
return
ret
;
}
template
<
typename
T
>
template
<
typename
T
>
Expr
MutateBinaryOp
(
const
T
*
op
,
const
Expr
&
e
)
{
Expr
MutateBinaryOp
(
const
T
*
op
,
const
Expr
&
e
)
{
in_call_
++
;
in_call_
++
;
...
@@ -536,6 +521,15 @@ class ThreeAddressExprMutator : public IRMutator {
...
@@ -536,6 +521,15 @@ class ThreeAddressExprMutator : public IRMutator {
Expr
Mutate_
(
const
FloatImm
*
op
,
const
Expr
&
e
)
final
{
return
MutateConstOp
(
op
,
e
);
}
Expr
Mutate_
(
const
FloatImm
*
op
,
const
Expr
&
e
)
final
{
return
MutateConstOp
(
op
,
e
);
}
Expr
Mutate_
(
const
IntImm
*
op
,
const
Expr
&
e
)
final
{
return
MutateConstOp
(
op
,
e
);
}
Expr
Mutate_
(
const
IntImm
*
op
,
const
Expr
&
e
)
final
{
return
MutateConstOp
(
op
,
e
);
}
void
AddBroadCastCallIfNeed
(
const
Call
*
op
,
const
Expr
&
e
)
{
if
(
broadcast_
.
find
(
op
)
==
broadcast_
.
end
())
{
return
;
}
const
Call
*
new_call
=
e
.
as
<
Call
>
();
CHECK_NOTNULL
(
new_call
);
broadcast_
.
insert
(
new_call
);
}
std
::
vector
<
Stmt
>
assign_stmt
;
std
::
vector
<
Stmt
>
assign_stmt
;
std
::
vector
<
Tensor
>
imm_tensors
;
std
::
vector
<
Tensor
>
imm_tensors
;
std
::
unordered_set
<
FunctionRef
,
air
::
NodeHash
,
air
::
NodeEqual
>
imm_ops
;
std
::
unordered_set
<
FunctionRef
,
air
::
NodeHash
,
air
::
NodeEqual
>
imm_ops
;
...
@@ -557,8 +551,8 @@ class ThreeAddressExprMutator : public IRMutator {
...
@@ -557,8 +551,8 @@ class ThreeAddressExprMutator : public IRMutator {
Array
<
Expr
>
shape_
;
Array
<
Expr
>
shape_
;
std
::
unordered_map
<
size_t
,
std
::
pair
<
Expr
,
Expr
>>
common_exprs_
;
// hash value -> <match expr, replace expr>
std
::
unordered_map
<
size_t
,
std
::
pair
<
Expr
,
Expr
>>
common_exprs_
;
// hash value -> <match expr, replace expr>
std
::
unordered_map
<
FunctionRef
,
size_t
,
air
::
NodeHash
,
air
::
NodeEqual
>
// imm tensor -> hash value of the expr in the tensor
imm2hash_
;
// imm tensor -> hash value of the expr in the tensor
std
::
unordered_map
<
FunctionRef
,
size_t
,
air
::
NodeHash
,
air
::
NodeEqual
>
imm2hash_
;
int
level_
{
0
};
int
level_
{
0
};
int
in_call_
{
0
};
int
in_call_
{
0
};
...
@@ -574,300 +568,493 @@ class ThreeAddressExprMutator : public IRMutator {
...
@@ -574,300 +568,493 @@ class ThreeAddressExprMutator : public IRMutator {
ExprHasher
hasher_
;
ExprHasher
hasher_
;
};
};
Expr
CallPureIntrinsic
(
const
std
::
string
&
name
,
const
Array
<
Expr
>
&
args
,
const
Type
type
)
{
Expr
ThreeAddressExprMutator
::
Mutate
(
Expr
expr
)
{
return
Call
::
make
(
type
,
name
,
args
,
Call
::
CallType
::
PureIntrinsic
);
level_
++
;
expr_stack
.
push_back
(
expr
);
Expr
ret
=
IRMutator
::
Mutate
(
expr
);
expr_stack
.
pop_back
();
level_
--
;
return
ret
;
}
}
// Match instructions by dynamic programming on the tree
int
ThreeAddressExprMutator
::
ct_
=
0
;
class
InstructionMatcher
{
class
InstructionSelector
{
public:
public:
void
Match
(
const
Expr
value
)
{
InstructionSelector
(
ThreeAddressExprMutator
&
mutator
,
std
::
list
<
Expr
>
&
exprs
,
int
max_score
=
-
1
;
std
::
unordered_map
<
const
Object
*
,
std
::
string
>
&
notation_map
,
int
max_i
=
-
1
;
std
::
unordered_map
<
const
Object
*
,
bool
>
&
sign_map
)
:
mutator_
(
mutator
),
exprs_
(
exprs
),
notation_map_
(
notation_map
),
sign_map_
(
sign_map
)
{}
~
InstructionSelector
()
=
default
;
// try patterns
Expr
Mutate
(
Expr
expr
)
{
for
(
size_t
i
=
0
;
i
<
ins_pattern
.
size
();
++
i
)
{
if
(
const
Mul
*
op
=
expr
.
as
<
Mul
>
())
{
int
score_
=
ins_pattern
[
i
].
score_func
(
value
);
return
Mutate_
(
op
,
expr
);
if
(
score_
>
max_score
)
{
max_score
=
score_
;
max_i
=
static_cast
<
int
>
(
i
);
}
}
if
(
const
Cast
*
op
=
expr
.
as
<
Cast
>
())
{
return
Mutate_
(
op
,
expr
);
}
}
if
(
const
Select
*
op
=
expr
.
as
<
Select
>
())
{
score
=
max_score
;
return
Mutate_
(
op
,
expr
);
choice
=
max_i
;
}
return
expr
;
}
}
int
score
;
int
choice
;
const
int
NORMAL
=
20
;
const
int
PRIOR
=
50
;
const
int
UNMATCH
=
-
1
;
air
::
arith
::
PVar
<
Expr
>
x
,
y
,
z
,
w
;
air
::
arith
::
PVar
<
Type
>
pt
;
air
::
arith
::
PVar
<
Floating
>
c1
,
c2
;
std
::
vector
<
ExpressionPattern
>
ins_pattern
{
// vmadd [Xd] = [Xn] * [Xd] + [Xm]
// vmadd [Xd] = [Xn] * [Xd] + [Xm]
// vmla [Xd] = [Xn] * [Xm] + [Xd]
// vaxpy [Xd] = Xm * [Xn] + [Xd]
ExpressionPattern
{
Expr
Mutate_
(
const
Mul
*
op
,
const
Expr
&
e
)
{
2
,
std
::
string
root
=
notation_map_
.
at
(
e
.
get
());
[
&
,
this
](
const
Expr
&
expr
)
->
int
{
if
(
root
!=
Add
::
_type_key
&&
root
!=
Sub
::
_type_key
)
{
if
(((
x
*
y
+
z
).
Match
(
expr
)
||
(
z
+
x
*
y
).
Match
(
expr
))
&&
return
e
;
(
!
is_constant
(
x
.
Eval
())
&&
!
is_constant
(
y
.
Eval
())
&&
!
is_constant
(
z
.
Eval
())))
{
}
return
PRIOR
;
bool
is_left_constant
=
is_constant
(
op
->
a
);
}
bool
is_right_constant
=
is_constant
(
op
->
b
);
return
UNMATCH
;
if
(
is_left_constant
&&
is_right_constant
)
{
},
return
e
;
[
&
,
this
](
const
Expr
&
expr
,
ThreeAddressExprMutator
&
mutator
)
->
Expr
{
}
CHECK
(((
x
*
y
+
z
)).
Match
(
expr
)
||
(
z
+
x
*
y
).
Match
(
expr
));
Expr
expr
=
GetIndexOfPairExprForMul
(
e
);
if
(
expr
.
same_as
(
e
))
{
Expr
x_eval
=
mutator
.
Mutate
(
x
.
Eval
());
return
e
;
Expr
y_eval
=
mutator
.
Mutate
(
y
.
Eval
());
}
Expr
z_eval
=
mutator
.
Mutate
(
z
.
Eval
());
Array
<
Expr
>
args
;
// make sure elemwise inside
if
(
!
is_left_constant
)
{
if
(
CountVars
(
x_eval
)
!=
CountVars
(
y_eval
)
||
CountVars
(
x_eval
)
!=
CountVars
(
z_eval
))
{
args
.
push_back
(
op
->
a
);
return
mutator
.
MutateWithoutSelection
(
x_eval
*
y_eval
+
z_eval
);
}
if
(
mutator
.
IsTmpTensor
(
x_eval
))
{
return
mutator
.
AssignTmp
(
x_eval
,
CallPureIntrinsic
(
"vmadd"
,
{
y_eval
,
z_eval
,
x_eval
},
x_eval
.
type
()));
}
else
if
(
mutator
.
IsTmpTensor
(
y_eval
))
{
return
mutator
.
AssignTmp
(
y_eval
,
CallPureIntrinsic
(
"vmadd"
,
{
x_eval
,
z_eval
,
y_eval
},
y_eval
.
type
()));
}
else
if
(
mutator
.
IsTmpTensor
(
z_eval
))
{
return
mutator
.
AssignTmp
(
z_eval
,
CallPureIntrinsic
(
"vmla"
,
{
x_eval
,
y_eval
,
z_eval
},
z_eval
.
type
()));
}
else
{
}
else
{
return
mutator
.
MutateWithoutSelection
(
x_eval
*
y_eval
+
z_eval
);
args
.
push_back
(
op
->
b
);
}
}
}},
args
.
push_back
(
expr
);
if
(
!
is_right_constant
)
{
// vmaddrelu [Xd] = max([Xn] * [Xd] + [Xm], 0)
args
.
push_back
(
op
->
b
);
ExpressionPattern
{
2
,
[
&
,
this
](
const
Expr
expr
)
->
int
{
if
(((
max
(
x
*
y
+
z
,
c1
)).
Match
(
expr
)
||
(
max
(
z
+
x
*
y
,
c1
)).
Match
(
expr
)
||
(
max
(
c1
,
x
*
y
+
z
)).
Match
(
expr
)
||
(
max
(
c1
,
z
+
x
*
y
)).
Match
(
expr
))
&&
c1
.
Eval
()
->
value
==
0.0
&&
(
!
is_constant
(
x
.
Eval
())
&&
!
is_constant
(
y
.
Eval
())
&&
!
is_constant
(
z
.
Eval
())))
{
return
PRIOR
;
}
return
UNMATCH
;
},
[
&
,
this
](
const
Expr
expr
,
ThreeAddressExprMutator
&
mutator
)
->
Expr
{
CHECK
((
max
(
x
*
y
+
z
,
c1
)).
Match
(
expr
)
||
(
max
(
z
+
x
*
y
,
c1
)).
Match
(
expr
)
||
(
max
(
c1
,
x
*
y
+
z
)).
Match
(
expr
)
||
(
max
(
c1
,
z
+
x
*
y
)).
Match
(
expr
));
Expr
x_eval
=
mutator
.
Mutate
(
x
.
Eval
());
Expr
y_eval
=
mutator
.
Mutate
(
y
.
Eval
());
Expr
z_eval
=
mutator
.
Mutate
(
z
.
Eval
());
// check elemwise
if
(
CountVars
(
x_eval
)
!=
CountVars
(
y_eval
)
||
CountVars
(
x_eval
)
!=
CountVars
(
z_eval
))
{
return
mutator
.
MutateWithoutSelection
(
x_eval
*
y_eval
+
z_eval
);
}
if
(
mutator
.
IsTmpTensor
(
x_eval
)
||
x_eval
.
same_as
(
x
.
Eval
()))
{
return
mutator
.
AssignTmp
(
x_eval
,
CallPureIntrinsic
(
"vmaddrelu"
,
{
y_eval
,
z_eval
,
x_eval
},
x_eval
.
type
()));
}
else
if
(
mutator
.
IsTmpTensor
(
y_eval
)
||
y_eval
.
same_as
(
y
.
Eval
()))
{
return
mutator
.
AssignTmp
(
y_eval
,
CallPureIntrinsic
(
"vmaddrelu"
,
{
x_eval
,
z_eval
,
y_eval
},
y_eval
.
type
()));
}
else
{
}
else
{
return
mutator
.
MutateWithoutSelection
(
max
(
x_eval
*
y_eval
+
z_eval
,
c1
.
Eval
()));
args
.
push_back
(
op
->
a
);
}
return
Call
::
make
(
op
->
type
,
!
is_left_constant
&&
!
is_right_constant
?
"vmadd"
:
"vaxpy"
,
args
,
Call
::
CallType
::
PureIntrinsic
);
}
}
}},
// vaxpy [Xd] = Xm * [Xn] + [Xd]
// vrelu [Xd] = max([Xn], 0)
ExpressionPattern
{
// vmaddrelu [Xd] = max(vmadd [Xd], 0)
2
,
Expr
Mutate_
(
const
Max
*
op
,
const
Expr
&
e
)
{
[
&
,
this
](
const
Expr
expr
)
->
int
{
bool
is_left_zero
=
isZero
(
op
->
a
);
if
(((
c1
*
x
+
y
).
Match
(
expr
)
||
(
x
*
c1
+
y
).
Match
(
expr
)
||
(
y
+
c1
*
x
).
Match
(
expr
)
||
bool
is_right_zero
=
IsZero
(
op
->
b
);
(
y
+
c1
*
x
).
Match
(
expr
))
&&
if
(
!
is_left_zero
&&
!
is_right_zero
)
{
(
!
is_constant
(
x
.
Eval
())
&&
!
is_constant
(
y
.
Eval
())))
{
return
e
;
return
PRIOR
;
}
}
Expr
expr
=
op
->
a
;
return
UNMATCH
;
if
(
is_left_zero
)
{
},
expr
=
op
->
b
;
[
&
,
this
](
const
Expr
expr
,
ThreeAddressExprMutator
&
mutator
)
->
Expr
{
}
CHECK
((
c1
*
x
+
y
).
Match
(
expr
)
||
(
x
*
c1
+
y
).
Match
(
expr
)
||
(
y
+
c1
*
x
).
Match
(
expr
)
||
(
y
+
c1
*
x
).
Match
(
expr
));
if
(
const
Call
*
call
=
expr
.
as
<
Call
>
())
{
Expr
x_eval
=
mutator
.
Mutate
(
x
.
Eval
());
if
(
call
->
call_type
==
Call
::
CallType
::
PureIntrinsic
&&
call
->
name
==
"vmadd"
)
{
Expr
y_eval
=
mutator
.
Mutate
(
y
.
Eval
());
return
Call
::
make
(
op
->
type
,
"vmaddrelu"
,
call
->
args
,
Call
::
CallType
::
PureIntrinsic
);
// check elemwise
}
if
(
CountVars
(
x_eval
)
!=
CountVars
(
y_eval
))
{
}
return
mutator
.
MutateWithoutSelection
(
c1
.
Eval
()
*
x_eval
+
y_eval
);
return
Call
::
make
(
op
->
type
,
"relu"
,
{
expr
},
Call
::
CallType
::
PureIntrinsic
);
}
}
if
(
mutator
.
IsTmpTensor
(
y_eval
)
||
y_eval
.
same_as
(
y
.
Eval
()))
{
// int32 floor/ceil/round/trunc() --> floor/ceil/round/trunc()
return
mutator
.
AssignTmp
(
y_eval
,
CallPureIntrinsic
(
"vaxpy"
,
{
x_eval
,
y_eval
,
c1
.
Eval
()},
y_eval
.
type
()));
// float(cc1) -> a[i] = cc1; cast(a[i])
Expr
Mutate_
(
const
Cast
*
op
,
const
Expr
&
e
)
{
if
(
op
->
type
.
is_int
()
&&
op
->
value
->
IsInstance
<
Call
>
())
{
const
Call
*
call
=
op
->
value
.
as
<
Call
>
();
if
(
call
->
name
!=
"floor"
&&
call
->
name
!=
"ceil"
&&
call
->
name
!=
"round"
&&
call
->
name
!=
"trunc"
)
{
return
e
;
}
if
(
op
->
type
==
call
->
type
)
{
return
op
->
value
;
}
else
{
}
else
{
return
mutator
.
MutateWithoutSelection
(
c1
.
Eval
()
*
x_eval
+
y_eval
);
return
Call
::
make
(
op
->
type
,
call
->
name
,
call
->
args
,
call
->
call_type
,
call
->
func
,
call
->
value_index
);
}
}
if
(
op
->
type
.
is_float
()
&&
op
->
value
->
IsInstance
<
Variable
>
())
{
return
Cast
::
make
(
op
->
type
,
mutator_
.
AllocateTmp
(
op
->
value
));
}
return
e
;
}
}
}},
// vrelu [Xd] = max([Xn], 0)
Expr
Mutate_
(
const
Select
*
op
,
const
Expr
&
e
)
{
ExpressionPattern
{
1
,
if
(
const
Not
*
notCond
=
op
->
condition
.
as
<
Not
>
())
{
[
&
,
this
](
const
Expr
expr
)
->
int
{
return
Select
::
make
(
notCond
->
a
,
op
->
false_value
,
op
->
true_value
);
if
(((
max
(
x
,
c1
)).
Match
(
expr
)
||
(
max
(
c1
,
x
)).
Match
(
expr
))
&&
c1
.
Eval
()
->
value
==
0.0
&&
}
!
is_constant
(
x
.
Eval
())
&&
x
.
Eval
().
type
()
==
Float
(
16
,
1
))
{
if
(
const
And
*
andCond
=
op
->
condition
.
as
<
And
>
())
{
return
NORMAL
;
Expr
tmpExpr
=
Select
::
make
(
andCond
->
a
,
op
->
true_value
,
op
->
false_value
);
}
return
Select
::
make
(
andCond
->
b
,
tmpExpr
,
op
->
false_value
);
return
UNMATCH
;
}
},
if
(
const
Or
*
orCond
=
op
->
condition
.
as
<
Or
>
())
{
Expr
tmpExpr
=
Select
::
make
(
orCond
->
a
,
op
->
true_value
,
op
->
false_value
);
[
&
,
this
](
const
Expr
expr
,
ThreeAddressExprMutator
&
mutator
)
->
Expr
{
return
Select
::
make
(
orCond
->
b
,
op
->
true_value
,
tmpExpr
);
CHECK
(((
max
(
x
,
c1
)).
Match
(
expr
)
||
(
max
(
c1
,
x
)).
Match
(
expr
)));
}
Expr
x_eval
=
mutator
.
Mutate
(
x
.
Eval
());
return
e
;
return
mutator
.
Mutate
(
CallPureIntrinsic
(
"relu"
,
{
x_eval
},
x_eval
.
type
()));
}
}},
private:
// adds [Xd] = ([Xn] + [Yn]) + imm -> [Xn] + ([Yn] + imm)
Expr
GetIndexOfPairExprForMul
(
const
Expr
&
expr
)
{
ExpressionPattern
{
1
,
Expr
ret_expr
=
expr
;
[
&
,
this
](
const
Expr
expr
)
->
int
{
bool
pos
=
sign_map_
.
at
(
expr
.
get
());
if
((((
x
-
y
)
+
c1
).
Match
(
expr
)
||
(
c1
+
(
x
-
y
)).
Match
(
expr
)
||
((
x
+
y
)
+
c1
).
Match
(
expr
)
||
int
dim
=
CountVars
(
expr
);
(
c1
+
(
x
+
y
)).
Match
(
expr
))
&&
for
(
auto
iter
=
exprs_
.
rbegin
();
iter
!=
exprs_
.
rend
();
++
iter
)
{
!
is_constant
(
x
.
Eval
())
&&
!
is_constant
(
y
.
Eval
()))
{
if
((
sign_map_
.
at
((
*
iter
).
get
())
!=
pos
)
||
is_constant
(
*
iter
)
||
(
iter
->
same_as
(
expr
)))
{
return
NORMAL
;
continue
;
}
}
return
UNMATCH
;
if
(
CountVars
(
*
iter
)
>
dim
)
{
},
continue
;
}
[
&
,
this
](
Expr
expr
,
ThreeAddressExprMutator
&
mutator
)
->
Expr
{
ret_expr
=
*
iter
;
if
(((
x
-
y
)
+
c1
).
Match
(
expr
)
||
(
c1
+
(
x
-
y
)).
Match
(
expr
))
{
exprs_
.
remove_if
([
&
ret_expr
](
Expr
e
)
{
return
e
.
same_as
(
ret_expr
);
});
Expr
x_eval
=
mutator
.
Mutate
(
x
.
Eval
());
break
;
Expr
y_eval
=
mutator
.
Mutate
(
y
.
Eval
());
}
return
mutator
.
Mutate
(
x_eval
+
(
c1
.
Eval
()
-
y_eval
));
return
ret_expr
;
}
}
if
(((
x
+
y
)
+
c1
).
Match
(
expr
)
||
(
c1
+
(
x
+
y
)).
Match
(
expr
))
{
Expr
x_eval
=
mutator
.
Mutate
(
x
.
Eval
());
ThreeAddressExprMutator
&
mutator_
;
Expr
y_eval
=
mutator
.
Mutate
(
y
.
Eval
());
std
::
list
<
Expr
>
&
exprs_
;
return
mutator
.
Mutate
(
x_eval
+
(
y_eval
+
c1
.
Eval
()));
std
::
unordered_map
<
const
Object
*
,
std
::
string
>
&
notation_map_
;
std
::
unordered_map
<
const
Object
*
,
bool
>
&
sign_map_
;
};
class
ExprOptMutator
:
public
IRMutator
{
public:
ExprOptMutator
(
ThreeAddressExprMutator
&
mutator
)
:
mutator_
(
mutator
)
{}
~
ExprOptMutator
()
override
=
default
;
Expr
Mutate
(
Expr
expr
)
{
expr
=
IRMutator
::
Mutate
(
expr
);
exprs_
.
sort
([](
Expr
&
e1
,
Expr
&
e2
)
->
bool
{
int
dim1
=
CountVars
(
e1
);
int
dim2
=
CountVars
(
e2
);
if
(
dim1
==
dim2
)
{
return
!
e1
->
IsInstance
<
Mul
>
();
}
return
dim1
<
dim2
;
});
InstructionSelector
selector
(
mutator_
,
exprs_
,
notation_map_
,
sign_map_
);
for
(
auto
iter
=
exprs_
.
rbegin
();
iter
!=
exprs_
.
rend
();
++
iter
)
{
*
iter
=
selector
.
Mutate
(
*
iter
);
}
}
expr
=
RebuildExpr
();
return
expr
;
return
expr
;
}},
}
// int32 floor/ceil/round/trunc() --> floor/ceil/round/trunc()
Expr
Mutate_
(
const
Select
*
op
,
const
Expr
&
e
)
{
ExpressionPattern
{
InitExprStatusIfNeed
(
e
);
1
,
Expr
expr
=
Select
::
make
(
op
->
condition
,
ExprOptMutator
(
mutator_
).
Mutate
(
op
->
true_value
),
[
&
,
this
](
const
Expr
expr
)
->
int
{
ExprOptMutator
(
mutator_
).
Mutate
(
op
->
false_value
));
if
(((
cast
(
pt
,
call_floor
(
x
))).
Match
(
expr
)
&&
pt
.
Eval
().
is_int
())
||
exprs_
.
push_back
(
expr
);
((
cast
(
pt
,
call_ceil
(
x
))).
Match
(
expr
)
&&
pt
.
Eval
().
is_int
())
||
return
expr
;
((
cast
(
pt
,
call_round
(
x
))).
Match
(
expr
)
&&
pt
.
Eval
().
is_int
())
||
}
((
cast
(
pt
,
call_trunc
(
x
))).
Match
(
expr
)
&&
pt
.
Eval
().
is_int
()))
{
return
NORMAL
;
Expr
Mutate_
(
const
Add
*
op
,
const
Expr
&
e
)
{
return
AnalyzeBinaryOpExpr
(
op
,
e
);
}
}
return
UNMATCH
;
Expr
Mutate_
(
const
Sub
*
op
,
const
Expr
&
e
)
{
return
AnalyzeBinaryOpExpr
(
op
,
e
);
}
},
Expr
Mutate_
(
const
Mul
*
op
,
const
Expr
&
e
)
{
[
&
,
this
](
Expr
expr
,
ThreeAddressExprMutator
&
mutator
)
->
Expr
{
bool
is_left_constant
=
is_constant
(
op
->
a
);
if
((
cast
(
pt
,
call_floor
(
x
))).
Match
(
expr
)
&&
pt
.
Eval
().
is_int
())
{
bool
is_right_constant
=
is_constant
(
op
->
b
);
Expr
x_eval
=
mutator
.
Mutate
(
x
.
Eval
());
if
((
is_left_constant
&&
is_left_constant
)
||
(
!
is_left_constant
&&
!
is_right_constant
))
{
return
mutator
.
Mutate
(
Call
::
make
(
expr
.
type
(),
"floor"
,
{
x_eval
},
Call
::
CallType
::
PureIntrinsic
));
return
AnalyzeBinaryOpExpr
(
op
,
e
);
}
}
if
((
cast
(
pt
,
call_ceil
(
x
))).
Match
(
expr
)
&&
pt
.
Eval
().
is_int
())
{
Expr
non_constant_expr
=
is_left_constant
?
op
->
b
:
op
->
a
;
Expr
x_eval
=
mutator
.
Mutate
(
x
.
Eval
());
Expr
constant_expr
=
is_left_constant
?
op
->
a
:
op
->
b
;
return
mutator
.
Mutate
(
Call
::
make
(
expr
.
type
(),
"ceil"
,
{
x_eval
},
Call
::
CallType
::
PureIntrinsic
));
}
if
(
non_constant_expr
->
IsInstance
<
Add
>
())
{
if
((
cast
(
pt
,
call_round
(
x
))).
Match
(
expr
)
&&
pt
.
Eval
().
is_int
())
{
const
Add
*
add
=
non_constant_expr
.
as
<
Add
>
();
Expr
x_eval
=
mutator
.
Mutate
(
x
.
Eval
());
if
(
is_constant
(
add
->
a
)
||
is_constant
(
add
->
b
))
{
return
mutator
.
Mutate
(
Call
::
make
(
expr
.
type
(),
"round"
,
{
x_eval
},
Call
::
CallType
::
PureIntrinsic
));
Expr
expr
=
Add
::
make
(
Mul
::
make
(
constant_expr
,
add
->
a
),
Mul
::
make
(
constant_expr
,
add
->
b
));
}
if
(
notation_map_
.
find
(
e
.
get
())
==
notation_map_
.
end
())
{
if
((
cast
(
pt
,
call_trunc
(
x
))).
Match
(
expr
)
&&
pt
.
Eval
().
is_int
())
{
notation_map_
[
expr
.
get
()]
=
notation_map_
[
e
.
get
()];
Expr
x_eval
=
mutator
.
Mutate
(
x
.
Eval
());
}
return
mutator
.
Mutate
(
Call
::
make
(
expr
.
type
(),
"trunc"
,
{
x_eval
},
Call
::
CallType
::
PureIntrinsic
));
if
(
sign_map_
.
find
(
e
.
get
())
!=
sign_map_
.
end
())
{
sign_map_
[
expr
.
get
()]
=
sign_map_
[
e
.
get
()];
}
return
IRMutator
::
Mutate
(
expr
);
}
}
if
(
non_constant_expr
->
IsInstance
<
Sub
>
())
{
const
Sub
*
sub
=
non_constant_expr
.
as
<
Sub
>
();
if
(
is_constant
(
sub
->
a
)
||
is_constant
(
sub
->
b
))
{
Expr
expr
=
Sub
::
make
(
Mul
::
make
(
constant_expr
,
sub
->
a
),
Mul
::
make
(
constant_expr
,
sub
->
b
));
if
(
notation_map_
.
find
(
e
.
get
())
==
notation_map_
.
end
())
{
notation_map_
[
expr
.
get
()]
=
notation_map_
[
e
.
get
()];
}
if
(
sign_map_
.
find
(
e
.
get
())
!=
sign_map_
.
end
())
{
sign_map_
[
expr
.
get
()]
=
sign_map_
[
e
.
get
()];
}
}
return
IRMutator
::
Mutate
(
expr
);
}
}
return
AnalyzeBinaryOpExpr
(
op
,
e
);
}
// Imm / x -> y = Imm; y/x
Expr
Mutate_
(
const
Div
*
op
,
const
Expr
&
e
)
{
if
(
is_constant
(
op
->
a
)
&&
!
is_constant
(
op
->
b
))
{
Expr
expr
=
Div
::
make
(
mutator_
.
AllocateTmp
(
op
->
a
),
op
->
b
);
const
Div
*
div
=
expr
.
as
<
Div
>
();
return
AnalyzeBinaryOpExpr
(
div
,
expr
);
}
return
AnalyzeBinaryOpExpr
(
op
,
e
);
}
Expr
Mutate_
(
const
Mod
*
op
,
const
Expr
&
e
)
{
return
AnalyzeBinaryOpExpr
(
op
,
e
);
}
Expr
Mutate_
(
const
FloorDiv
*
op
,
const
Expr
&
e
)
{
return
AnalyzeBinaryOpExpr
(
op
,
e
);
}
Expr
Mutate_
(
const
FloorMod
*
op
,
const
Expr
&
e
)
{
return
AnalyzeBinaryOpExpr
(
op
,
e
);
}
Expr
Mutate_
(
const
Min
*
op
,
const
Expr
&
e
)
{
return
AnalyzeBinaryOpExpr
(
op
,
e
);
}
Expr
Mutate_
(
const
Max
*
op
,
const
Expr
&
e
)
{
return
AnalyzeBinaryOpExpr
(
op
,
e
);
}
Expr
Mutate_
(
const
EQ
*
op
,
const
Expr
&
e
)
{
return
AnalyzeBinaryOpExpr
(
op
,
e
);
}
Expr
Mutate_
(
const
NE
*
op
,
const
Expr
&
e
)
{
return
AnalyzeBinaryOpExpr
(
op
,
e
);
}
Expr
Mutate_
(
const
LT
*
op
,
const
Expr
&
e
)
{
return
AnalyzeBinaryOpExpr
(
op
,
e
);
}
Expr
Mutate_
(
const
LE
*
op
,
const
Expr
&
e
)
{
return
AnalyzeBinaryOpExpr
(
op
,
e
);
}
Expr
Mutate_
(
const
GT
*
op
,
const
Expr
&
e
)
{
return
AnalyzeBinaryOpExpr
(
op
,
e
);
}
Expr
Mutate_
(
const
GE
*
op
,
const
Expr
&
e
)
{
return
AnalyzeBinaryOpExpr
(
op
,
e
);
}
Expr
Mutate_
(
const
And
*
op
,
const
Expr
&
e
)
{
return
AnalyzeBinaryOpExpr
(
op
,
e
);
}
Expr
Mutate_
(
const
Or
*
op
,
const
Expr
&
e
)
{
return
AnalyzeBinaryOpExpr
(
op
,
e
);
}
Expr
Mutate_
(
const
Let
*
op
,
const
Expr
&
e
)
{
InitExprStatusIfNeed
(
e
);
Expr
expr
=
Let
::
make
(
op
->
var
,
ExprOptMutator
(
mutator_
).
Mutate
(
op
->
value
),
ExprOptMutator
(
mutator_
).
Mutate
(
op
->
body
));
exprs_
.
push_back
(
expr
);
return
expr
;
return
expr
;
}},
}
// float(cc1) -> a[i] = cc1; cast(a[i])
Expr
Mutate_
(
const
Cast
*
op
,
const
Expr
&
e
)
{
ExpressionPattern
{
1
,
InitExprStatusIfNeed
(
e
);
[
&
,
this
](
const
Expr
expr
)
->
int
{
Expr
expr
=
Cast
::
make
(
op
->
type
,
ExprOptMutator
(
mutator_
).
Mutate
(
op
->
value
));
if
((
cast
(
pt
,
x
)).
Match
(
expr
)
&&
pt
.
Eval
().
is_float
()
&&
x
.
Eval
().
as
<
Variable
>
())
{
exprs_
.
push_back
(
expr
);
return
NORMAL
;
return
expr
;
}
Expr
Mutate_
(
const
Not
*
op
,
const
Expr
&
e
)
{
InitExprStatusIfNeed
(
e
);
Expr
expr
=
Not
::
make
(
ExprOptMutator
(
mutator_
).
Mutate
(
op
->
a
));
exprs_
.
push_back
(
expr
);
return
expr
;
}
}
return
UNMATCH
;
},
[
&
,
this
](
Expr
expr
,
ThreeAddressExprMutator
&
mutator
)
->
Expr
{
Expr
Mutate_
(
const
Load
*
op
,
const
Expr
&
e
)
{
if
((
cast
(
pt
,
x
)).
Match
(
expr
)
&&
pt
.
Eval
().
is_float
()
&&
x
.
Eval
().
as
<
Variable
>
())
{
InitExprStatusIfNeed
(
e
);
Expr
tmp
=
mutator
.
AllocateTmp
(
x
.
Eval
());
Expr
expr
=
Load
::
make
(
op
->
type
,
op
->
buffer_var
,
ExprOptMutator
(
mutator_
).
Mutate
(
op
->
index
),
return
mutator
.
Mutate
(
Cast
::
make
(
expr
.
type
(),
tmp
));
ExprOptMutator
(
mutator_
).
Mutate
(
op
->
predicate
));
exprs_
.
push_back
(
expr
);
return
expr
;
}
Expr
Mutate_
(
const
Reduce
*
op
,
const
Expr
&
e
)
{
InitExprStatusIfNeed
(
e
);
Array
<
Expr
>
source
;
for
(
Expr
src
:
op
->
source
)
{
source
.
push_back
(
ExprOptMutator
(
mutator_
).
Mutate
(
src
));
}
}
Expr
expr
=
Reduce
::
make
(
op
->
combiner
,
source
,
op
->
axis
,
ExprOptMutator
(
mutator_
).
Mutate
(
op
->
condition
),
op
->
value_index
);
exprs_
.
push_back
(
expr
);
return
expr
;
return
expr
;
}},
}
// Imm / x -> y = Imm; y/x
Expr
Mutate_
(
const
Shuffle
*
op
,
const
Expr
&
e
)
{
ExpressionPattern
{
1
,
InitExprStatusIfNeed
(
e
);
[
&
,
this
](
const
Expr
expr
)
->
int
{
Array
<
Expr
>
vectors
;
if
(
div
(
c1
,
y
).
Match
(
expr
)
&&
is_constant
(
c1
.
Eval
())
&&
!
is_constant
(
y
.
Eval
()))
{
for
(
Expr
v
:
op
->
vectors
)
{
return
NORMAL
;
vectors
.
push_back
(
ExprOptMutator
(
mutator_
).
Mutate
(
v
));
}
Array
<
Expr
>
indices
;
for
(
Expr
indic
:
op
->
indices
)
{
indices
.
push_back
(
ExprOptMutator
(
mutator_
).
Mutate
(
indic
));
}
Expr
expr
=
Shuffle
::
make
(
vectors
,
indices
);
exprs_
.
push_back
(
expr
);
return
expr
;
}
Expr
Mutate_
(
const
Call
*
op
,
const
Expr
&
e
)
{
InitExprStatusIfNeed
(
e
);
Array
<
Expr
>
args
;
for
(
Expr
arg
:
op
->
args
)
{
args
.
push_back
(
ExprOptMutator
(
mutator_
).
Mutate
(
arg
));
}
Expr
expr
=
Call
::
make
(
op
->
type
,
op
->
name
,
args
,
op
->
call_type
,
op
->
func
,
op
->
value_index
);
mutator_
.
AddBroadCastCallIfNeed
(
op
,
expr
);
exprs_
.
push_back
(
expr
);
return
expr
;
}
}
return
UNMATCH
;
},
[
&
,
this
](
const
Expr
expr
,
ThreeAddressExprMutator
&
mutator
)
->
Expr
{
Expr
Mutate_
(
const
Ramp
*
op
,
const
Expr
&
e
)
{
CHECK
(
div
(
c1
,
y
).
Match
(
expr
)
&&
is_constant
(
c1
.
Eval
())
&&
!
is_constant
(
y
.
Eval
()));
InitExprStatusIfNeed
(
e
);
Expr
x_eval
=
mutator
.
AllocateTmp
(
c1
.
Eval
());
Expr
expr
=
return
mutator
.
Mutate
(
Div
::
make
(
x_eval
,
y
.
Eval
()));
Ramp
::
make
(
ExprOptMutator
(
mutator_
).
Mutate
(
op
->
base
),
ExprOptMutator
(
mutator_
).
Mutate
(
op
->
stride
),
op
->
lanes
);
}},
exprs_
.
push_back
(
expr
);
return
expr
;
}
ExpressionPattern
{
1
,
Expr
Mutate_
(
const
Broadcast
*
op
,
const
Expr
&
e
)
{
[
&
,
this
](
const
Expr
expr
)
->
int
{
InitExprStatusIfNeed
(
e
);
if
((
c1
*
(
c2
+
x
)).
Match
(
expr
)
||
(
c1
*
(
c2
-
x
)).
Match
(
expr
))
{
Expr
expr
=
Broadcast
::
make
(
ExprOptMutator
(
mutator_
).
Mutate
(
op
->
value
),
op
->
lanes
);
return
NORMAL
;
exprs_
.
push_back
(
expr
);
return
expr
;
}
}
return
UNMATCH
;
},
[
&
,
this
](
Expr
expr
,
ThreeAddressExprMutator
&
mutator
)
->
Expr
{
Expr
Mutate_
(
const
IntImm
*
op
,
const
Expr
&
e
)
{
return
SaveAutomicExpr
(
e
);
}
if
((
c1
*
(
c2
+
x
)).
Match
(
expr
))
{
return
mutator
.
Mutate
(
Simplify_cce
(
x
.
Eval
()
*
c1
.
Eval
()
+
c1
.
Eval
()
*
c2
.
Eval
()));
Expr
Mutate_
(
const
UIntImm
*
op
,
const
Expr
&
e
)
{
return
SaveAutomicExpr
(
e
);
}
Expr
Mutate_
(
const
FloatImm
*
op
,
const
Expr
&
e
)
{
return
SaveAutomicExpr
(
e
);
}
Expr
Mutate_
(
const
StringImm
*
op
,
const
Expr
&
e
)
{
return
SaveAutomicExpr
(
e
);
}
Expr
Mutate_
(
const
Variable
*
op
,
const
Expr
&
e
)
{
return
SaveAutomicExpr
(
e
);
}
private:
void
InitExprStatusIfNeed
(
const
Expr
&
e
)
{
const
Object
*
object_e
=
e
.
get
();
if
(
notation_map_
.
find
(
object_e
)
==
notation_map_
.
end
())
{
notation_map_
[
object_e
]
=
e
->
GetTypeKey
();
}
}
if
((
c1
*
(
c2
-
x
)).
Match
(
expr
))
{
if
(
sign_map_
.
find
(
object_e
)
==
sign_map_
.
end
())
{
return
mutator
.
Mutate
(
Simplify_cce
(
c1
.
Eval
()
*
c2
.
Eval
()
-
x
.
Eval
()
*
c1
.
Eval
()));
sign_map_
[
object_e
]
=
true
;
}
}
bool
IsNewRoot
(
const
Expr
&
e
)
{
CHECK
(
notation_map_
.
find
(
e
.
get
())
!=
notation_map_
.
end
());
std
::
string
root
=
notation_map_
[
e
.
get
()];
std
::
string
type_key
=
e
->
GetTypeKey
();
return
!
((
root
==
Add
::
_type_key
||
root
==
Sub
::
_type_key
)
&&
(
type_key
==
Add
::
_type_key
||
type_key
==
Sub
::
_type_key
))
||
!
((
root
==
Mul
::
_type_key
||
root
==
Div
::
_type_key
)
&&
(
type_key
==
Mul
::
_type_key
||
type_key
==
Div
::
_type_key
));
}
template
<
typename
T
>
Expr
AnalyzeBinaryOpExpr
(
const
T
*
op
,
const
Expr
&
e
)
{
InitExprStatusIfNeed
(
e
);
const
Object
*
object_e
=
e
.
get
();
std
::
string
root_of_e
=
notation_map_
[
object_e
];
bool
pos_of_e
=
sign_map_
[
object_e
];
std
::
string
type_key
=
e
->
GetTypeKey
();
Expr
expr
=
e
;
if
(
IsNewRoot
(
e
))
{
expr
=
T
::
make
(
ExprOptMutator
(
mutator_
).
Mutate
(
op
->
a
),
ExprOptMutator
(
mutator_
).
Mutate
(
op
->
b
));
notation_map_
[
expr
.
get
()]
=
root_of_e
;
sign_map_
[
expr
.
get
()]
=
pos_of_e
;
exprs_
.
push_back
(
expr
);
}
else
{
notation_map_
[
op
->
a
.
get
()]
=
root_of_e
;
notation_map_
[
op
->
b
.
get
()]
=
root_of_e
;
sign_map_
[
op
->
a
.
get
()]
=
pos_of_e
;
sign_map_
[
op
->
b
.
get
()]
=
(
type_key
==
Sub
::
_type_key
||
type_key
==
Div
::
_type_key
)
?
!
pos_of_e
:
pos_of_e
;
expr
=
T
::
make
(
IRMutator
::
Mutate
(
op
->
a
),
IRMutator
::
Mutate
(
op
->
b
));
}
return
expr
;
}
Expr
SaveAutomicExpr
(
const
Expr
&
e
)
{
InitExprStatusIfNeed
(
e
);
exprs_
.
push_back
(
e
);
return
e
;
}
Expr
RebuildExpr
()
{
CHECK
(
!
exprs_
.
empty
());
Expr
expr
=
exprs_
.
front
();
exprs_
.
pop_front
();
while
(
!
exprs_
.
empty
())
{
expr
=
RebuildExpr
(
expr
,
exprs_
.
front
());
exprs_
.
pop_front
();
}
}
return
expr
;
return
expr
;
}},
ExpressionPattern
{
1
,
[
&
,
this
](
const
Expr
expr
)
->
int
{
if
((
select
((
z
||
w
),
x
,
y
)).
Match
(
expr
)
||
(
select
((
z
&&
w
),
x
,
y
)).
Match
(
expr
)
||
(
select
((
!
z
),
x
,
y
)).
Match
(
expr
))
{
return
NORMAL
;
}
}
return
UNMATCH
;
},
[
&
,
this
](
Expr
expr
,
ThreeAddressExprMutator
&
mutator
)
->
Expr
{
Expr
RebuildExpr
(
const
Expr
&
expr1
,
const
Expr
&
expr2
)
{
if
((
select
((
z
||
w
),
x
,
y
)).
Match
(
expr
))
{
Expr
expr
=
expr1
;
Expr
temp_eval
=
mutator
.
Mutate
(
Select
::
make
(
z
.
Eval
(),
x
.
Eval
(),
y
.
Eval
()));
Expr
opnd
=
expr2
;
return
mutator
.
Mutate
(
Select
::
make
(
w
.
Eval
(),
x
.
Eval
(),
temp_eval
));
if
(
sign_map_
[
expr2
.
get
()]
&&
!
sign_map_
[
expr1
.
get
()])
{
expr
=
expr2
;
opnd
=
expr1
;
}
}
if
((
select
((
z
&&
w
),
x
,
y
)).
Match
(
expr
))
{
Expr
temp_eval
=
mutator
.
Mutate
(
Select
::
make
(
z
.
Eval
(),
x
.
Eval
(),
y
.
Eval
()));
if
((
sign_map_
[
expr1
.
get
()]
&&
sign_map_
[
expr2
.
get
()])
||
(
!
sign_map_
[
expr1
.
get
()]
&&
!
sign_map_
[
expr2
.
get
()]))
{
return
mutator
.
Mutate
(
Select
::
make
(
w
.
Eval
(),
temp_eval
,
y
.
Eval
()));
if
(
notation_map_
[
expr1
.
get
()]
==
Add
::
_type_key
||
notation_map_
[
expr1
.
get
()]
==
Sub
::
_type_key
)
{
expr
=
Add
::
make
(
expr
,
opnd
);
}
else
{
expr
=
Mul
::
make
(
expr
,
opnd
);
}
}
else
{
if
(
notation_map_
[
expr1
.
get
()]
==
Add
::
_type_key
||
notation_map_
[
expr1
.
get
()]
==
Sub
::
_type_key
)
{
expr
=
Sub
::
make
(
expr
,
opnd
);
}
else
{
expr
=
Div
::
make
(
expr
,
opnd
);
}
}
if
((
select
((
!
z
),
x
,
y
)).
Match
(
expr
))
{
return
mutator
.
Mutate
(
Select
::
make
(
z
.
Eval
(),
y
.
Eval
(),
x
.
Eval
()));
}
}
notation_map_
[
expr
.
get
()]
=
notation_map_
[
expr1
.
get
()];
sign_map_
[
expr
.
get
()]
=
sign_map_
[
expr1
.
get
()]
||
sign_map_
[
expr2
.
get
()];
return
expr
;
return
expr
;
}}};
}
ThreeAddressExprMutator
&
mutator_
;
std
::
list
<
Expr
>
exprs_
;
std
::
unordered_map
<
const
Object
*
,
std
::
string
>
notation_map_
;
std
::
unordered_map
<
const
Object
*
,
bool
>
sign_map_
;
};
};
Expr
ThreeAddressExprMutator
::
Mutate
(
Expr
expr
)
{
class
LoopMutator
:
public
IRMutator
{
// select instructions
public:
InstructionMatcher
matcher
;
LoopMutator
()
:
loop_level_
(
0
)
{}
matcher
.
Match
(
expr
);
~
LoopMutator
()
override
=
default
;
int
idx
=
matcher
.
choice
;
Expr
ret
;
Stmt
Mutate_
(
const
For
*
op
,
const
Stmt
&
s
)
final
{
level_
++
;
loop_level_
++
;
if
(
idx
<
0
||
disable_selection_
||
level_
<
matcher
.
ins_pattern
[
idx
].
min_level
)
{
loop_vars_
.
push_front
(
op
);
expr_stack
.
push_back
(
expr
);
Stmt
stmt
=
IRMutator
::
Mutate
(
op
->
body
);
ret
=
IRMutator
::
Mutate
(
expr
);
if
(
provides_
.
size
()
==
1
||
provides_
.
front
()
->
args
.
size
()
==
provides_
.
front
()
->
args
.
size
())
{
expr_stack
.
pop_back
();
return
s
;
}
else
{
// match an intrinsic
}
ret
=
matcher
.
ins_pattern
[
idx
].
replace_func
(
expr
,
*
this
);
if
(
!
stmt
->
IsInstance
<
For
>
())
{
provides_
.
sort
([](
const
Provide
*
s1
,
const
Provide
*
s2
)
->
bool
{
return
s1
->
args
.
size
()
<
s2
->
args
.
size
();
});
const
Provide
*
provide
=
provides_
.
back
();
stmt
=
Provide
::
make
(
provide
->
func
,
provide
->
value_index
,
provide
->
value
,
provide
->
args
);
provides_
.
pop_back
();
}
stmt
=
RebuildForStmt
(
loop_vars_
.
back
(),
stmt
);
loop_vars_
.
pop_back
();
loop_level_
--
;
return
stmt
;
}
}
level_
--
;
return
ret
;
}
int
ThreeAddressExprMutator
::
ct_
=
0
;
Stmt
Mutate_
(
const
Provide
*
op
,
const
Stmt
&
s
)
final
{
if
(
!
loop_vars_
.
empty
())
{
provides_
.
push_back
(
op
);
}
return
IRMutator
::
Mutate_
(
op
,
s
);
}
private:
Stmt
RebuildForStmt
(
const
For
*
op
,
Stmt
&
body
)
{
Stmt
stmt
=
body
;
while
(
!
provides_
.
empty
()
&&
op
->
loop_var
.
same_as
(
provides_
.
back
()
->
args
[
0
]))
{
const
Provide
*
second
=
provides_
.
back
();
stmt
=
Block
::
make
(
Provide
::
make
(
second
->
func
,
second
->
value_index
,
second
->
value
,
second
->
args
),
stmt
);
provides_
.
pop_back
();
}
return
For
::
make
(
op
->
loop_var
,
op
->
min
,
op
->
extent
,
op
->
for_type
,
op
->
device_api
,
stmt
);
}
size_t
loop_level_
;
std
::
list
<
const
Provide
*>
provides_
;
std
::
list
<
const
For
*>
loop_vars_
;
};
class
InferUpperBound
{
class
InferUpperBound
{
private:
private:
...
@@ -1202,6 +1389,7 @@ class ThreeAddressStmtMutator : public IRMutator {
...
@@ -1202,6 +1389,7 @@ class ThreeAddressStmtMutator : public IRMutator {
// Bring over the common exprs from previous stage
// Bring over the common exprs from previous stage
mutator
.
SetCommonExpr
(
global_common_expr_
);
mutator
.
SetCommonExpr
(
global_common_expr_
);
}
}
value
=
ExprOptMutator
(
mutator
).
Mutate
(
value
);
value
=
mutator
.
Mutate
(
value
);
value
=
mutator
.
Mutate
(
value
);
if
(
cross_stmt_simplify_
)
{
if
(
cross_stmt_simplify_
)
{
// Take back the common exprs for next stages
// Take back the common exprs for next stages
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录