Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
04527ee3
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
04527ee3
编写于
12月 27, 2021
作者:
B
baoachun
提交者:
GitHub
12月 27, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add attr check for infer in batch_norm_act mkldnn fuse pass (#38443)
上级
37022482
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
19 addition
and
27 deletion
+19
-27
paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.cc
paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.cc
+7
-15
paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc
...id/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc
+12
-12
未找到文件。
paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.cc
浏览文件 @
04527ee3
...
@@ -67,12 +67,6 @@ FuseBatchNormActOneDNNPass::FuseBatchNormActOneDNNPass() {
...
@@ -67,12 +67,6 @@ FuseBatchNormActOneDNNPass::FuseBatchNormActOneDNNPass() {
.
AddAttr
(
"epsilon"
)
.
AddAttr
(
"epsilon"
)
.
IsNumGE
(
0.0
f
)
.
IsNumGE
(
0.0
f
)
.
IsNumLE
(
0.001
f
)
.
IsNumLE
(
0.001
f
)
.
End
()
.
AddAttr
(
"trainable_statistics"
)
.
IsBoolEQ
(
false
)
.
End
()
.
AddAttr
(
"is_test"
)
.
IsBoolEQ
(
true
)
.
End
();
.
End
();
AddOpCompat
(
OpCompat
(
"relu"
))
AddOpCompat
(
OpCompat
(
"relu"
))
...
@@ -114,21 +108,19 @@ void FuseBatchNormActOneDNNPass::FuseBatchNormAct(
...
@@ -114,21 +108,19 @@ void FuseBatchNormActOneDNNPass::FuseBatchNormAct(
GET_IR_NODE_FROM_SUBGRAPH
(
act
,
act
,
bn_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
act
,
act
,
bn_act_pattern
);
auto
*
bn_op
=
batch_norm
->
Op
();
auto
*
bn_op
=
batch_norm
->
Op
();
if
(
bn_op
->
HasAttr
(
"
use_mkldnn
"
))
{
if
(
bn_op
->
HasAttr
(
"
trainable_statistics
"
))
{
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
BOOST_GET_CONST
(
bool
,
bn_op
->
GetAttr
(
"use_mkldnn
"
)),
!
BOOST_GET_CONST
(
bool
,
bn_op
->
GetAttr
(
"trainable_statistics
"
)),
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"The BatchNorm+Act fusion may happen only when
oneDNN library
"
"The BatchNorm+Act fusion may happen only when
mean and variance
"
"
is used
."
));
"
are not calculated by current batch statistics
."
));
}
}
auto
*
act_op
=
act
->
Op
();
if
(
bn_op
->
HasAttr
(
"is_test"
))
{
if
(
act_op
->
HasAttr
(
"use_mkldnn"
))
{
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
BOOST_GET_CONST
(
bool
,
bn_op
->
GetAttr
(
"
use_mkldnn
"
)),
BOOST_GET_CONST
(
bool
,
bn_op
->
GetAttr
(
"
is_test
"
)),
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"The BatchNorm+Act fusion may happen only when oneDNN library "
"The BatchNorm+Act fusion may happen only during inference."
));
"is used."
));
}
}
bn_op
->
SetAttr
(
"use_mkldnn"
,
true
);
bn_op
->
SetAttr
(
"use_mkldnn"
,
true
);
...
...
paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc
浏览文件 @
04527ee3
...
@@ -65,9 +65,9 @@ TEST(FuseBatchNormActOneDNNPass, ThrowIsTestTrainableStats) {
...
@@ -65,9 +65,9 @@ TEST(FuseBatchNormActOneDNNPass, ThrowIsTestTrainableStats) {
// No fusion in this attribute configuration
// No fusion in this attribute configuration
constexpr
int
removed_nodes_count
=
0
;
constexpr
int
removed_nodes_count
=
0
;
EXPECT_T
RUE
(
test
::
RunPassAndAssert
(
&
graph
,
"batch_norm_act_fuse_pass"
,
"x"
,
EXPECT_T
HROW
(
test
::
RunPassAndAssert
(
&
graph
,
"batch_norm_act_fuse_pass"
,
"x"
,
"act_y"
,
removed_nodes_count
));
"act_y"
,
removed_nodes_count
),
EXPECT_TRUE
(
test
::
AssertOpsCount
(
graph
,
{{
"batch_norm"
,
1
},
{
"relu"
,
1
}})
);
paddle
::
platform
::
EnforceNotMet
);
}
}
TEST
(
FuseBatchNormActOneDNNPass
,
FuseIsTest
)
{
TEST
(
FuseBatchNormActOneDNNPass
,
FuseIsTest
)
{
...
@@ -123,9 +123,9 @@ TEST(FuseBatchNormActOneDNNPass, ThrowTrainableStats) {
...
@@ -123,9 +123,9 @@ TEST(FuseBatchNormActOneDNNPass, ThrowTrainableStats) {
// No fusion in this attribute configuration
// No fusion in this attribute configuration
constexpr
int
removed_nodes_count
=
0
;
constexpr
int
removed_nodes_count
=
0
;
EXPECT_T
RUE
(
test
::
RunPassAndAssert
(
&
graph
,
"batch_norm_act_fuse_pass"
,
"x"
,
EXPECT_T
HROW
(
test
::
RunPassAndAssert
(
&
graph
,
"batch_norm_act_fuse_pass"
,
"x"
,
"act_y"
,
removed_nodes_count
));
"act_y"
,
removed_nodes_count
),
EXPECT_TRUE
(
test
::
AssertOpsCount
(
graph
,
{{
"batch_norm"
,
1
},
{
"relu"
,
1
}})
);
paddle
::
platform
::
EnforceNotMet
);
}
}
TEST
(
FuseBatchNormActOneDNNPass
,
AllAttrsFalse
)
{
TEST
(
FuseBatchNormActOneDNNPass
,
AllAttrsFalse
)
{
...
@@ -149,9 +149,9 @@ TEST(FuseBatchNormActOneDNNPass, AllAttrsFalse) {
...
@@ -149,9 +149,9 @@ TEST(FuseBatchNormActOneDNNPass, AllAttrsFalse) {
// No fusion in this attribute configuration
// No fusion in this attribute configuration
constexpr
int
removed_nodes_count
=
0
;
constexpr
int
removed_nodes_count
=
0
;
EXPECT_T
RUE
(
test
::
RunPassAndAssert
(
&
graph
,
"batch_norm_act_fuse_pass"
,
"x"
,
EXPECT_T
HROW
(
test
::
RunPassAndAssert
(
&
graph
,
"batch_norm_act_fuse_pass"
,
"x"
,
"act_y"
,
removed_nodes_count
));
"act_y"
,
removed_nodes_count
),
EXPECT_TRUE
(
test
::
AssertOpsCount
(
graph
,
{{
"batch_norm"
,
1
},
{
"relu"
,
1
}})
);
paddle
::
platform
::
EnforceNotMet
);
}
}
TEST
(
FuseBatchNormActOneDNNPass
,
ThrowUseMkldnn
)
{
TEST
(
FuseBatchNormActOneDNNPass
,
ThrowUseMkldnn
)
{
...
@@ -176,9 +176,9 @@ TEST(FuseBatchNormActOneDNNPass, ThrowUseMkldnn) {
...
@@ -176,9 +176,9 @@ TEST(FuseBatchNormActOneDNNPass, ThrowUseMkldnn) {
// No fusion in this attribute configuration
// No fusion in this attribute configuration
constexpr
int
removed_nodes_count
=
0
;
constexpr
int
removed_nodes_count
=
0
;
EXPECT_T
RUE
(
test
::
RunPassAndAssert
(
&
graph
,
"batch_norm_act_fuse_pass"
,
"x"
,
EXPECT_T
HROW
(
test
::
RunPassAndAssert
(
&
graph
,
"batch_norm_act_fuse_pass"
,
"x"
,
"act_y"
,
removed_nodes_count
));
"act_y"
,
removed_nodes_count
),
EXPECT_TRUE
(
test
::
AssertOpsCount
(
graph
,
{{
"batch_norm"
,
1
},
{
"relu"
,
1
}})
);
paddle
::
platform
::
EnforceNotMet
);
}
}
TEST
(
FuseBatchNormActOneDNNPass
,
pass_op_version_check
)
{
TEST
(
FuseBatchNormActOneDNNPass
,
pass_op_version_check
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录