Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
7db747d9
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7db747d9
编写于
10月 26, 2020
作者:
A
Adam Osewski
提交者:
GitHub
10月 26, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
oneDNN BatchNorm + Act fusion pass. (#27912)
上级
fb7f8529
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
657 addition
and
0 deletion
+657
-0
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+2
-0
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+20
-0
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+21
-0
paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.cc
paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.cc
+108
-0
paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h
paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h
+44
-0
paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc
...id/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc
+382
-0
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+1
-0
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_batch_norm_act_fuse_pass.py
...ests/ir/inference/test_mkldnn_batch_norm_act_fuse_pass.py
+79
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
7db747d9
...
@@ -110,6 +110,7 @@ if(WITH_MKLDNN)
...
@@ -110,6 +110,7 @@ if(WITH_MKLDNN)
pass_library
(
cpu_quantize_squash_pass inference DIR mkldnn
)
pass_library
(
cpu_quantize_squash_pass inference DIR mkldnn
)
pass_library
(
reshape_transpose_matmul_mkldnn_fuse_pass inference DIR mkldnn
)
pass_library
(
reshape_transpose_matmul_mkldnn_fuse_pass inference DIR mkldnn
)
pass_library
(
matmul_transpose_reshape_fuse_pass inference DIR mkldnn
)
pass_library
(
matmul_transpose_reshape_fuse_pass inference DIR mkldnn
)
pass_library
(
batch_norm_act_fuse_pass inference DIR mkldnn
)
endif
()
endif
()
cc_library
(
fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector
)
cc_library
(
fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector
)
...
@@ -151,6 +152,7 @@ if (WITH_MKLDNN)
...
@@ -151,6 +152,7 @@ if (WITH_MKLDNN)
cc_test
(
test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass
)
cc_test
(
test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass
)
cc_test
(
test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass
)
cc_test
(
test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass
)
cc_test
(
test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass
)
cc_test
(
test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass
)
cc_test
(
test_batch_norm_act_fuse_pass SRCS mkldnn/batch_norm_act_fuse_pass_tester.cc DEPS batch_norm_act_fuse_pass
)
set
(
TEST_CONV_BN_PASS_DEPS conv_bn_fuse_pass graph_to_program_pass conv_op conv_transpose_op math_function im2col vol2col batch_norm_op gelu_op activation_op elementwise_add_op concat_and_split naive_executor device_context
)
set
(
TEST_CONV_BN_PASS_DEPS conv_bn_fuse_pass graph_to_program_pass conv_op conv_transpose_op math_function im2col vol2col batch_norm_op gelu_op activation_op elementwise_add_op concat_and_split naive_executor device_context
)
if
(
WITH_GPU
)
if
(
WITH_GPU
)
set
(
TEST_CONV_BN_PASS_DEPS
${
TEST_CONV_BN_PASS_DEPS
}
depthwise_conv
)
set
(
TEST_CONV_BN_PASS_DEPS
${
TEST_CONV_BN_PASS_DEPS
}
depthwise_conv
)
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
7db747d9
...
@@ -1188,6 +1188,26 @@ PDNode *patterns::BatchNormActGrad::operator()(
...
@@ -1188,6 +1188,26 @@ PDNode *patterns::BatchNormActGrad::operator()(
return
bn_grad
;
return
bn_grad
;
}
}
PDNode
*
patterns
::
BatchNormActOneDNN
::
operator
()(
const
std
::
string
&
act_type
)
{
auto
*
bn_x
=
pattern
->
NewNode
(
bn_in_repr
())
->
AsInput
()
->
assert_is_op_input
(
"batch_norm"
,
"X"
);
auto
*
bn
=
pattern
->
NewNode
(
batch_norm_repr
())
->
assert_is_op
(
"batch_norm"
);
auto
*
bn_out
=
pattern
->
NewNode
(
bn_out_repr
())
->
assert_is_op_output
(
"batch_norm"
,
"Y"
)
->
assert_is_op_input
(
act_type
);
auto
*
act
=
pattern
->
NewNode
(
act_repr
())
->
assert_is_op
(
act_type
)
->
AsIntermediate
();
auto
*
act_out
=
pattern
->
NewNode
(
act_out_repr
())
->
assert_is_op_output
(
act_type
,
"Out"
)
->
AsOutput
();
bn
->
LinksFrom
({
bn_x
}).
LinksTo
({
bn_out
});
act
->
LinksFrom
({
bn_out
}).
LinksTo
({
act_out
});
return
act_out
;
}
PDNode
*
patterns
::
ElewiseAddAct
::
operator
()(
PDNode
*
patterns
::
ElewiseAddAct
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
ele_x_var
,
paddle
::
framework
::
ir
::
PDNode
*
ele_x_var
,
std
::
unordered_set
<
std
::
string
>
act_types
)
{
std
::
unordered_set
<
std
::
string
>
act_types
)
{
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
7db747d9
...
@@ -664,6 +664,27 @@ struct BatchNormActGrad : public PatternBase {
...
@@ -664,6 +664,27 @@ struct BatchNormActGrad : public PatternBase {
PATTERN_DECL_NODE
(
d_bn_bias
);
PATTERN_DECL_NODE
(
d_bn_bias
);
};
};
//
// \brief Pattern looking for batch_norm and a directly following activation
// operator.
//
// \note Currently only ReLU is supported as an activation function.
// Formula: act(bn(x))
// Op: batch_norm + act
struct
BatchNormActOneDNN
:
public
PatternBase
{
BatchNormActOneDNN
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"bn_act_onednn"
)
{}
PDNode
*
operator
()(
const
std
::
string
&
act_type
);
// declare operator node's name
PATTERN_DECL_NODE
(
bn_in
);
PATTERN_DECL_NODE
(
batch_norm
);
PATTERN_DECL_NODE
(
act
);
PATTERN_DECL_NODE
(
bn_out
);
PATTERN_DECL_NODE
(
act_out
);
};
// The following patterns are used to fuse elewise_add and act
// The following patterns are used to fuse elewise_add and act
// formula: act(ele_add(x, y))
// formula: act(ele_add(x, y))
// op: elementwise_add + act
// op: elementwise_add + act
...
...
paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.cc
0 → 100644
浏览文件 @
7db747d9
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
using
string
::
PrettyLogDetail
;
void
FuseBatchNormActOneDNNPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
std
::
string
act_type
(
"relu"
);
FuseBatchNormAct
(
graph
,
act_type
);
}
void
FuseBatchNormActOneDNNPass
::
FuseBatchNormAct
(
Graph
*
graph
,
const
std
::
string
&
act_type
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
InvalidArgument
(
"The input graph of "
"FuseBatchNormActOneDNNPass should not be nullptr."
));
FusePassBase
::
Init
(
"bn_act"
,
graph
);
GraphPatternDetector
gpd
;
patterns
::
BatchNormActOneDNN
bn_act_pattern
(
gpd
.
mutable_pattern
(),
"bn_act"
);
bn_act_pattern
(
act_type
);
int
found_bn_act_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
4
)
<<
"Fuse BatchNorm with ReLU activation op."
;
// BN output
GET_IR_NODE_FROM_SUBGRAPH
(
bn_out
,
bn_out
,
bn_act_pattern
);
// ACT output
GET_IR_NODE_FROM_SUBGRAPH
(
act_out
,
act_out
,
bn_act_pattern
);
// ops
GET_IR_NODE_FROM_SUBGRAPH
(
batch_norm
,
batch_norm
,
bn_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
act
,
act
,
bn_act_pattern
);
auto
*
bn_op
=
batch_norm
->
Op
();
if
(
bn_op
->
HasAttr
(
"use_mkldnn"
))
{
PADDLE_ENFORCE
(
BOOST_GET_CONST
(
bool
,
bn_op
->
GetAttr
(
"use_mkldnn"
)),
platform
::
errors
::
PreconditionNotMet
(
"The BatchNorm+Act fusion may happen only when oneDNN library "
"is used."
));
}
if
(
bn_op
->
HasAttr
(
"trainable_statistics"
))
{
PADDLE_ENFORCE
(
!
BOOST_GET_CONST
(
bool
,
bn_op
->
GetAttr
(
"trainable_statistics"
)),
platform
::
errors
::
PreconditionNotMet
(
"The BatchNorm+Act fusion may happen only when mean and variance "
"are not calculated by current batch statistics."
));
}
if
(
bn_op
->
HasAttr
(
"is_test"
))
{
PADDLE_ENFORCE
(
BOOST_GET_CONST
(
bool
,
bn_op
->
GetAttr
(
"is_test"
)),
platform
::
errors
::
PreconditionNotMet
(
"The BatchNorm+Act fusion may happen only during inference."
));
}
bn_op
->
SetAttr
(
"use_mkldnn"
,
true
);
bn_op
->
SetAttr
(
"is_test"
,
true
);
bn_op
->
SetAttr
(
"fuse_with_relu"
,
true
);
bn_op
->
SetAttr
(
"trainable_statistics"
,
false
);
bn_op
->
SetOutput
(
"Y"
,
{
act_out
->
Name
()});
IR_OP_VAR_LINK
(
batch_norm
,
act_out
);
GraphSafeRemoveNodes
(
g
,
{
act
,
bn_out
});
found_bn_act_count
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_bn_act_count
);
PrettyLogDetail
(
"--- fused %d batch norm with relu activation"
,
found_bn_act_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
batch_norm_act_fuse_pass
,
paddle
::
framework
::
ir
::
FuseBatchNormActOneDNNPass
);
REGISTER_PASS_CAPABILITY
(
batch_norm_act_fuse_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
EQ
(
"batch_norm"
,
0
)
.
EQ
(
"relu"
,
0
));
paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h
0 → 100644
浏览文件 @
7db747d9
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
/*
* \brief Fuse the BatchNorm and activation operators into single OneDNN's
* BatchNorm with post-op.
*
* \note Currently only ReLU is supported as an activation function.
*/
class
FuseBatchNormActOneDNNPass
:
public
FusePassBase
{
public:
virtual
~
FuseBatchNormActOneDNNPass
()
{}
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
void
FuseBatchNormAct
(
ir
::
Graph
*
graph
,
const
std
::
string
&
act_types
)
const
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc
0 → 100644
浏览文件 @
7db747d9
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <algorithm>
#include <exception>
#include <functional>
#include <iterator>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/errors.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
// -------------------------- helper functions --------------------------------
namespace
{
using
InOutVarNamePair
=
std
::
pair
<
std
::
string
,
std
::
string
>
;
using
OpTypeCountPair
=
std
::
pair
<
std
::
string
,
int
>
;
///
/// @brief Creates the specified operator and sets up its inputs/outputs.
///
/// @param prog The program descriptor to which we add new op.
/// @param[in] op_type_name The operator type name.
/// @param[in] inputs The vector of input pairs: {input_name, variable
/// name}
/// @param[in] outputs The vector of output pairs {output_name, variable}
/// @param[in] use_mkldnn The flag deciding whether or not to set
/// 'use_mkldnn' attribute.
///
/// @return Returns pointer to the created operator descriptor.
///
OpDesc
*
CreateOp
(
ProgramDesc
*
prog
,
const
std
::
string
&
op_type_name
,
const
std
::
vector
<
InOutVarNamePair
>&
inputs
,
const
std
::
vector
<
InOutVarNamePair
>&
outputs
,
bool
use_mkldnn
=
true
)
{
auto
op
=
prog
->
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
op_type_name
);
op
->
SetAttr
(
"use_mkldnn"
,
use_mkldnn
);
for
(
const
auto
&
input
:
inputs
)
{
op
->
SetInput
(
input
.
first
,
{
input
.
second
});
}
for
(
const
auto
&
output
:
outputs
)
{
op
->
SetOutput
(
output
.
first
,
{
output
.
second
});
}
return
op
;
}
///
/// @brief Check whether node 'to' is reachable from node 'from' in graph.
///
/// @param[in] graph The graph we're checking for reachability.
/// @param[in] from The 'from' node name.
/// @param[in] to The 'to' node name.
///
/// @return True if there is connection between nodes 'from' and 'to'.
///
bool
TestIsReachable
(
const
Graph
&
graph
,
std
::
string
from
,
std
::
string
to
)
{
auto
hash
=
[](
const
Node
*
node
)
->
std
::
string
{
return
node
->
Name
()
+
std
::
to_string
(
node
->
id
());
};
auto
find_node
=
[
&
](
const
Graph
&
graph
,
const
std
::
string
&
name
)
->
Node
*
{
for
(
auto
&
node
:
GraphTraits
::
DFS
(
graph
))
{
if
(
name
==
hash
(
&
node
))
{
return
&
node
;
}
}
return
nullptr
;
};
if
(
from
==
to
)
return
true
;
std
::
map
<
std
::
string
,
bool
>
visited
;
// update the from and to strings to hashed equivs in loop from graph traits
for
(
auto
&
node
:
GraphTraits
::
DFS
(
graph
))
{
auto
hashed
=
hash
(
&
node
);
if
(
node
.
Name
()
==
from
)
{
from
=
hashed
;
}
if
(
node
.
Name
()
==
to
)
{
to
=
hashed
;
}
visited
[
hashed
]
=
false
;
}
visited
[
from
]
=
true
;
std
::
list
<
std
::
string
>
queue
;
queue
.
push_back
(
from
);
while
(
!
queue
.
empty
())
{
auto
cur
=
find_node
(
graph
,
queue
.
front
());
queue
.
pop_front
();
if
(
cur
==
nullptr
)
{
return
false
;
}
for
(
auto
n
:
cur
->
outputs
)
{
auto
hashed_name
=
hash
(
n
);
if
(
hashed_name
==
to
)
{
return
true
;
}
if
(
!
visited
[
hashed_name
])
{
visited
[
hashed_name
]
=
true
;
queue
.
push_back
(
hashed_name
);
}
}
}
return
false
;
}
///
/// @brief Search through graph and counts provided operator occurences.
///
/// @param[in] graph The graph we search through.
/// @param[in] op_type_count The vector of pairs {op_type_name, op count}
///
/// @note After going through all graph nodes this function asserts
/// whether counted number for each requested op is as expected.
///
void
AssertOpsCount
(
const
Graph
&
graph
,
std
::
vector
<
OpTypeCountPair
>
op_type_count
)
{
for
(
auto
*
node
:
graph
.
Nodes
())
{
if
(
!
node
->
IsOp
())
{
continue
;
}
const
std
::
string
op_type_name
=
node
->
Op
()
->
Type
();
auto
op_it
=
std
::
find_if
(
std
::
begin
(
op_type_count
),
std
::
end
(
op_type_count
),
[
op_type_name
](
const
OpTypeCountPair
&
p
)
{
return
op_type_name
==
p
.
first
;
});
if
(
op_it
!=
std
::
end
(
op_type_count
))
{
op_it
->
second
--
;
}
}
for
(
const
OpTypeCountPair
&
p
:
op_type_count
)
{
EXPECT_EQ
(
p
.
second
,
0
);
}
}
///
/// @brief Builds a program descriptor.
///
/// @param[in] transient_vars The vector of transient variables names.
/// @param[in] persistent_vars The vector of persistent variables names. Those
/// will have persistable attribute set to true.
///
/// @return The program descriptor object.
///
ProgramDesc
BuildProgramDesc
(
const
std
::
vector
<
std
::
string
>&
transient_vars
,
const
std
::
vector
<
std
::
string
>&
persistent_vars
)
{
ProgramDesc
prog
;
auto
add_var_to_prog
=
[
&
prog
](
const
std
::
string
&
var_name
)
->
VarDesc
*
{
auto
var
=
prog
.
MutableBlock
(
0
)
->
Var
(
var_name
);
var
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
return
var
;
};
for
(
const
auto
&
v
:
transient_vars
)
{
add_var_to_prog
(
v
);
}
for
(
const
auto
&
v
:
persistent_vars
)
{
auto
*
var
=
add_var_to_prog
(
v
);
var
->
SetPersistable
(
true
);
}
return
prog
;
}
///
/// @brief Execute pass on provided graph and perform checks.
///
/// @param graph The graph we run pass on.
/// @param[in] from The name of a 'starting' node sequence in a
/// graph. This would be used to test for
/// correct node connections.
/// @param[in] to The name of a 'ending' node sequence in a
/// graph. This would be used to test for
/// correct node connections.
/// @param[in] removed_nodes_count The number of nodes we expect will be
/// removed/fused after pass execution.
/// @param[in] added_nodes_count The number of nodes we expect will be
/// added after pass execution.
///
void
RunPassAndAssert
(
Graph
*
graph
,
const
std
::
string
&
from
,
const
std
::
string
&
to
,
int
removed_nodes_count
,
int
added_nodes_count
=
0
)
{
EXPECT_TRUE
(
TestIsReachable
(
*
graph
,
from
,
to
));
int
original_nodes_num
=
graph
->
Nodes
().
size
();
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"batch_norm_act_fuse_pass"
);
pass
->
Apply
(
graph
);
int
current_nodes_num
=
graph
->
Nodes
().
size
();
EXPECT_TRUE
(
TestIsReachable
(
*
graph
,
from
,
to
));
EXPECT_EQ
(
original_nodes_num
-
removed_nodes_count
+
added_nodes_count
,
current_nodes_num
);
}
void
SetBatchNormAttrs
(
OpDesc
*
bn_op
,
bool
is_test
=
true
,
bool
trainable_stats
=
true
)
{
bn_op
->
SetAttr
(
"is_test"
,
is_test
);
bn_op
->
SetAttr
(
"trainable_statistics"
,
trainable_stats
);
bn_op
->
SetAttr
(
"fuse_with_relu"
,
false
);
}
}
// namespace
// ------------------------------ Test cases -----------------------------------
// The below test cases are distinguished by whether following attributes have
// true or false value:
// - is_test
// - trainable_statistics
// The test case name would have only attributes with true value in its name.
TEST
(
FuseBatchNormActOneDNNPass
,
ThrowIsTestTrainableStats
)
{
auto
prog
=
BuildProgramDesc
(
{
"x"
,
"m"
,
"v"
,
"bn_y"
,
"act_y"
,
"m_out"
,
"var_out"
,
"sm"
,
"sv"
},
{
"scale"
,
"bias"
});
auto
*
bn_op
=
CreateOp
(
&
prog
,
"batch_norm"
,
{{
"X"
,
"x"
},
{
"Scale"
,
"scale"
},
{
"Bias"
,
"bias"
},
{
"Mean"
,
"m"
},
{
"Variance"
,
"v"
}},
{{
"Y"
,
"bn_y"
},
{
"MeanOut"
,
"m_out"
},
{
"VarianceOut"
,
"var_out"
},
{
"SavedMean"
,
"sm"
},
{
"SavedVariance"
,
"sv"
}});
SetBatchNormAttrs
(
bn_op
,
true
,
true
);
CreateOp
(
&
prog
,
"relu"
,
{{
"X"
,
"bn_y"
}},
{{
"Out"
,
"act_y"
}},
false
);
Graph
graph
(
prog
);
// No fusion in this attribute configuration
constexpr
int
removed_nodes_count
=
0
;
EXPECT_THROW
(
RunPassAndAssert
(
&
graph
,
"x"
,
"act_y"
,
removed_nodes_count
),
paddle
::
platform
::
EnforceNotMet
);
}
TEST
(
FuseBatchNormActOneDNNPass
,
FuseIsTest
)
{
auto
prog
=
BuildProgramDesc
({
"x"
,
"m"
,
"v"
,
"bn_y"
,
"act_y"
},
{
"scale"
,
"bias"
});
auto
*
bn_op
=
CreateOp
(
&
prog
,
"batch_norm"
,
{{
"X"
,
"x"
},
{
"Scale"
,
"scale"
},
{
"Bias"
,
"bias"
},
{
"Mean"
,
"m"
},
{
"Variance"
,
"v"
}},
{{
"Y"
,
"bn_y"
}});
SetBatchNormAttrs
(
bn_op
,
true
,
false
);
CreateOp
(
&
prog
,
"relu"
,
{{
"X"
,
"bn_y"
}},
{{
"Out"
,
"act_y"
}},
false
);
Graph
graph
(
prog
);
constexpr
int
removed_nodes_count
=
2
;
RunPassAndAssert
(
&
graph
,
"x"
,
"act_y"
,
removed_nodes_count
);
AssertOpsCount
(
graph
,
{{
"batch_norm"
,
1
},
{
"relu"
,
0
}});
for
(
const
auto
*
node
:
graph
.
Nodes
())
{
if
(
node
->
IsOp
()
&&
node
->
Op
()
->
Type
()
==
"batch_norm"
)
{
const
auto
*
op
=
node
->
Op
();
ASSERT_TRUE
(
op
->
HasAttr
(
"use_mkldnn"
));
EXPECT_TRUE
(
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"use_mkldnn"
)));
ASSERT_TRUE
(
op
->
HasAttr
(
"fuse_with_relu"
));
EXPECT_TRUE
(
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"fuse_with_relu"
)));
ASSERT_TRUE
(
op
->
HasAttr
(
"trainable_statistics"
));
EXPECT_FALSE
(
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"trainable_statistics"
)));
}
}
}
TEST
(
FuseBatchNormActOneDNNPass
,
ThrowTrainableStats
)
{
auto
prog
=
BuildProgramDesc
(
{
"x"
,
"m"
,
"v"
,
"bn_y"
,
"act_y"
,
"m_out"
,
"var_out"
,
"sm"
,
"sv"
},
{
"scale"
,
"bias"
});
auto
*
bn_op
=
CreateOp
(
&
prog
,
"batch_norm"
,
{{
"X"
,
"x"
},
{
"Scale"
,
"scale"
},
{
"Bias"
,
"bias"
},
{
"Mean"
,
"m"
},
{
"Variance"
,
"v"
}},
{{
"Y"
,
"bn_y"
},
{
"MeanOut"
,
"m_out"
},
{
"VarianceOut"
,
"var_out"
},
{
"SavedMean"
,
"sm"
},
{
"SavedVariance"
,
"sv"
}});
SetBatchNormAttrs
(
bn_op
,
false
,
true
);
CreateOp
(
&
prog
,
"relu"
,
{{
"X"
,
"bn_y"
}},
{{
"Out"
,
"act_y"
}},
false
);
Graph
graph
(
prog
);
// No fusion in this attribute configuration
constexpr
int
removed_nodes_count
=
0
;
EXPECT_THROW
(
RunPassAndAssert
(
&
graph
,
"x"
,
"act_y"
,
removed_nodes_count
),
paddle
::
platform
::
EnforceNotMet
);
}
TEST
(
FuseBatchNormActOneDNNPass
,
AllAttrsFalse
)
{
auto
prog
=
BuildProgramDesc
(
{
"x"
,
"m"
,
"v"
,
"bn_y"
,
"act_y"
,
"m_out"
,
"var_out"
,
"sm"
,
"sv"
},
{
"scale"
,
"bias"
});
auto
*
bn_op
=
CreateOp
(
&
prog
,
"batch_norm"
,
{{
"X"
,
"x"
},
{
"Scale"
,
"scale"
},
{
"Bias"
,
"bias"
},
{
"Mean"
,
"m"
},
{
"Variance"
,
"v"
}},
{{
"Y"
,
"bn_y"
},
{
"MeanOut"
,
"m_out"
},
{
"VarianceOut"
,
"var_out"
},
{
"SavedMean"
,
"sm"
},
{
"SavedVariance"
,
"sv"
}});
SetBatchNormAttrs
(
bn_op
,
false
,
false
);
CreateOp
(
&
prog
,
"relu"
,
{{
"X"
,
"bn_y"
}},
{{
"Out"
,
"act_y"
}},
false
);
Graph
graph
(
prog
);
// No fusion in this attribute configuration
constexpr
int
removed_nodes_count
=
0
;
EXPECT_THROW
(
RunPassAndAssert
(
&
graph
,
"x"
,
"act_y"
,
removed_nodes_count
),
paddle
::
platform
::
EnforceNotMet
);
}
TEST
(
FuseBatchNormActOneDNNPass
,
ThrowUseMkldnn
)
{
auto
prog
=
BuildProgramDesc
(
{
"x"
,
"m"
,
"v"
,
"bn_y"
,
"act_y"
,
"m_out"
,
"var_out"
,
"sm"
,
"sv"
},
{
"scale"
,
"bias"
});
auto
*
bn_op
=
CreateOp
(
&
prog
,
"batch_norm"
,
{{
"X"
,
"x"
},
{
"Scale"
,
"scale"
},
{
"Bias"
,
"bias"
},
{
"Mean"
,
"m"
},
{
"Variance"
,
"v"
}},
{{
"Y"
,
"bn_y"
},
{
"MeanOut"
,
"m_out"
},
{
"VarianceOut"
,
"var_out"
},
{
"SavedMean"
,
"sm"
},
{
"SavedVariance"
,
"sv"
}},
false
);
SetBatchNormAttrs
(
bn_op
,
false
,
false
);
CreateOp
(
&
prog
,
"relu"
,
{{
"X"
,
"bn_y"
}},
{{
"Out"
,
"act_y"
}},
false
);
Graph
graph
(
prog
);
// No fusion in this attribute configuration
constexpr
int
removed_nodes_count
=
0
;
EXPECT_THROW
(
RunPassAndAssert
(
&
graph
,
"x"
,
"act_y"
,
removed_nodes_count
),
paddle
::
platform
::
EnforceNotMet
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
USE_PASS
(
batch_norm_act_fuse_pass
);
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
7db747d9
...
@@ -207,6 +207,7 @@ void CpuPassStrategy::EnableMKLDNN() {
...
@@ -207,6 +207,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"matmul_transpose_reshape_fuse_pass"
,
//
"matmul_transpose_reshape_fuse_pass"
,
//
// Disabled due to topology-dependent speed-up
// Disabled due to topology-dependent speed-up
// "fc_mkldnn_pass",
// "fc_mkldnn_pass",
"batch_norm_act_fuse_pass"
,
"mkldnn_inplace_pass"
,
// This pass should be activated after
"mkldnn_inplace_pass"
,
// This pass should be activated after
// fuses
// fuses
}))
{
}))
{
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_batch_norm_act_fuse_pass.py
0 → 100644
浏览文件 @
7db747d9
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for fusion of batch norm and activation."""
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
paddle.fluid
as
fluid
from
inference_pass_test
import
InferencePassTest
from
paddle
import
enable_static
from
paddle.fluid.core
import
PassVersionChecker
enable_static
()
class
BnReluOneDnnFusePassTest
(
InferencePassTest
):
def
setUp
(
self
):
self
.
set_params
()
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
3
,
100
,
100
],
dtype
=
"float32"
)
bn_out
=
fluid
.
layers
.
batch_norm
(
input
=
data
,
is_test
=
True
,
use_global_stats
=
self
.
global_stats
)
relu_out
=
fluid
.
layers
.
relu
(
bn_out
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
((
1
,
3
,
100
,
100
)).
astype
(
"float32"
)
}
self
.
fetch_list
=
[
relu_out
]
self
.
enable_mkldnn
=
True
def
set_params
(
self
):
self
.
global_stats
=
False
self
.
pass_name
=
"batch_norm_act_fuse_pass"
def
test_check_output
(
self
):
self
.
check_output
()
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
self
.
pass_name
))
class
BnReluGlobalStatsOneDnnFusePassTest
(
InferencePassTest
):
def
setUp
(
self
):
self
.
set_params
()
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
3
,
100
,
100
],
dtype
=
"float32"
)
bn_out
=
fluid
.
layers
.
batch_norm
(
input
=
data
,
is_test
=
True
,
use_global_stats
=
self
.
global_stats
)
relu_out
=
fluid
.
layers
.
relu
(
bn_out
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
((
1
,
3
,
100
,
100
)).
astype
(
"float32"
)
}
self
.
fetch_list
=
[
relu_out
]
self
.
enable_mkldnn
=
True
def
set_params
(
self
):
self
.
global_stats
=
True
self
.
pass_name
=
"batch_norm_act_fuse_pass"
def
test_check_output
(
self
):
self
.
check_output
()
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
self
.
pass_name
))
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录