Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
9c17e45e
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9c17e45e
编写于
7月 03, 2023
作者:
Y
Yuanle Liu
提交者:
GitHub
7月 03, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix conv+bn pattern (#55073)
上级
3f5c2b5f
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
25 addition
and
36 deletion
+25
-36
test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc
test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc
+25
-36
未找到文件。
test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc
浏览文件 @
9c17e45e
...
@@ -183,7 +183,7 @@ TEST(PatternRewrite, FrozenRewritePatternSet) {
...
@@ -183,7 +183,7 @@ TEST(PatternRewrite, FrozenRewritePatternSet) {
2U
);
2U
);
}
}
class
TransposePatternRewrite
class
RedundantTransposeFusePattern
:
public
ir
::
OpRewritePattern
<
paddle
::
dialect
::
TransposeOp
>
{
:
public
ir
::
OpRewritePattern
<
paddle
::
dialect
::
TransposeOp
>
{
public:
public:
using
ir
::
OpRewritePattern
<
paddle
::
dialect
::
TransposeOp
>::
OpRewritePattern
;
using
ir
::
OpRewritePattern
<
paddle
::
dialect
::
TransposeOp
>::
OpRewritePattern
;
...
@@ -265,37 +265,26 @@ class Conv2dBnFusePattern
...
@@ -265,37 +265,26 @@ class Conv2dBnFusePattern
ir
::
Value
bn_scale
=
op
.
scale
();
ir
::
Value
bn_scale
=
op
.
scale
();
ir
::
Value
bn_bias
=
op
.
bias
();
ir
::
Value
bn_bias
=
op
.
bias
();
ir
::
OpResult
bn_mean_result
=
bn_mean
.
dyn_cast
<
ir
::
OpResult
>
();
IR_ENFORCE
(
bn_mean_result
);
ir
::
OpResult
bn_variance_result
=
bn_variance
.
dyn_cast
<
ir
::
OpResult
>
();
IR_ENFORCE
(
bn_variance_result
);
ir
::
OpResult
bn_scale_result
=
bn_scale
.
dyn_cast
<
ir
::
OpResult
>
();
IR_ENFORCE
(
bn_scale_result
);
ir
::
OpResult
bn_bias_result
=
bn_bias
.
dyn_cast
<
ir
::
OpResult
>
();
IR_ENFORCE
(
bn_bias_result
);
// --- deal with filter ---
// --- deal with filter ---
rewriter
.
SetInsertionPoint
(
conv2d_
op
);
rewriter
.
SetInsertionPoint
(
op
);
phi
::
DDim
bn_variance_shape
=
phi
::
DDim
bn_variance_shape
=
bn_variance
.
type
().
dyn_cast
<
paddle
::
dialect
::
DenseTensorType
>
().
dims
();
bn_variance
.
type
().
dyn_cast
<
paddle
::
dialect
::
DenseTensorType
>
().
dims
();
float
epsilon
=
op
.
attribute
<
ir
::
FloatAttribute
>
(
"epsilon"
).
data
();
float
epsilon
=
op
.
attribute
<
ir
::
FloatAttribute
>
(
"epsilon"
).
data
();
paddle
::
dialect
::
FullOp
full_op
=
rewriter
.
Build
<
paddle
::
dialect
::
FullOp
>
(
paddle
::
dialect
::
FullOp
full_op
=
rewriter
.
Build
<
paddle
::
dialect
::
FullOp
>
(
phi
::
vectorize
(
bn_variance_shape
),
epsilon
);
phi
::
vectorize
(
bn_variance_shape
),
epsilon
);
paddle
::
dialect
::
AddOp
add_op
=
rewriter
.
Build
<
paddle
::
dialect
::
AddOp
>
(
paddle
::
dialect
::
AddOp
add_op
=
rewriter
.
Build
<
paddle
::
dialect
::
AddOp
>
(
bn_variance
_result
,
full_op
.
out
());
bn_variance
.
dyn_cast
<
ir
::
OpResult
>
()
,
full_op
.
out
());
paddle
::
dialect
::
SqrtOp
sqrt_op
=
paddle
::
dialect
::
SqrtOp
sqrt_op
=
rewriter
.
Build
<
paddle
::
dialect
::
SqrtOp
>
(
add_op
.
out
());
rewriter
.
Build
<
paddle
::
dialect
::
SqrtOp
>
(
add_op
.
out
());
paddle
::
dialect
::
DivideOp
div_op
=
paddle
::
dialect
::
DivideOp
div_op
=
rewriter
.
Build
<
paddle
::
dialect
::
DivideOp
>
(
bn_scale_result
,
rewriter
.
Build
<
paddle
::
dialect
::
DivideOp
>
(
sqrt_op
.
out
());
bn_scale
.
dyn_cast
<
ir
::
OpResult
>
(),
sqrt_op
.
out
());
// reshape scale
// reshape scale
phi
::
DDim
conv2d_filter_shape
=
ir
::
GetShapeFromValue
(
conv2d_filter
);
phi
::
DDim
conv2d_filter_shape
=
ir
::
GetShapeFromValue
(
conv2d_filter
);
phi
::
DDim
bn_scale_shape
=
phi
::
DDim
bn_scale_shape
=
bn_scale
.
type
().
dyn_cast
<
paddle
::
dialect
::
DenseTensorType
>
().
dims
();
bn_scale
.
type
().
dyn_cast
<
paddle
::
dialect
::
DenseTensorType
>
().
dims
();
std
::
vector
<
int64_t
>
bn_scale_new_shape
(
conv2d_filter_shape
.
size
(),
1
);
std
::
vector
<
int64_t
>
bn_scale_new_shape
(
conv2d_filter_shape
.
size
(),
1
);
bn_scale_new_shape
[
0
]
=
bn_scale_shape
[
0
];
bn_scale_new_shape
[
0
]
=
bn_scale_shape
[
0
];
paddle
::
dialect
::
ReshapeOp
reshape_scale_op
=
paddle
::
dialect
::
ReshapeOp
reshape_scale_op
=
rewriter
.
Build
<
paddle
::
dialect
::
ReshapeOp
>
(
div_op
.
out
(),
rewriter
.
Build
<
paddle
::
dialect
::
ReshapeOp
>
(
div_op
.
out
(),
bn_scale_new_shape
);
bn_scale_new_shape
);
...
@@ -303,39 +292,39 @@ class Conv2dBnFusePattern
...
@@ -303,39 +292,39 @@ class Conv2dBnFusePattern
paddle
::
dialect
::
MultiplyOp
mul_op
=
paddle
::
dialect
::
MultiplyOp
mul_op
=
rewriter
.
Build
<
paddle
::
dialect
::
MultiplyOp
>
(
conv2d_filter_result
,
rewriter
.
Build
<
paddle
::
dialect
::
MultiplyOp
>
(
conv2d_filter_result
,
reshape_scale_op
.
out
());
reshape_scale_op
.
out
());
// TODO(liuyuanle): Use rewriter.
conv2d_op
->
op_operand
(
1
).
set_source
(
mul_op
.
out
());
auto
conv2d_attributes
=
conv2d_op
->
attributes
();
auto
new_conv2d_op
=
rewriter
.
Build
<
paddle
::
dialect
::
Conv2dOp
>
(
conv2d_op
.
input
().
dyn_cast
<
ir
::
OpResult
>
(),
mul_op
.
out
(),
conv2d_attributes
);
// --- deal with bias ---
// --- deal with bias ---
rewriter
.
SetInsertionPoint
(
op
);
paddle
::
dialect
::
MultiplyOp
mul_bias_op
=
paddle
::
dialect
::
MultiplyOp
mul_bias_op
=
rewriter
.
Build
<
paddle
::
dialect
::
MultiplyOp
>
(
bn_mean_result
,
rewriter
.
Build
<
paddle
::
dialect
::
MultiplyOp
>
(
div_op
.
out
());
bn_mean
.
dyn_cast
<
ir
::
OpResult
>
(),
div_op
.
out
());
// new bias --> sub_op.out()
// new bias --> sub_op.out()
paddle
::
dialect
::
SubtractOp
sub_op
=
paddle
::
dialect
::
SubtractOp
sub_op
=
rewriter
.
Build
<
paddle
::
dialect
::
SubtractOp
>
(
bn_bias_result
,
rewriter
.
Build
<
paddle
::
dialect
::
SubtractOp
>
(
mul_bias_op
.
out
());
bn_bias
.
dyn_cast
<
ir
::
OpResult
>
(),
mul_bias_op
.
out
());
// reshape new bias
// reshape new bias
phi
::
DDim
conv2d_out_shape
=
ir
::
GetShapeFromValue
(
conv2d_out
);
phi
::
DDim
new_conv2d_out_shape
=
ir
::
GetShapeFromValue
(
new_conv2d_op
.
out
()
);
std
::
vector
<
int64_t
>
new_bias_new_shape
(
conv2d_out_shape
.
size
(),
1
);
std
::
vector
<
int64_t
>
new_bias_new_shape
(
new_
conv2d_out_shape
.
size
(),
1
);
std
::
string
data_format
=
std
::
string
data_format
=
conv2d_op
.
attribute
<
ir
::
StrAttribute
>
(
"data_format"
).
data
();
new_conv2d_op
.
attribute
<
ir
::
StrAttribute
>
(
"data_format"
).
data
();
IR_ENFORCE
(
data_format
==
"NCHW"
,
"Only support NCHW now."
);
IR_ENFORCE
(
data_format
==
"NCHW"
,
"Only support NCHW now."
);
new_bias_new_shape
[
0
]
=
conv2d_out_shape
[
0
];
new_bias_new_shape
[
0
]
=
new_conv2d_out_shape
[
0
];
new_bias_new_shape
[
1
]
=
conv2d_out_shape
[
1
];
new_bias_new_shape
[
1
]
=
new_conv2d_out_shape
[
1
];
paddle
::
dialect
::
ReshapeOp
reshape_bias_op
=
paddle
::
dialect
::
ReshapeOp
reshape_bias_op
=
rewriter
.
Build
<
paddle
::
dialect
::
ReshapeOp
>
(
sub_op
.
out
(),
rewriter
.
Build
<
paddle
::
dialect
::
ReshapeOp
>
(
sub_op
.
out
(),
new_bias_new_shape
);
new_bias_new_shape
);
paddle
::
dialect
::
AddOp
add_bias_op
=
rewriter
.
Build
<
paddle
::
dialect
::
AddOp
>
(
paddle
::
dialect
::
AddOp
add_bias_op
=
rewriter
.
Build
<
paddle
::
dialect
::
AddOp
>
(
conv2d_out
,
reshape_bias_op
.
out
());
new_conv2d_op
.
out
()
,
reshape_bias_op
.
out
());
auto
next_op
=
ir
::
GetFirstUseOperationForOutput
<
0
>
(
op
);
rewriter
.
ReplaceAllUsesWith
(
next_op
->
operand
(
0
),
add_bias_op
.
out
());
rewriter
.
ReplaceAllUsesWith
(
op
.
out
(
),
add_bias_op
.
out
());
rewriter
.
EraseOp
(
op
);
rewriter
.
EraseOp
(
op
);
rewriter
.
EraseOp
(
conv2d_op
);
return
true
;
return
true
;
}
}
};
};
...
@@ -345,7 +334,7 @@ class TestPass : public ir::Pass {
...
@@ -345,7 +334,7 @@ class TestPass : public ir::Pass {
TestPass
()
:
ir
::
Pass
(
"TestPass"
,
1
)
{}
TestPass
()
:
ir
::
Pass
(
"TestPass"
,
1
)
{}
void
Run
(
ir
::
Operation
*
op
)
override
{
void
Run
(
ir
::
Operation
*
op
)
override
{
ir
::
RewritePatternSet
ps
(
op
->
ir_context
());
ir
::
RewritePatternSet
ps
(
op
->
ir_context
());
ps
.
Add
<
TransposePatternRewrite
>
(
op
->
ir_context
());
ps
.
Add
<
RedundantTransposeFusePattern
>
(
op
->
ir_context
());
ps
.
Add
<
Conv2dBnFusePattern
>
(
op
->
ir_context
());
ps
.
Add
<
Conv2dBnFusePattern
>
(
op
->
ir_context
());
ir
::
FrozenRewritePatternSet
frozen_ps
(
std
::
move
(
ps
));
ir
::
FrozenRewritePatternSet
frozen_ps
(
std
::
move
(
ps
));
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录