Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
fdc06f21
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
fdc06f21
编写于
10月 27, 2020
作者:
Z
Zhang Ting
提交者:
GitHub
10月 27, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add Fuse bn add act pass (#28196)
* add fuse_bn_add_act pass
上级
813b2ade
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
771 addition
and
16 deletion
+771
-16
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+1
-1
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+8
-0
paddle/fluid/framework/details/build_strategy.h
paddle/fluid/framework/details/build_strategy.h
+1
-0
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+1
-0
paddle/fluid/framework/ir/fuse_bn_add_act_pass.cc
paddle/fluid/framework/ir/fuse_bn_add_act_pass.cc
+365
-0
paddle/fluid/framework/ir/fuse_bn_add_act_pass.h
paddle/fluid/framework/ir/fuse_bn_add_act_pass.h
+75
-0
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+186
-0
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+72
-0
paddle/fluid/operators/fused/fused_bn_add_activation_op.cc
paddle/fluid/operators/fused/fused_bn_add_activation_op.cc
+0
-2
paddle/fluid/operators/fused/fused_bn_add_activation_op.cu
paddle/fluid/operators/fused/fused_bn_add_activation_op.cu
+0
-1
paddle/fluid/operators/fused/fused_bn_add_activation_op.h
paddle/fluid/operators/fused/fused_bn_add_activation_op.h
+0
-1
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+25
-0
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+2
-0
python/paddle/fluid/tests/unittests/test_fuse_bn_add_act_pass.py
...paddle/fluid/tests/unittests/test_fuse_bn_add_act_pass.py
+35
-11
未找到文件。
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
fdc06f21
...
@@ -107,7 +107,7 @@ cc_test(exception_holder_test SRCS exception_holder_test.cc )
...
@@ -107,7 +107,7 @@ cc_test(exception_holder_test SRCS exception_holder_test.cc )
set
(
IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass
set
(
IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass fuse_bn_act_pass
fuse_elewise_add_act_pass fuse_bn_act_pass
fuse_bn_add_act_pass
multi_batch_merge_pass
multi_batch_merge_pass
fuse_relu_depthwise_conv_pass
fuse_relu_depthwise_conv_pass
lock_free_optimize_pass
lock_free_optimize_pass
...
...
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
fdc06f21
...
@@ -164,6 +164,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
...
@@ -164,6 +164,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPassWithCheck
(
strategy_
.
fuse_relu_depthwise_conv_
,
AppendPassWithCheck
(
strategy_
.
fuse_relu_depthwise_conv_
,
"fuse_relu_depthwise_conv_pass"
);
"fuse_relu_depthwise_conv_pass"
);
AppendPassWithCheck
(
strategy_
.
fuse_bn_act_ops_
,
"fuse_bn_act_pass"
);
AppendPassWithCheck
(
strategy_
.
fuse_bn_act_ops_
,
"fuse_bn_act_pass"
);
AppendPassWithCheck
(
strategy_
.
fuse_bn_add_act_ops_
,
"fuse_bn_add_act_pass"
);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__)
AppendPassWithCheck
(
strategy_
.
enable_auto_fusion_
,
"fusion_group_pass"
);
AppendPassWithCheck
(
strategy_
.
enable_auto_fusion_
,
"fusion_group_pass"
);
#else
#else
...
@@ -390,6 +391,12 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
...
@@ -390,6 +391,12 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
"GPU, skipped."
;
"GPU, skipped."
;
continue
;
continue
;
}
}
}
else
if
(
pass
->
Type
()
==
"fuse_bn_add_act_pass"
)
{
if
(
!
use_cuda
)
{
LOG
(
WARNING
)
<<
"fuse_bn_add_act_pass is only supported on "
"GPU, skipped."
;
continue
;
}
}
else
if
(
pass
->
Type
()
==
"mkldnn_placement_pass"
)
{
}
else
if
(
pass
->
Type
()
==
"mkldnn_placement_pass"
)
{
pass
->
Set
(
"mkldnn_enabled_op_types"
,
pass
->
Set
(
"mkldnn_enabled_op_types"
,
new
std
::
unordered_set
<
std
::
string
>
(
mkldnn_enabled_op_types_
));
new
std
::
unordered_set
<
std
::
string
>
(
mkldnn_enabled_op_types_
));
...
@@ -416,6 +423,7 @@ USE_PASS(sync_batch_norm_pass);
...
@@ -416,6 +423,7 @@ USE_PASS(sync_batch_norm_pass);
USE_PASS
(
fuse_relu_depthwise_conv_pass
);
USE_PASS
(
fuse_relu_depthwise_conv_pass
);
USE_PASS
(
fuse_elewise_add_act_pass
);
USE_PASS
(
fuse_elewise_add_act_pass
);
USE_PASS
(
fuse_bn_act_pass
);
USE_PASS
(
fuse_bn_act_pass
);
USE_PASS
(
fuse_bn_add_act_pass
);
USE_PASS
(
graph_viz_pass
);
USE_PASS
(
graph_viz_pass
);
USE_PASS
(
multi_batch_merge_pass
);
USE_PASS
(
multi_batch_merge_pass
);
USE_PASS
(
reduce_mode_multi_devices_pass
);
USE_PASS
(
reduce_mode_multi_devices_pass
);
...
...
paddle/fluid/framework/details/build_strategy.h
浏览文件 @
fdc06f21
...
@@ -100,6 +100,7 @@ struct BuildStrategy {
...
@@ -100,6 +100,7 @@ struct BuildStrategy {
// TODO(dev-paddle): fuse_elewise_add_act_ops may cause some models have
// TODO(dev-paddle): fuse_elewise_add_act_ops may cause some models have
// cycle.
// cycle.
bool
fuse_bn_act_ops_
{
false
};
bool
fuse_bn_act_ops_
{
false
};
bool
fuse_bn_add_act_ops_
{
true
};
bool
fuse_elewise_add_act_ops_
{
false
};
bool
fuse_elewise_add_act_ops_
{
false
};
bool
enable_auto_fusion_
{
false
};
bool
enable_auto_fusion_
{
false
};
// Fuse_all_optimizer_ops and fuse_all_reduce_ops require that gradients
// Fuse_all_optimizer_ops and fuse_all_reduce_ops require that gradients
...
...
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
fdc06f21
...
@@ -114,6 +114,7 @@ if(WITH_MKLDNN)
...
@@ -114,6 +114,7 @@ if(WITH_MKLDNN)
endif
()
endif
()
cc_library
(
fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector
)
cc_library
(
fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector
)
cc_library
(
fuse_bn_add_act_pass SRCS fuse_bn_add_act_pass.cc DEPS pass graph_pattern_detector
)
cc_library
(
fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector
)
cc_library
(
fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector
)
cc_library
(
fuse_relu_depthwise_conv_pass SRCS fuse_relu_depthwise_conv_pass.cc DEPS pass graph_pattern_detector
)
cc_library
(
fuse_relu_depthwise_conv_pass SRCS fuse_relu_depthwise_conv_pass.cc DEPS pass graph_pattern_detector
)
...
...
paddle/fluid/framework/ir/fuse_bn_add_act_pass.cc
0 → 100644
浏览文件 @
fdc06f21
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fuse_bn_add_act_pass.h"
#include <algorithm>
#include <string>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
FuseBatchNormAddActPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
#ifdef PADDLE_WITH_CUDA
#if CUDNN_VERSION_MIN(7, 4, 1)
// forward
std
::
unordered_set
<
std
::
string
>
act_types
=
{
"relu"
};
graph
=
FuseBatchNormAddAct
(
graph
,
act_types
);
// backward
std
::
unordered_set
<
std
::
string
>
act_grad_types
=
{
"relu_grad"
};
graph
=
FuseBatchNormAddActGrad
(
graph
,
act_grad_types
);
#endif
#endif
}
// act(bn(x) + z)
ir
::
Graph
*
FuseBatchNormAddActPass
::
FuseBatchNormAddAct
(
ir
::
Graph
*
graph
,
const
std
::
unordered_set
<
std
::
string
>
&
act_types
)
const
{
PADDLE_ENFORCE_NE
(
graph
,
nullptr
,
platform
::
errors
::
InvalidArgument
(
"The input graph of FuseBatchNormAddAct should not be nullptr."
));
FusePassBase
::
Init
(
"bn_add_act"
,
graph
);
GraphPatternDetector
gpd
;
auto
*
x
=
gpd
.
mutable_pattern
()
->
NewNode
(
"bn_add_act/x"
)
->
AsInput
()
->
assert_is_op_input
(
"batch_norm"
,
"X"
)
->
assert_var_dtype
(
proto
::
VarType
::
FP16
);
patterns
::
BatchNormAddAct
bn_add_act_pattern
(
gpd
.
mutable_pattern
(),
"bn_add_act"
);
bn_add_act_pattern
(
x
,
act_types
);
int
found_bn_add_act_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
4
)
<<
"handle FuseBatchNormAddAct fuse"
;
// BN inputs
GET_IR_NODE_FROM_SUBGRAPH
(
bn_scale
,
bn_scale
,
bn_add_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
bn_bias
,
bn_bias
,
bn_add_act_pattern
);
// BN outputs
GET_IR_NODE_FROM_SUBGRAPH
(
bn_mean_out
,
bn_mean_out
,
bn_add_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
bn_variance_out
,
bn_variance_out
,
bn_add_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
bn_saved_variance
,
bn_saved_variance
,
bn_add_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
bn_saved_mean
,
bn_saved_mean
,
bn_add_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
bn_reserve_space
,
bn_reserve_space
,
bn_add_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
bn_out
,
bn_out
,
bn_add_act_pattern
);
// Add outputs
GET_IR_NODE_FROM_SUBGRAPH
(
elewise_add_in
,
elewise_add_in
,
bn_add_act_pattern
);
// Add outputs
GET_IR_NODE_FROM_SUBGRAPH
(
elewise_add_out
,
elewise_add_out
,
bn_add_act_pattern
);
// ACT output
GET_IR_NODE_FROM_SUBGRAPH
(
act_out
,
act_out
,
bn_add_act_pattern
);
// ops
GET_IR_NODE_FROM_SUBGRAPH
(
batch_norm
,
batch_norm
,
bn_add_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elewise_add
,
elewise_add
,
bn_add_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
act
,
act
,
bn_add_act_pattern
);
std
::
string
bn_x_n
=
subgraph
.
at
(
x
)
->
Name
();
std
::
string
elewise_add_in_n
=
elewise_add_in
->
Name
();
std
::
string
bn_scale_n
=
bn_scale
->
Name
();
std
::
string
bn_bias_n
=
bn_bias
->
Name
();
std
::
string
bn_mean_out_n
=
bn_mean_out
->
Name
();
std
::
string
bn_variance_out_n
=
bn_variance_out
->
Name
();
std
::
string
bn_saved_variance_n
=
bn_saved_variance
->
Name
();
std
::
string
bn_saved_mean_n
=
bn_saved_mean
->
Name
();
std
::
string
bn_reserve_space_n
=
bn_reserve_space
->
Name
();
std
::
string
bn_out_n
=
bn_out
->
Name
();
std
::
string
elewise_add_out_n
=
elewise_add_out
->
Name
();
std
::
string
act_out_n
=
act_out
->
Name
();
Node
*
fused_bn_add_act_node
=
CreateFusedBatchNormAddActNode
(
g
,
act
,
elewise_add
,
batch_norm
,
bn_x_n
,
elewise_add_in_n
,
bn_scale_n
,
bn_bias_n
,
bn_mean_out_n
,
bn_variance_out_n
,
bn_saved_variance_n
,
bn_saved_mean_n
,
bn_reserve_space_n
,
act_out_n
);
VLOG
(
4
)
<<
"
\n\t
"
<<
bn_x_n
<<
", "
<<
bn_scale_n
<<
", "
<<
bn_bias_n
<<
" -> "
<<
batch_norm
->
Name
()
<<
" -> "
<<
bn_mean_out_n
<<
", "
<<
bn_variance_out_n
<<
", "
<<
bn_saved_variance_n
<<
", "
<<
bn_saved_mean_n
<<
", "
<<
bn_reserve_space_n
<<
" and "
<<
bn_out_n
<<
"
\n
"
<<
"
\t
"
<<
bn_out_n
<<
" and "
<<
elewise_add_in_n
<<
" -> "
<<
elewise_add
->
Name
()
<<
" -> "
<<
elewise_add_out_n
<<
"
\n
"
<<
"
\t
"
<<
elewise_add_out_n
<<
" -> "
<<
act
->
Name
()
<<
" -> "
<<
act_out_n
;
ReLinkNodes
(
g
,
batch_norm
,
elewise_add
,
act
,
fused_bn_add_act_node
);
found_bn_add_act_count
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_bn_add_act_count
);
return
graph
;
}
Node
*
FuseBatchNormAddActPass
::
CreateFusedBatchNormAddActNode
(
Graph
*
g
,
const
Node
*
act
,
const
Node
*
elewise_add
,
const
Node
*
bn
,
const
std
::
string
&
bn_x_n
,
const
std
::
string
&
elewise_add_in_n
,
const
std
::
string
&
bn_scale_n
,
const
std
::
string
&
bn_bias_n
,
const
std
::
string
&
bn_mean_out_n
,
const
std
::
string
&
bn_variance_out_n
,
const
std
::
string
&
bn_saved_variance_n
,
const
std
::
string
&
bn_saved_mean_n
,
const
std
::
string
&
bn_reserve_space_n
,
const
std
::
string
&
act_out_n
)
const
{
OpDesc
desc
;
desc
.
SetInput
(
"X"
,
std
::
vector
<
std
::
string
>
({
bn_x_n
}));
desc
.
SetInput
(
"Z"
,
std
::
vector
<
std
::
string
>
({
elewise_add_in_n
}));
desc
.
SetInput
(
"Scale"
,
std
::
vector
<
std
::
string
>
({
bn_scale_n
}));
desc
.
SetInput
(
"Bias"
,
std
::
vector
<
std
::
string
>
({
bn_bias_n
}));
desc
.
SetOutput
(
"Y"
,
std
::
vector
<
std
::
string
>
({
act_out_n
}));
desc
.
SetOutput
(
"MeanOut"
,
std
::
vector
<
std
::
string
>
({
bn_mean_out_n
}));
desc
.
SetOutput
(
"VarianceOut"
,
std
::
vector
<
std
::
string
>
({
bn_variance_out_n
}));
desc
.
SetOutput
(
"SavedMean"
,
std
::
vector
<
std
::
string
>
({
bn_saved_mean_n
}));
desc
.
SetOutput
(
"SavedVariance"
,
std
::
vector
<
std
::
string
>
({
bn_saved_variance_n
}));
desc
.
SetOutput
(
"ReserveSpace"
,
std
::
vector
<
std
::
string
>
({
bn_reserve_space_n
}));
desc
.
SetType
(
"fused_bn_add_activation"
);
desc
.
SetAttr
(
"act_type"
,
act
->
Name
());
// Set attrs
for
(
auto
&
n
:
{
act
->
Op
(),
elewise_add
->
Op
(),
bn
->
Op
()})
{
for
(
auto
&
m
:
n
->
GetAttrMap
())
{
desc
.
SetAttr
(
m
.
first
,
m
.
second
);
}
}
auto
fused_bn_add_act_node
=
g
->
CreateOpNode
(
&
desc
);
return
fused_bn_add_act_node
;
}
// the backward of act(bn(x) + z)
ir
::
Graph
*
FuseBatchNormAddActPass
::
FuseBatchNormAddActGrad
(
ir
::
Graph
*
graph
,
const
std
::
unordered_set
<
std
::
string
>
&
act_grad_types
)
const
{
PADDLE_ENFORCE_NE
(
graph
,
nullptr
,
platform
::
errors
::
InvalidArgument
(
"The input graph of FuseBatchNormAddActGrad should not be nullptr."
));
FusePassBase
::
Init
(
"bn_add_act_grad"
,
graph
);
GraphPatternDetector
gpd
;
auto
*
d_act_out
=
gpd
.
mutable_pattern
()
->
NewNode
(
"bn_add_act_grad/x"
)
->
AsInput
()
->
assert_is_ops_input
(
act_grad_types
,
GradVarName
(
"Out"
))
->
assert_var_dtype
(
proto
::
VarType
::
FP16
);
patterns
::
BatchNormAddActGrad
bn_add_act_grad_pattern
(
gpd
.
mutable_pattern
(),
"bn_add_act_grad"
);
bn_add_act_grad_pattern
(
d_act_out
,
act_grad_types
);
int
found_bn_add_act_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
4
)
<<
"handle FuseBatchNormAddActGrad fuse"
;
GET_IR_NODE_FROM_SUBGRAPH
(
act_grad
,
act_grad
,
bn_add_act_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elewise_add_grad
,
elewise_add_grad
,
bn_add_act_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
batch_norm_grad
,
batch_norm_grad
,
bn_add_act_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
act_out
,
act_out
,
bn_add_act_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
d_act_x
,
d_act_x
,
bn_add_act_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
d_bn_out
,
d_bn_out
,
bn_add_act_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
bn_x
,
bn_x
,
bn_add_act_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
bn_scale
,
bn_scale
,
bn_add_act_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
bn_bias
,
bn_bias
,
bn_add_act_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
bn_saved_mean
,
bn_saved_mean
,
bn_add_act_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
bn_saved_variance
,
bn_saved_variance
,
bn_add_act_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
bn_reserve_space
,
bn_reserve_space
,
bn_add_act_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
d_bn_x
,
d_bn_x
,
bn_add_act_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
d_bn_scale
,
d_bn_scale
,
bn_add_act_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
d_bn_bias
,
d_bn_bias
,
bn_add_act_grad_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
d_elewise_add_in
,
d_elewise_add_in
,
bn_add_act_grad_pattern
);
std
::
string
d_act_out_n
=
subgraph
.
at
(
d_act_out
)
->
Name
();
// Y@GRAD
std
::
string
act_out_n
=
act_out
->
Name
();
// Y
std
::
string
d_act_x_n
=
d_act_x
->
Name
();
std
::
string
bn_x_n
=
bn_x
->
Name
();
std
::
string
bn_scale_n
=
bn_scale
->
Name
();
std
::
string
bn_bias_n
=
bn_bias
->
Name
();
std
::
string
bn_saved_mean_n
=
bn_saved_mean
->
Name
();
std
::
string
bn_saved_variance_n
=
bn_saved_variance
->
Name
();
std
::
string
bn_reserve_space_n
=
bn_reserve_space
->
Name
();
std
::
string
d_bn_out_n
=
d_bn_out
->
Name
();
std
::
string
d_bn_x_n
=
d_bn_x
->
Name
();
std
::
string
d_bn_scale_n
=
d_bn_scale
->
Name
();
std
::
string
d_bn_bias_n
=
d_bn_bias
->
Name
();
std
::
string
d_elewise_add_in_n
=
d_elewise_add_in
->
Name
();
OpDesc
desc
;
desc
.
SetType
(
"fused_bn_add_activation_grad"
);
desc
.
SetInput
(
"X"
,
{
bn_x_n
});
desc
.
SetInput
(
"Y"
,
std
::
vector
<
std
::
string
>
({
act_out_n
}));
desc
.
SetInput
(
GradVarName
(
"Y"
),
std
::
vector
<
std
::
string
>
({
d_act_out_n
}));
desc
.
SetInput
(
"Scale"
,
std
::
vector
<
std
::
string
>
({
bn_scale_n
}));
desc
.
SetInput
(
"Bias"
,
std
::
vector
<
std
::
string
>
({
bn_bias_n
}));
desc
.
SetInput
(
"SavedMean"
,
std
::
vector
<
std
::
string
>
({
bn_saved_mean_n
}));
desc
.
SetInput
(
"SavedVariance"
,
std
::
vector
<
std
::
string
>
({
bn_saved_variance_n
}));
desc
.
SetInput
(
"ReserveSpace"
,
std
::
vector
<
std
::
string
>
({
bn_reserve_space_n
}));
desc
.
SetOutput
(
GradVarName
(
"X"
),
std
::
vector
<
std
::
string
>
({
d_bn_x_n
}));
desc
.
SetOutput
(
GradVarName
(
"Z"
),
std
::
vector
<
std
::
string
>
({
d_elewise_add_in_n
}));
desc
.
SetOutput
(
GradVarName
(
"Scale"
),
std
::
vector
<
std
::
string
>
({
d_bn_scale_n
}));
desc
.
SetOutput
(
GradVarName
(
"Bias"
),
std
::
vector
<
std
::
string
>
({
d_bn_bias_n
}));
std
::
string
act
=
act_grad
->
Name
();
act
=
act
.
substr
(
0
,
act
.
length
()
-
5
);
// remove "_grad"
desc
.
SetAttr
(
"act_type"
,
act
);
for
(
auto
&
n
:
{
act_grad
->
Op
(),
elewise_add_grad
->
Op
(),
batch_norm_grad
->
Op
()})
{
for
(
auto
&
m
:
n
->
GetAttrMap
())
{
desc
.
SetAttr
(
m
.
first
,
m
.
second
);
}
}
auto
fused_node
=
g
->
CreateOpNode
(
&
desc
);
VLOG
(
4
)
<<
"
\n\t
"
<<
d_act_out_n
<<
" and "
<<
act_out_n
<<
" -> "
<<
act_grad
->
Name
()
<<
" -> "
<<
d_act_x_n
<<
"
\n\t
"
;
VLOG
(
4
)
<<
d_act_x_n
<<
" -> "
<<
elewise_add_grad
->
Name
()
<<
" -> "
<<
d_elewise_add_in_n
<<
","
<<
d_bn_out_n
<<
"
\n\t
"
;
VLOG
(
4
)
<<
bn_x_n
<<
", "
<<
d_bn_out_n
<<
", "
<<
bn_scale_n
<<
", "
<<
bn_bias_n
<<
", "
<<
bn_saved_mean_n
<<
", "
<<
bn_saved_variance_n
<<
" and "
<<
bn_reserve_space_n
<<
" -> "
<<
batch_norm_grad
->
Name
()
<<
" -> "
<<
d_bn_x_n
<<
", "
<<
d_bn_scale_n
<<
" and "
<<
d_bn_bias_n
;
ReLinkNodes
(
g
,
act_grad
,
elewise_add_grad
,
batch_norm_grad
,
fused_node
);
found_bn_add_act_count
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_bn_add_act_count
);
return
graph
;
}
void
FuseBatchNormAddActPass
::
ReLinkNodes
(
Graph
*
graph
,
Node
*
op_1
,
Node
*
op_2
,
Node
*
op_3
,
Node
*
fused_op
)
const
{
// delete act
// link inputs of op_1 to fused_op
for
(
auto
&
in
:
op_1
->
inputs
)
{
fused_op
->
inputs
.
emplace_back
(
in
);
in
->
outputs
=
this
->
ReplaceNode
(
op_1
,
fused_op
,
in
->
outputs
);
}
std
::
unordered_set
<
const
Node
*>
nodes2delete
;
LinkOutputsToFuseOp
(
op_1
,
op_2
,
fused_op
,
&
nodes2delete
);
LinkOutputsToFuseOp
(
op_2
,
op_3
,
fused_op
,
&
nodes2delete
);
LinkInputsToFuseOp
(
op_2
,
fused_op
,
&
nodes2delete
);
LinkInputsToFuseOp
(
op_3
,
fused_op
,
&
nodes2delete
);
for
(
auto
&
out
:
op_3
->
outputs
)
{
IR_OP_VAR_LINK
(
fused_op
,
out
);
}
nodes2delete
.
insert
(
std
::
move
(
op_1
));
nodes2delete
.
insert
(
std
::
move
(
op_2
));
nodes2delete
.
insert
(
std
::
move
(
op_3
));
GraphSafeRemoveNodes
(
graph
,
nodes2delete
);
}
void
FuseBatchNormAddActPass
::
LinkOutputsToFuseOp
(
Node
*
op_1
,
Node
*
op_2
,
Node
*
fused_op
,
std
::
unordered_set
<
const
Node
*>
*
nodes2delete
)
const
{
// if the outputs of op_1 are inputs of op_2, add the outputs to nodes2delete
// otherwise link the outputs to fused_op
for
(
auto
&
out
:
op_1
->
outputs
)
{
auto
result_iter
=
std
::
find_if
(
op_2
->
inputs
.
begin
(),
op_2
->
inputs
.
end
(),
[
&
out
](
const
Node
*
node
)
->
bool
{
return
node
==
out
;
});
if
(
result_iter
==
op_2
->
inputs
.
end
())
{
IR_OP_VAR_LINK
(
fused_op
,
out
);
}
else
{
nodes2delete
->
emplace
(
out
);
}
}
}
void
FuseBatchNormAddActPass
::
LinkInputsToFuseOp
(
Node
*
op
,
Node
*
fused_op
,
std
::
unordered_set
<
const
Node
*>
*
nodes2delete
)
const
{
// if the inputs of the op are outputs of previous op, which means
// these inputs have been added to nodes2delete before, skip the inputs,
// otherwise link the inputs of the op to fused_op
for
(
auto
&
in
:
op
->
inputs
)
{
if
(
nodes2delete
->
count
(
in
))
{
continue
;
}
fused_op
->
inputs
.
emplace_back
(
in
);
in
->
outputs
=
this
->
ReplaceNode
(
op
,
fused_op
,
in
->
outputs
);
}
}
std
::
vector
<
Node
*>
FuseBatchNormAddActPass
::
ReplaceNode
(
Node
*
cur_node
,
Node
*
new_node
,
const
std
::
vector
<
Node
*>
&
nodes
)
const
{
std
::
vector
<
Node
*>
new_list
(
nodes
.
size
());
bool
has_replaced
=
false
;
std
::
transform
(
nodes
.
begin
(),
nodes
.
end
(),
new_list
.
begin
(),
[
&
](
Node
*
node
)
->
Node
*
{
if
(
node
==
cur_node
)
{
has_replaced
=
true
;
return
new_node
;
}
return
node
;
});
PADDLE_ENFORCE_EQ
(
has_replaced
,
true
,
platform
::
errors
::
NotFound
(
"Not found %s in the node list."
,
cur_node
->
Name
()));
return
new_list
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
fuse_bn_add_act_pass
,
paddle
::
framework
::
ir
::
FuseBatchNormAddActPass
);
paddle/fluid/framework/ir/fuse_bn_add_act_pass.h
0 → 100644
浏览文件 @
fdc06f21
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
/*
* Fuse the BatchNorm, add and activation.
*/
class
Graph
;
class
Node
;
class
FuseBatchNormAddActPass
:
public
FusePassBase
{
public:
virtual
~
FuseBatchNormAddActPass
()
{}
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
ir
::
Graph
*
FuseBatchNormAddAct
(
ir
::
Graph
*
graph
,
const
std
::
unordered_set
<
std
::
string
>
&
act_types
)
const
;
ir
::
Graph
*
FuseBatchNormAddActGrad
(
ir
::
Graph
*
graph
,
const
std
::
unordered_set
<
std
::
string
>
&
act_grad_types
)
const
;
void
LinkOutputsToFuseOp
(
Node
*
op_1
,
Node
*
op_2
,
Node
*
fused_op
,
std
::
unordered_set
<
const
Node
*>
*
nodes2delete
)
const
;
void
LinkInputsToFuseOp
(
Node
*
op
,
Node
*
fused_op
,
std
::
unordered_set
<
const
Node
*>
*
nodes2delete
)
const
;
std
::
vector
<
Node
*>
ReplaceNode
(
Node
*
cur_node
,
Node
*
new_node
,
const
std
::
vector
<
Node
*>
&
nodes
)
const
;
void
ReLinkNodes
(
Graph
*
graph
,
Node
*
op_1
,
Node
*
op_2
,
Node
*
op_3
,
Node
*
fused_op
)
const
;
Node
*
CreateFusedBatchNormAddActNode
(
Graph
*
g
,
const
Node
*
act
,
const
Node
*
add
,
const
Node
*
bn
,
const
std
::
string
&
bn_x_n
,
const
std
::
string
&
add_y_n
,
const
std
::
string
&
bn_scale_n
,
const
std
::
string
&
bn_bias_n
,
const
std
::
string
&
bn_mean_out_n
,
const
std
::
string
&
bn_variance_out_n
,
const
std
::
string
&
bn_saved_variance_n
,
const
std
::
string
&
bn_saved_mean_n
,
const
std
::
string
&
bn_reserve_space_n
,
const
std
::
string
&
act_out_n
)
const
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
fdc06f21
...
@@ -93,6 +93,7 @@ void GraphPatternDetector::operator()(Graph *graph,
...
@@ -93,6 +93,7 @@ void GraphPatternDetector::operator()(Graph *graph,
auto
subgraphs
=
DetectPatterns
();
auto
subgraphs
=
DetectPatterns
();
UniquePatterns
(
&
subgraphs
);
UniquePatterns
(
&
subgraphs
);
SortSubgraphs
(
&
subgraphs
);
RemoveOverlappedMatch
(
&
subgraphs
);
RemoveOverlappedMatch
(
&
subgraphs
);
ValidateByNodeRole
(
&
subgraphs
);
ValidateByNodeRole
(
&
subgraphs
);
...
@@ -302,6 +303,46 @@ void GraphPatternDetector::UniquePatterns(
...
@@ -302,6 +303,46 @@ void GraphPatternDetector::UniquePatterns(
*
subgraphs
=
result
;
*
subgraphs
=
result
;
}
}
void
GraphPatternDetector
::
SortSubgraphs
(
std
::
vector
<
GraphPatternDetector
::
subgraph_t
>
*
subgraphs
)
{
if
(
subgraphs
->
empty
())
return
;
bool
has_bn_add_act
=
false
;
for
(
auto
&
subgraph
:
*
subgraphs
)
{
for
(
auto
&
item
:
subgraph
)
{
if
(
item
.
first
->
name
().
find
(
"bn_add_act"
)
!=
std
::
string
::
npos
)
{
has_bn_add_act
=
true
;
break
;
}
}
}
if
(
!
has_bn_add_act
)
{
return
;
}
std
::
sort
(
subgraphs
->
begin
(),
subgraphs
->
end
(),
[](
const
GraphPatternDetector
::
subgraph_t
&
a
,
const
GraphPatternDetector
::
subgraph_t
&
b
)
{
for
(
auto
&
item
:
a
)
{
if
(
item
.
first
->
name
().
find
(
"bn_add_act"
)
!=
std
::
string
::
npos
&&
item
.
first
->
name
().
find
(
"bn_reserve_space"
)
!=
std
::
string
::
npos
)
{
auto
it_b
=
b
.
find
(
item
.
first
);
if
(
it_b
!=
b
.
end
())
{
if
(
item
.
second
->
Name
()
!=
it_b
->
second
->
Name
())
{
return
item
.
second
->
Name
()
<
it_b
->
second
->
Name
();
}
else
{
return
false
;
}
}
else
{
return
false
;
}
}
}
return
false
;
});
}
void
GraphPatternDetector
::
RemoveOverlappedMatch
(
void
GraphPatternDetector
::
RemoveOverlappedMatch
(
std
::
vector
<
subgraph_t
>
*
subgraphs
)
{
std
::
vector
<
subgraph_t
>
*
subgraphs
)
{
std
::
vector
<
subgraph_t
>
result
;
std
::
vector
<
subgraph_t
>
result
;
...
@@ -1208,6 +1249,151 @@ PDNode *patterns::BatchNormActOneDNN::operator()(const std::string &act_type) {
...
@@ -1208,6 +1249,151 @@ PDNode *patterns::BatchNormActOneDNN::operator()(const std::string &act_type) {
return
act_out
;
return
act_out
;
}
}
PDNode
*
patterns
::
BatchNormAddAct
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
bn_x_var
,
std
::
unordered_set
<
std
::
string
>
act_types
)
{
bn_x_var
->
assert_is_op_input
(
"batch_norm"
,
"X"
)
->
assert_var_dtype
(
proto
::
VarType
::
FP16
);
auto
*
bn_scale_var
=
pattern
->
NewNode
(
bn_scale_repr
())
->
assert_is_op_input
(
"batch_norm"
,
"Scale"
);
auto
*
bn_bias_var
=
pattern
->
NewNode
(
bn_bias_repr
())
->
assert_is_op_input
(
"batch_norm"
,
"Bias"
);
auto
*
bn
=
pattern
->
NewNode
(
batch_norm_repr
())
->
assert_is_op
(
"batch_norm"
)
->
assert_is_not_op_input
(
"MomentumTensor"
)
->
assert_op_attr
<
bool
>
(
"is_test"
,
false
)
->
assert_op_attr
<
bool
>
(
"use_global_stats"
,
false
)
->
assert_op_attr
<
std
::
string
>
(
"data_layout"
,
"NHWC"
);
auto
*
bn_mean_out_var
=
pattern
->
NewNode
(
bn_mean_out_repr
())
->
assert_is_op_output
(
"batch_norm"
,
"MeanOut"
);
auto
*
bn_variance_out_var
=
pattern
->
NewNode
(
bn_variance_out_repr
())
->
assert_is_op_output
(
"batch_norm"
,
"VarianceOut"
);
auto
*
bn_saved_variance_var
=
pattern
->
NewNode
(
bn_saved_variance_repr
())
->
assert_is_op_output
(
"batch_norm"
,
"SavedVariance"
);
auto
*
bn_saved_mean_var
=
pattern
->
NewNode
(
bn_saved_mean_repr
())
->
assert_is_op_output
(
"batch_norm"
,
"SavedMean"
);
auto
*
bn_reserve_space
=
pattern
->
NewNode
(
bn_reserve_space_repr
())
->
assert_is_op_output
(
"batch_norm"
,
"ReserveSpace"
);
auto
*
bn_out_var
=
pattern
->
NewNode
(
bn_out_repr
())
->
assert_is_op_output
(
"batch_norm"
,
"Y"
)
->
assert_var_dtype
(
proto
::
VarType
::
FP16
);
bn_out_var
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
elewise_add
=
pattern
->
NewNode
(
elewise_add_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
elewise_add_in_var
=
pattern
->
NewNode
(
elewise_add_in_repr
())
->
assert_is_not_ctrl_var
()
->
assert_is_op_input
(
"elementwise_add"
)
->
assert_var_dtype
(
proto
::
VarType
::
FP16
);
auto
*
elewise_add_out_var
=
pattern
->
NewNode
(
elewise_add_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
,
"Out"
)
->
assert_has_n_outputs
(
1
);
elewise_add_out_var
->
AsIntermediate
()
->
assert_is_ops_input
(
act_types
);
auto
*
act
=
pattern
->
NewNode
(
act_repr
())
->
assert_is_ops
(
act_types
);
auto
*
act_out_var
=
pattern
->
NewNode
(
act_out_repr
())
->
assert_is_ops_output
(
act_types
,
"Out"
);
bn
->
LinksFrom
({
bn_x_var
,
bn_scale_var
,
bn_bias_var
})
.
LinksTo
({
bn_mean_out_var
,
bn_variance_out_var
,
bn_saved_variance_var
,
bn_saved_mean_var
,
bn_reserve_space
,
bn_out_var
});
elewise_add
->
LinksFrom
({
elewise_add_in_var
,
bn_out_var
})
.
LinksTo
({
elewise_add_out_var
});
act
->
LinksFrom
({
elewise_add_out_var
}).
LinksTo
({
act_out_var
});
return
act_out_var
;
}
PDNode
*
patterns
::
BatchNormAddActGrad
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
d_act_out_var
,
std
::
unordered_set
<
std
::
string
>
act_grad_types
)
{
auto
*
act_grad
=
pattern
->
NewNode
(
act_grad_repr
())
->
assert_is_ops
(
act_grad_types
);
auto
*
elewise_add_grad
=
pattern
->
NewNode
(
elewise_add_grad_repr
())
->
assert_is_op
(
"elementwise_add_grad"
);
auto
*
bn_grad
=
pattern
->
NewNode
(
batch_norm_grad_repr
())
->
assert_is_op
(
"batch_norm_grad"
)
->
assert_op_attr
<
bool
>
(
"use_global_stats"
,
false
)
->
assert_op_attr
<
std
::
string
>
(
"data_layout"
,
"NHWC"
);
auto
*
act_out_var
=
pattern
->
NewNode
(
act_out_repr
())
->
assert_is_ops_input
(
act_grad_types
,
"Out"
);
auto
*
d_act_x_var
=
pattern
->
NewNode
(
d_act_x_repr
())
->
assert_is_ops_output
(
act_grad_types
,
GradVarName
(
"X"
))
->
assert_has_n_outputs
(
1
);
// d_act_x
d_act_x_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add_grad"
);
auto
*
d_elewise_add_in_var
=
pattern
->
NewNode
(
d_elewise_add_in_repr
())
->
assert_is_not_ctrl_var
()
->
assert_is_op_output
(
"elementwise_add_grad"
)
->
assert_var_dtype
(
proto
::
VarType
::
FP16
);
// d_add_in_1
auto
*
d_bn_out_var
=
pattern
->
NewNode
(
d_bn_out_repr
())
->
assert_is_not_ctrl_var
()
->
assert_is_op_output
(
"elementwise_add_grad"
)
->
assert_var_dtype
(
proto
::
VarType
::
FP16
);
// d_add_in_2
d_bn_out_var
->
assert_is_op_input
(
"batch_norm_grad"
,
GradVarName
(
"Y"
));
auto
*
bn_x_var
=
pattern
->
NewNode
(
bn_x_repr
())
->
assert_is_op_input
(
"batch_norm_grad"
,
"X"
)
->
assert_var_dtype
(
proto
::
VarType
::
FP16
);
auto
*
bn_scale_var
=
pattern
->
NewNode
(
bn_scale_repr
())
->
assert_is_op_input
(
"batch_norm_grad"
,
"Scale"
);
auto
*
bn_bias_var
=
pattern
->
NewNode
(
bn_bias_repr
())
->
assert_is_op_input
(
"batch_norm_grad"
,
"Bias"
);
auto
*
bn_saved_mean_var
=
pattern
->
NewNode
(
bn_saved_mean_repr
())
->
assert_is_op_input
(
"batch_norm_grad"
,
"SavedMean"
);
auto
*
bn_saved_variance_var
=
pattern
->
NewNode
(
bn_saved_variance_repr
())
->
assert_is_op_input
(
"batch_norm_grad"
,
"SavedVariance"
);
auto
*
bn_reserve_space
=
pattern
->
NewNode
(
bn_reserve_space_repr
())
->
assert_is_op_input
(
"batch_norm_grad"
,
"ReserveSpace"
);
auto
*
d_bn_x_var
=
pattern
->
NewNode
(
d_bn_x_repr
())
->
assert_is_not_ctrl_var
()
->
assert_is_op_output
(
"batch_norm_grad"
,
GradVarName
(
"X"
))
->
assert_var_dtype
(
proto
::
VarType
::
FP16
);
auto
*
d_bn_scale_var
=
pattern
->
NewNode
(
d_bn_scale_repr
())
->
assert_is_not_ctrl_var
()
->
assert_is_op_output
(
"batch_norm_grad"
,
GradVarName
(
"Scale"
));
auto
*
d_bn_bias_var
=
pattern
->
NewNode
(
d_bn_bias_repr
())
->
assert_is_not_ctrl_var
()
->
assert_is_op_output
(
"batch_norm_grad"
,
GradVarName
(
"Bias"
));
act_grad
->
LinksFrom
({
d_act_out_var
,
act_out_var
}).
LinksTo
({
d_act_x_var
});
elewise_add_grad
->
LinksFrom
({
d_act_x_var
})
.
LinksTo
({
d_elewise_add_in_var
,
d_bn_out_var
});
bn_grad
->
LinksFrom
({
bn_x_var
,
d_bn_out_var
,
bn_scale_var
,
bn_bias_var
,
bn_saved_mean_var
,
bn_saved_variance_var
,
bn_reserve_space
})
.
LinksTo
({
d_bn_x_var
,
d_bn_scale_var
,
d_bn_bias_var
});
return
bn_grad
;
}
PDNode
*
patterns
::
ElewiseAddAct
::
operator
()(
PDNode
*
patterns
::
ElewiseAddAct
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
ele_x_var
,
paddle
::
framework
::
ir
::
PDNode
*
ele_x_var
,
std
::
unordered_set
<
std
::
string
>
act_types
)
{
std
::
unordered_set
<
std
::
string
>
act_types
)
{
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
fdc06f21
...
@@ -294,6 +294,12 @@ class GraphPatternDetector {
...
@@ -294,6 +294,12 @@ class GraphPatternDetector {
// Remove duplicate patterns.
// Remove duplicate patterns.
void
UniquePatterns
(
std
::
vector
<
subgraph_t
>*
subgraphs
);
void
UniquePatterns
(
std
::
vector
<
subgraph_t
>*
subgraphs
);
// Sort subgraphs, sort subgraphs by the specified node so that
// the removed forward and backward subgraphs are corresponding
// when two subgraphs are overlapped. Note: this function is
// currently only used for bn_add_act, refer to PR28196 for details.
void
SortSubgraphs
(
std
::
vector
<
subgraph_t
>*
subgraphs
);
// Remove overlapped match subgraphs, when overlapped, keep the previous one.
// Remove overlapped match subgraphs, when overlapped, keep the previous one.
// The intermediate PDNodes will be removed, so can't shared by multiple
// The intermediate PDNodes will be removed, so can't shared by multiple
// patterns.
// patterns.
...
@@ -685,6 +691,72 @@ struct BatchNormActOneDNN : public PatternBase {
...
@@ -685,6 +691,72 @@ struct BatchNormActOneDNN : public PatternBase {
PATTERN_DECL_NODE
(
act_out
);
PATTERN_DECL_NODE
(
act_out
);
};
};
// The following pattern is used to fuse batch_norm, elewise_add, and act
// formula: act(bn(x) + z)
// op: batch_norm + elewise_add + act
struct
BatchNormAddAct
:
public
PatternBase
{
BatchNormAddAct
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"bn_add_act"
)
{}
PDNode
*
operator
()(
PDNode
*
x
,
std
::
unordered_set
<
std
::
string
>
acts
);
// declare operator node's name
PATTERN_DECL_NODE
(
batch_norm
);
PATTERN_DECL_NODE
(
elewise_add
);
PATTERN_DECL_NODE
(
act
);
// declare variable node's name
// BN inputs
PATTERN_DECL_NODE
(
bn_scale
);
PATTERN_DECL_NODE
(
bn_bias
);
// BN outputs
PATTERN_DECL_NODE
(
bn_mean_out
);
PATTERN_DECL_NODE
(
bn_variance_out
);
PATTERN_DECL_NODE
(
bn_saved_variance
);
PATTERN_DECL_NODE
(
bn_saved_mean
);
PATTERN_DECL_NODE
(
bn_reserve_space
);
PATTERN_DECL_NODE
(
bn_out
);
// Elewise_Add input
PATTERN_DECL_NODE
(
elewise_add_in
);
// Elewise_Add output
PATTERN_DECL_NODE
(
elewise_add_out
);
// ACT output
PATTERN_DECL_NODE
(
act_out
);
};
// the backward of act(bn(x) + z)
// op: batch_norm_grad + elewise_add_grad + act_grad
struct
BatchNormAddActGrad
:
public
PatternBase
{
BatchNormAddActGrad
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"bn_add_act_grad"
)
{}
// act_grad: in["Out", "Out@GRAD"], out["X@GRAD"]
// elewise_add_grad: in["Out@GRAD"], out["X@GRAD", "Y@GRAD"]
// bn_grad: in["X", "Z", "Y@GRAD", "Scale", "Bias", "SavedMean",
// "SavedVariance",
// "ReserveSpace"],
// out["X@GRAD", "Z@GRAD", "Scale@GRAD", "Bias@GRAD"]
PDNode
*
operator
()(
PDNode
*
x
,
std
::
unordered_set
<
std
::
string
>
act_grad_types
);
// declare operator node's name
PATTERN_DECL_NODE
(
act_grad
);
PATTERN_DECL_NODE
(
elewise_add_grad
);
PATTERN_DECL_NODE
(
batch_norm_grad
);
// declare variable node's name
PATTERN_DECL_NODE
(
act_out
);
PATTERN_DECL_NODE
(
d_act_x
);
PATTERN_DECL_NODE
(
d_elewise_add_in
);
PATTERN_DECL_NODE
(
d_bn_out
);
PATTERN_DECL_NODE
(
bn_x
);
PATTERN_DECL_NODE
(
bn_scale
);
PATTERN_DECL_NODE
(
bn_bias
);
PATTERN_DECL_NODE
(
bn_saved_mean
);
PATTERN_DECL_NODE
(
bn_saved_variance
);
PATTERN_DECL_NODE
(
bn_reserve_space
);
PATTERN_DECL_NODE
(
d_bn_x
);
PATTERN_DECL_NODE
(
d_bn_scale
);
PATTERN_DECL_NODE
(
d_bn_bias
);
};
// The following patterns are used to fuse elewise_add and act
// The following patterns are used to fuse elewise_add and act
// formula: act(ele_add(x, y))
// formula: act(ele_add(x, y))
// op: elementwise_add + act
// op: elementwise_add + act
...
...
paddle/fluid/operators/fused/fused_bn_add_activation_op.cc
浏览文件 @
fdc06f21
...
@@ -186,8 +186,6 @@ void FusedBatchNormAddActGradOp::InferShape(
...
@@ -186,8 +186,6 @@ void FusedBatchNormAddActGradOp::InferShape(
// check input
// check input
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"FusedBatchNormAddActGradOp"
);
"FusedBatchNormAddActGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Z"
),
"Input"
,
"Z"
,
"FusedBatchNormAddActGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Scale"
),
"Input"
,
"Scale"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Scale"
),
"Input"
,
"Scale"
,
"FusedBatchNormAddActGradOp"
);
"FusedBatchNormAddActGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"SavedMean"
),
"Input"
,
"SavedMean"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"SavedMean"
),
"Input"
,
"SavedMean"
,
...
...
paddle/fluid/operators/fused/fused_bn_add_activation_op.cu
浏览文件 @
fdc06f21
...
@@ -188,7 +188,6 @@ class FusedBatchNormAddActGradKernel<platform::CUDADeviceContext, T>
...
@@ -188,7 +188,6 @@ class FusedBatchNormAddActGradKernel<platform::CUDADeviceContext, T>
std
::
string
act_type
=
ctx
.
Attr
<
std
::
string
>
(
"act_type"
);
std
::
string
act_type
=
ctx
.
Attr
<
std
::
string
>
(
"act_type"
);
const
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
auto
*
z
=
ctx
.
Input
<
Tensor
>
(
"Z"
);
const
auto
*
y
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
const
auto
*
y
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
const
auto
*
d_y
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
const
auto
*
d_y
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
const
auto
*
scale
=
ctx
.
Input
<
Tensor
>
(
"Scale"
);
const
auto
*
scale
=
ctx
.
Input
<
Tensor
>
(
"Scale"
);
...
...
paddle/fluid/operators/fused/fused_bn_add_activation_op.h
浏览文件 @
fdc06f21
...
@@ -61,7 +61,6 @@ class FusedBatchNormAddActGradOpMaker : public framework::SingleGradOpMaker<T> {
...
@@ -61,7 +61,6 @@ class FusedBatchNormAddActGradOpMaker : public framework::SingleGradOpMaker<T> {
void
Apply
(
GradOpPtr
<
T
>
op
)
const
override
{
void
Apply
(
GradOpPtr
<
T
>
op
)
const
override
{
op
->
SetType
(
this
->
ForwardOpType
()
+
"_grad"
);
op
->
SetType
(
this
->
ForwardOpType
()
+
"_grad"
);
op
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
op
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
op
->
SetInput
(
"Z"
,
this
->
Input
(
"Z"
));
op
->
SetInput
(
"Y"
,
this
->
Output
(
"Y"
));
op
->
SetInput
(
"Y"
,
this
->
Output
(
"Y"
));
op
->
SetInput
(
framework
::
GradVarName
(
"Y"
),
this
->
OutputGrad
(
"Y"
));
op
->
SetInput
(
framework
::
GradVarName
(
"Y"
),
this
->
OutputGrad
(
"Y"
));
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
fdc06f21
...
@@ -2500,6 +2500,31 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -2500,6 +2500,31 @@ All parameter, weight, gradient are variables in Paddle.
build_strategy = static.BuildStrategy()
build_strategy = static.BuildStrategy()
build_strategy.fuse_bn_act_ops = True
build_strategy.fuse_bn_act_ops = True
)DOC"
)
)DOC"
)
.
def_property
(
"fuse_bn_add_act_ops"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
fuse_bn_add_act_ops_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
PADDLE_ENFORCE_NE
(
self
.
IsFinalized
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
"BuildStrategy has been finlaized, cannot be "
"configured again."
));
self
.
fuse_bn_add_act_ops_
=
b
;
},
R"DOC((bool, optional): fuse_bn_add_act_ops indicate whether
to fuse batch_norm, elementwise_add and activation_op,
it may make the execution faster. Default is True
Examples:
.. code-block:: python
import paddle
import paddle.static as static
paddle.enable_static()
build_strategy = static.BuildStrategy()
build_strategy.fuse_bn_add_act_ops = True
)DOC"
)
.
def_property
(
.
def_property
(
"enable_auto_fusion"
,
"enable_auto_fusion"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
enable_auto_fusion_
;
},
[](
const
BuildStrategy
&
self
)
{
return
self
.
enable_auto_fusion_
;
},
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
fdc06f21
...
@@ -331,6 +331,7 @@ list(REMOVE_ITEM TEST_OPS test_basic_gru_unit_op)
...
@@ -331,6 +331,7 @@ list(REMOVE_ITEM TEST_OPS test_basic_gru_unit_op)
list
(
REMOVE_ITEM TEST_OPS test_basic_lstm_api
)
list
(
REMOVE_ITEM TEST_OPS test_basic_lstm_api
)
list
(
REMOVE_ITEM TEST_OPS test_basic_lstm_unit_op
)
list
(
REMOVE_ITEM TEST_OPS test_basic_lstm_unit_op
)
list
(
REMOVE_ITEM TEST_OPS test_fuse_bn_act_pass
)
list
(
REMOVE_ITEM TEST_OPS test_fuse_bn_act_pass
)
list
(
REMOVE_ITEM TEST_OPS test_fuse_bn_add_act_pass
)
list
(
REMOVE_ITEM TEST_OPS test_imperative_static_runner_mnist
)
list
(
REMOVE_ITEM TEST_OPS test_imperative_static_runner_mnist
)
list
(
REMOVE_ITEM TEST_OPS test_imperative_static_runner_while
)
list
(
REMOVE_ITEM TEST_OPS test_imperative_static_runner_while
)
list
(
REMOVE_ITEM TEST_OPS test_conv3d_transpose_op
)
list
(
REMOVE_ITEM TEST_OPS test_conv3d_transpose_op
)
...
@@ -515,6 +516,7 @@ py_test_modules(test_parallel_executor_transformer_auto_growth MODULES test_para
...
@@ -515,6 +516,7 @@ py_test_modules(test_parallel_executor_transformer_auto_growth MODULES test_para
py_test_modules
(
test_data_norm_op MODULES test_data_norm_op
)
py_test_modules
(
test_data_norm_op MODULES test_data_norm_op
)
py_test_modules
(
test_fuse_bn_act_pass MODULES test_fuse_bn_act_pass ENVS FLAGS_cudnn_deterministic=1 FLAGS_cudnn_batchnorm_spatial_persistent=1 FLAGS_conv_workspace_size_limit=1000
)
py_test_modules
(
test_fuse_bn_act_pass MODULES test_fuse_bn_act_pass ENVS FLAGS_cudnn_deterministic=1 FLAGS_cudnn_batchnorm_spatial_persistent=1 FLAGS_conv_workspace_size_limit=1000
)
py_test_modules
(
test_fuse_bn_add_act_pass MODULES test_fuse_bn_add_act_pass ENVS FLAGS_cudnn_deterministic=1 FLAGS_cudnn_batchnorm_spatial_persistent=1 FLAGS_conv_workspace_size_limit=1000
)
# NOTE: These unittests will appear NaN steadily in windows CI. After analysis,
# NOTE: These unittests will appear NaN steadily in windows CI. After analysis,
# it is found that windows CI will run all the training unittests with the ON_INFER option turned on,
# it is found that windows CI will run all the training unittests with the ON_INFER option turned on,
...
...
python/paddle/fluid/tests/unittests/test_fuse
d_bn_add_act
.py
→
python/paddle/fluid/tests/unittests/test_fuse
_bn_add_act_pass
.py
浏览文件 @
fdc06f21
...
@@ -21,6 +21,8 @@ import paddle
...
@@ -21,6 +21,8 @@ import paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.fluid
import
core
from
paddle.fluid
import
core
paddle
.
enable_static
()
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"Paddle core is not compiled with CUDA"
)
"Paddle core is not compiled with CUDA"
)
...
@@ -163,12 +165,16 @@ class TestFusedBnAddActAPI(unittest.TestCase):
...
@@ -163,12 +165,16 @@ class TestFusedBnAddActAPI(unittest.TestCase):
iters
=
5
iters
=
5
batch_size
=
16
batch_size
=
16
# build_fused_program
# build_fused_program
: turn on fuse_bn_add_act_ops
main_program
=
fluid
.
Program
()
main_program
=
fluid
.
Program
()
startup_program
=
fluid
.
Program
()
startup_program
=
fluid
.
Program
()
x
,
y
,
loss
=
self
.
build_
fused
_program
(
main_program
,
startup_program
,
x
,
y
,
loss
=
self
.
build_
origin
_program
(
main_program
,
startup_program
,
use_cuda
)
use_cuda
)
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
x
,
y
],
place
=
place
)
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
x
,
y
],
place
=
place
)
build_strategy_fused
=
fluid
.
BuildStrategy
()
build_strategy_fused
.
fuse_bn_add_act_ops
=
True
binary_fused
=
fluid
.
CompiledProgram
(
main_program
).
with_data_parallel
(
loss_name
=
loss
.
name
,
build_strategy
=
build_strategy_fused
)
train_reader
=
paddle
.
batch
(
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
train
(),
batch_size
=
batch_size
)
paddle
.
dataset
.
mnist
.
train
(),
batch_size
=
batch_size
)
exe
=
fluid
.
Executor
(
place
)
exe
=
fluid
.
Executor
(
place
)
...
@@ -178,17 +184,16 @@ class TestFusedBnAddActAPI(unittest.TestCase):
...
@@ -178,17 +184,16 @@ class TestFusedBnAddActAPI(unittest.TestCase):
exe
.
run
(
startup_program
)
exe
.
run
(
startup_program
)
for
_
in
range
(
iters
):
for
_
in
range
(
iters
):
data
=
next
(
train_reader
())
data
=
next
(
train_reader
())
loss_v
=
exe
.
run
(
main_program
,
loss_v
=
exe
.
run
(
binary_fused
,
feed
=
feeder
.
feed
(
data
),
feed
=
feeder
.
feed
(
data
),
fetch_list
=
[
loss
])
fetch_list
=
[
loss
])
loss_vals_fused
.
append
(
loss_v
[
0
][
0
])
loss_vals_fused
.
append
(
loss_v
[
0
][
0
])
# build_origin_program
# build_origin_program: turn off fused_bn_act_ops
main_program
=
fluid
.
Program
()
build_strategy
=
fluid
.
BuildStrategy
()
startup_program
=
fluid
.
Program
()
build_strategy
.
fuse_bn_add_act_ops
=
False
x
,
y
,
loss
=
self
.
build_origin_program
(
main_program
,
startup_program
,
binary
=
fluid
.
CompiledProgram
(
main_program
).
with_data_parallel
(
use_cuda
)
loss_name
=
loss
.
name
,
build_strategy
=
build_strategy
)
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
x
,
y
],
place
=
place
)
train_reader
=
paddle
.
batch
(
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
train
(),
batch_size
=
batch_size
)
paddle
.
dataset
.
mnist
.
train
(),
batch_size
=
batch_size
)
loss_vals
=
[]
loss_vals
=
[]
...
@@ -197,7 +202,7 @@ class TestFusedBnAddActAPI(unittest.TestCase):
...
@@ -197,7 +202,7 @@ class TestFusedBnAddActAPI(unittest.TestCase):
exe
.
run
(
startup_program
)
exe
.
run
(
startup_program
)
for
_
in
range
(
iters
):
for
_
in
range
(
iters
):
data
=
next
(
train_reader
())
data
=
next
(
train_reader
())
loss_v
=
exe
.
run
(
main_program
,
loss_v
=
exe
.
run
(
binary
,
feed
=
feeder
.
feed
(
data
),
feed
=
feeder
.
feed
(
data
),
fetch_list
=
[
loss
])
fetch_list
=
[
loss
])
loss_vals
.
append
(
loss_v
[
0
][
0
])
loss_vals
.
append
(
loss_v
[
0
][
0
])
...
@@ -210,6 +215,25 @@ class TestFusedBnAddActAPI(unittest.TestCase):
...
@@ -210,6 +215,25 @@ class TestFusedBnAddActAPI(unittest.TestCase):
place
=
fluid
.
CUDAPlace
(
0
)
place
=
fluid
.
CUDAPlace
(
0
)
self
.
check
(
place
,
use_cuda
=
True
)
self
.
check
(
place
,
use_cuda
=
True
)
def
test_fuse_bn_add_act_API
(
self
):
# build_fused_program: use fused_bn_add_act python API
main_program
=
fluid
.
Program
()
startup_program
=
fluid
.
Program
()
place
=
fluid
.
CUDAPlace
(
0
)
x
,
y
,
loss
=
self
.
build_fused_program
(
main_program
,
startup_program
,
use_cuda
=
True
)
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
x
,
y
],
place
=
place
)
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
train
(),
batch_size
=
16
)
exe
=
fluid
.
Executor
(
place
)
scope
=
fluid
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
exe
.
run
(
startup_program
)
for
_
in
range
(
5
):
data
=
next
(
train_reader
())
loss_v
=
exe
.
run
(
main_program
,
feed
=
feeder
.
feed
(
data
),
fetch_list
=
[
loss
])
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录