Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
7db747d9
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7db747d9
编写于
10月 26, 2020
作者:
A
Adam Osewski
提交者:
GitHub
10月 26, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
oneDNN BatchNorm + Act fusion pass. (#27912)
上级
fb7f8529
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
657 addition
and
0 deletion
+657
-0
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+2
-0
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+20
-0
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+21
-0
paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.cc
paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.cc
+108
-0
paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h
paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h
+44
-0
paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc
...id/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc
+382
-0
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+1
-0
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_batch_norm_act_fuse_pass.py
...ests/ir/inference/test_mkldnn_batch_norm_act_fuse_pass.py
+79
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
7db747d9
...
...
@@ -110,6 +110,7 @@ if(WITH_MKLDNN)
pass_library
(
cpu_quantize_squash_pass inference DIR mkldnn
)
pass_library
(
reshape_transpose_matmul_mkldnn_fuse_pass inference DIR mkldnn
)
pass_library
(
matmul_transpose_reshape_fuse_pass inference DIR mkldnn
)
pass_library
(
batch_norm_act_fuse_pass inference DIR mkldnn
)
endif
()
cc_library
(
fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector
)
...
...
@@ -151,6 +152,7 @@ if (WITH_MKLDNN)
cc_test
(
test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass
)
cc_test
(
test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass
)
cc_test
(
test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass
)
cc_test
(
test_batch_norm_act_fuse_pass SRCS mkldnn/batch_norm_act_fuse_pass_tester.cc DEPS batch_norm_act_fuse_pass
)
set
(
TEST_CONV_BN_PASS_DEPS conv_bn_fuse_pass graph_to_program_pass conv_op conv_transpose_op math_function im2col vol2col batch_norm_op gelu_op activation_op elementwise_add_op concat_and_split naive_executor device_context
)
if
(
WITH_GPU
)
set
(
TEST_CONV_BN_PASS_DEPS
${
TEST_CONV_BN_PASS_DEPS
}
depthwise_conv
)
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
7db747d9
...
...
@@ -1188,6 +1188,26 @@ PDNode *patterns::BatchNormActGrad::operator()(
return
bn_grad
;
}
PDNode
*
patterns
::
BatchNormActOneDNN
::
operator
()(
const
std
::
string
&
act_type
)
{
auto
*
bn_x
=
pattern
->
NewNode
(
bn_in_repr
())
->
AsInput
()
->
assert_is_op_input
(
"batch_norm"
,
"X"
);
auto
*
bn
=
pattern
->
NewNode
(
batch_norm_repr
())
->
assert_is_op
(
"batch_norm"
);
auto
*
bn_out
=
pattern
->
NewNode
(
bn_out_repr
())
->
assert_is_op_output
(
"batch_norm"
,
"Y"
)
->
assert_is_op_input
(
act_type
);
auto
*
act
=
pattern
->
NewNode
(
act_repr
())
->
assert_is_op
(
act_type
)
->
AsIntermediate
();
auto
*
act_out
=
pattern
->
NewNode
(
act_out_repr
())
->
assert_is_op_output
(
act_type
,
"Out"
)
->
AsOutput
();
bn
->
LinksFrom
({
bn_x
}).
LinksTo
({
bn_out
});
act
->
LinksFrom
({
bn_out
}).
LinksTo
({
act_out
});
return
act_out
;
}
PDNode
*
patterns
::
ElewiseAddAct
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
ele_x_var
,
std
::
unordered_set
<
std
::
string
>
act_types
)
{
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
7db747d9
...
...
@@ -664,6 +664,27 @@ struct BatchNormActGrad : public PatternBase {
PATTERN_DECL_NODE
(
d_bn_bias
);
};
//
// \brief Pattern looking for batch_norm and a directly following activation
// operator.
//
// \note Currently only ReLU is supported as an activation function.
// Formula: act(bn(x))
// Op: batch_norm + act
struct
BatchNormActOneDNN
:
public
PatternBase
{
BatchNormActOneDNN
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"bn_act_onednn"
)
{}
PDNode
*
operator
()(
const
std
::
string
&
act_type
);
// declare operator node's name
PATTERN_DECL_NODE
(
bn_in
);
PATTERN_DECL_NODE
(
batch_norm
);
PATTERN_DECL_NODE
(
act
);
PATTERN_DECL_NODE
(
bn_out
);
PATTERN_DECL_NODE
(
act_out
);
};
// The following patterns are used to fuse elewise_add and act
// formula: act(ele_add(x, y))
// op: elementwise_add + act
...
...
paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.cc
0 → 100644
浏览文件 @
7db747d9
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
using
string
::
PrettyLogDetail
;
void
FuseBatchNormActOneDNNPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
std
::
string
act_type
(
"relu"
);
FuseBatchNormAct
(
graph
,
act_type
);
}
void
FuseBatchNormActOneDNNPass
::
FuseBatchNormAct
(
Graph
*
graph
,
const
std
::
string
&
act_type
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
InvalidArgument
(
"The input graph of "
"FuseBatchNormActOneDNNPass should not be nullptr."
));
FusePassBase
::
Init
(
"bn_act"
,
graph
);
GraphPatternDetector
gpd
;
patterns
::
BatchNormActOneDNN
bn_act_pattern
(
gpd
.
mutable_pattern
(),
"bn_act"
);
bn_act_pattern
(
act_type
);
int
found_bn_act_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
4
)
<<
"Fuse BatchNorm with ReLU activation op."
;
// BN output
GET_IR_NODE_FROM_SUBGRAPH
(
bn_out
,
bn_out
,
bn_act_pattern
);
// ACT output
GET_IR_NODE_FROM_SUBGRAPH
(
act_out
,
act_out
,
bn_act_pattern
);
// ops
GET_IR_NODE_FROM_SUBGRAPH
(
batch_norm
,
batch_norm
,
bn_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
act
,
act
,
bn_act_pattern
);
auto
*
bn_op
=
batch_norm
->
Op
();
if
(
bn_op
->
HasAttr
(
"use_mkldnn"
))
{
PADDLE_ENFORCE
(
BOOST_GET_CONST
(
bool
,
bn_op
->
GetAttr
(
"use_mkldnn"
)),
platform
::
errors
::
PreconditionNotMet
(
"The BatchNorm+Act fusion may happen only when oneDNN library "
"is used."
));
}
if
(
bn_op
->
HasAttr
(
"trainable_statistics"
))
{
PADDLE_ENFORCE
(
!
BOOST_GET_CONST
(
bool
,
bn_op
->
GetAttr
(
"trainable_statistics"
)),
platform
::
errors
::
PreconditionNotMet
(
"The BatchNorm+Act fusion may happen only when mean and variance "
"are not calculated by current batch statistics."
));
}
if
(
bn_op
->
HasAttr
(
"is_test"
))
{
PADDLE_ENFORCE
(
BOOST_GET_CONST
(
bool
,
bn_op
->
GetAttr
(
"is_test"
)),
platform
::
errors
::
PreconditionNotMet
(
"The BatchNorm+Act fusion may happen only during inference."
));
}
bn_op
->
SetAttr
(
"use_mkldnn"
,
true
);
bn_op
->
SetAttr
(
"is_test"
,
true
);
bn_op
->
SetAttr
(
"fuse_with_relu"
,
true
);
bn_op
->
SetAttr
(
"trainable_statistics"
,
false
);
bn_op
->
SetOutput
(
"Y"
,
{
act_out
->
Name
()});
IR_OP_VAR_LINK
(
batch_norm
,
act_out
);
GraphSafeRemoveNodes
(
g
,
{
act
,
bn_out
});
found_bn_act_count
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_bn_act_count
);
PrettyLogDetail
(
"--- fused %d batch norm with relu activation"
,
found_bn_act_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
batch_norm_act_fuse_pass
,
paddle
::
framework
::
ir
::
FuseBatchNormActOneDNNPass
);
REGISTER_PASS_CAPABILITY
(
batch_norm_act_fuse_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
EQ
(
"batch_norm"
,
0
)
.
EQ
(
"relu"
,
0
));
paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h
0 → 100644
浏览文件 @
7db747d9
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
/*
* \brief Fuse the BatchNorm and activation operators into single OneDNN's
* BatchNorm with post-op.
*
* \note Currently only ReLU is supported as an activation function.
*/
class
FuseBatchNormActOneDNNPass
:
public
FusePassBase
{
public:
virtual
~
FuseBatchNormActOneDNNPass
()
{}
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
void
FuseBatchNormAct
(
ir
::
Graph
*
graph
,
const
std
::
string
&
act_types
)
const
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass_tester.cc
0 → 100644
浏览文件 @
7db747d9
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <algorithm>
#include <exception>
#include <functional>
#include <iterator>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/mkldnn/batch_norm_act_fuse_pass.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/errors.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
// -------------------------- helper functions --------------------------------
namespace
{
using
InOutVarNamePair
=
std
::
pair
<
std
::
string
,
std
::
string
>
;
using
OpTypeCountPair
=
std
::
pair
<
std
::
string
,
int
>
;
///
/// @brief Creates the specified operator and sets up its inputs/outputs.
///
/// @param prog The program descriptor to which we add new op.
/// @param[in] op_type_name The operator type name.
/// @param[in] inputs The vector of input pairs: {input_name, variable
/// name}
/// @param[in] outputs The vector of output pairs {output_name, variable}
/// @param[in] use_mkldnn The flag deciding whether or not to set
/// 'use_mkldnn' attribute.
///
/// @return Returns pointer to the created operator descriptor.
///
OpDesc
*
CreateOp
(
ProgramDesc
*
prog
,
const
std
::
string
&
op_type_name
,
const
std
::
vector
<
InOutVarNamePair
>&
inputs
,
const
std
::
vector
<
InOutVarNamePair
>&
outputs
,
bool
use_mkldnn
=
true
)
{
auto
op
=
prog
->
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
op_type_name
);
op
->
SetAttr
(
"use_mkldnn"
,
use_mkldnn
);
for
(
const
auto
&
input
:
inputs
)
{
op
->
SetInput
(
input
.
first
,
{
input
.
second
});
}
for
(
const
auto
&
output
:
outputs
)
{
op
->
SetOutput
(
output
.
first
,
{
output
.
second
});
}
return
op
;
}
///
/// @brief Check whether node 'to' is reachable from node 'from' in graph.
///
/// @param[in] graph The graph we're checking for reachability.
/// @param[in] from The 'from' node name.
/// @param[in] to The 'to' node name.
///
/// @return True if there is connection between nodes 'from' and 'to'.
///
bool
TestIsReachable
(
const
Graph
&
graph
,
std
::
string
from
,
std
::
string
to
)
{
auto
hash
=
[](
const
Node
*
node
)
->
std
::
string
{
return
node
->
Name
()
+
std
::
to_string
(
node
->
id
());
};
auto
find_node
=
[
&
](
const
Graph
&
graph
,
const
std
::
string
&
name
)
->
Node
*
{
for
(
auto
&
node
:
GraphTraits
::
DFS
(
graph
))
{
if
(
name
==
hash
(
&
node
))
{
return
&
node
;
}
}
return
nullptr
;
};
if
(
from
==
to
)
return
true
;
std
::
map
<
std
::
string
,
bool
>
visited
;
// update the from and to strings to hashed equivs in loop from graph traits
for
(
auto
&
node
:
GraphTraits
::
DFS
(
graph
))
{
auto
hashed
=
hash
(
&
node
);
if
(
node
.
Name
()
==
from
)
{
from
=
hashed
;
}
if
(
node
.
Name
()
==
to
)
{
to
=
hashed
;
}
visited
[
hashed
]
=
false
;
}
visited
[
from
]
=
true
;
std
::
list
<
std
::
string
>
queue
;
queue
.
push_back
(
from
);
while
(
!
queue
.
empty
())
{
auto
cur
=
find_node
(
graph
,
queue
.
front
());
queue
.
pop_front
();
if
(
cur
==
nullptr
)
{
return
false
;
}
for
(
auto
n
:
cur
->
outputs
)
{
auto
hashed_name
=
hash
(
n
);
if
(
hashed_name
==
to
)
{
return
true
;
}
if
(
!
visited
[
hashed_name
])
{
visited
[
hashed_name
]
=
true
;
queue
.
push_back
(
hashed_name
);
}
}
}
return
false
;
}
///
/// @brief Search through graph and counts provided operator occurences.
///
/// @param[in] graph The graph we search through.
/// @param[in] op_type_count The vector of pairs {op_type_name, op count}
///
/// @note After going through all graph nodes this function asserts
/// whether counted number for each requested op is as expected.
///
void
AssertOpsCount
(
const
Graph
&
graph
,
std
::
vector
<
OpTypeCountPair
>
op_type_count
)
{
for
(
auto
*
node
:
graph
.
Nodes
())
{
if
(
!
node
->
IsOp
())
{
continue
;
}
const
std
::
string
op_type_name
=
node
->
Op
()
->
Type
();
auto
op_it
=
std
::
find_if
(
std
::
begin
(
op_type_count
),
std
::
end
(
op_type_count
),
[
op_type_name
](
const
OpTypeCountPair
&
p
)
{
return
op_type_name
==
p
.
first
;
});
if
(
op_it
!=
std
::
end
(
op_type_count
))
{
op_it
->
second
--
;
}
}
for
(
const
OpTypeCountPair
&
p
:
op_type_count
)
{
EXPECT_EQ
(
p
.
second
,
0
);
}
}
///
/// @brief Builds a program descriptor.
///
/// @param[in] transient_vars The vector of transient variables names.
/// @param[in] persistent_vars The vector of persistent variables names. Those
/// will have persistable attribute set to true.
///
/// @return The program descriptor object.
///
ProgramDesc
BuildProgramDesc
(
const
std
::
vector
<
std
::
string
>&
transient_vars
,
const
std
::
vector
<
std
::
string
>&
persistent_vars
)
{
ProgramDesc
prog
;
auto
add_var_to_prog
=
[
&
prog
](
const
std
::
string
&
var_name
)
->
VarDesc
*
{
auto
var
=
prog
.
MutableBlock
(
0
)
->
Var
(
var_name
);
var
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
return
var
;
};
for
(
const
auto
&
v
:
transient_vars
)
{
add_var_to_prog
(
v
);
}
for
(
const
auto
&
v
:
persistent_vars
)
{
auto
*
var
=
add_var_to_prog
(
v
);
var
->
SetPersistable
(
true
);
}
return
prog
;
}
///
/// @brief Execute pass on provided graph and perform checks.
///
/// @param graph The graph we run pass on.
/// @param[in] from The name of a 'starting' node sequence in a
/// graph. This would be used to test for
/// correct node connections.
/// @param[in] to The name of a 'ending' node sequence in a
/// graph. This would be used to test for
/// correct node connections.
/// @param[in] removed_nodes_count The number of nodes we expect will be
/// removed/fused after pass execution.
/// @param[in] added_nodes_count The number of nodes we expect will be
/// added after pass execution.
///
void
RunPassAndAssert
(
Graph
*
graph
,
const
std
::
string
&
from
,
const
std
::
string
&
to
,
int
removed_nodes_count
,
int
added_nodes_count
=
0
)
{
EXPECT_TRUE
(
TestIsReachable
(
*
graph
,
from
,
to
));
int
original_nodes_num
=
graph
->
Nodes
().
size
();
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"batch_norm_act_fuse_pass"
);
pass
->
Apply
(
graph
);
int
current_nodes_num
=
graph
->
Nodes
().
size
();
EXPECT_TRUE
(
TestIsReachable
(
*
graph
,
from
,
to
));
EXPECT_EQ
(
original_nodes_num
-
removed_nodes_count
+
added_nodes_count
,
current_nodes_num
);
}
void
SetBatchNormAttrs
(
OpDesc
*
bn_op
,
bool
is_test
=
true
,
bool
trainable_stats
=
true
)
{
bn_op
->
SetAttr
(
"is_test"
,
is_test
);
bn_op
->
SetAttr
(
"trainable_statistics"
,
trainable_stats
);
bn_op
->
SetAttr
(
"fuse_with_relu"
,
false
);
}
}
// namespace
// ------------------------------ Test cases -----------------------------------
// The below test cases are distinguished by whether following attributes have
// true or false value:
// - is_test
// - trainable_statistics
// The test case name would have only attributes with true value in its name.
TEST
(
FuseBatchNormActOneDNNPass
,
ThrowIsTestTrainableStats
)
{
auto
prog
=
BuildProgramDesc
(
{
"x"
,
"m"
,
"v"
,
"bn_y"
,
"act_y"
,
"m_out"
,
"var_out"
,
"sm"
,
"sv"
},
{
"scale"
,
"bias"
});
auto
*
bn_op
=
CreateOp
(
&
prog
,
"batch_norm"
,
{{
"X"
,
"x"
},
{
"Scale"
,
"scale"
},
{
"Bias"
,
"bias"
},
{
"Mean"
,
"m"
},
{
"Variance"
,
"v"
}},
{{
"Y"
,
"bn_y"
},
{
"MeanOut"
,
"m_out"
},
{
"VarianceOut"
,
"var_out"
},
{
"SavedMean"
,
"sm"
},
{
"SavedVariance"
,
"sv"
}});
SetBatchNormAttrs
(
bn_op
,
true
,
true
);
CreateOp
(
&
prog
,
"relu"
,
{{
"X"
,
"bn_y"
}},
{{
"Out"
,
"act_y"
}},
false
);
Graph
graph
(
prog
);
// No fusion in this attribute configuration
constexpr
int
removed_nodes_count
=
0
;
EXPECT_THROW
(
RunPassAndAssert
(
&
graph
,
"x"
,
"act_y"
,
removed_nodes_count
),
paddle
::
platform
::
EnforceNotMet
);
}
TEST
(
FuseBatchNormActOneDNNPass
,
FuseIsTest
)
{
auto
prog
=
BuildProgramDesc
({
"x"
,
"m"
,
"v"
,
"bn_y"
,
"act_y"
},
{
"scale"
,
"bias"
});
auto
*
bn_op
=
CreateOp
(
&
prog
,
"batch_norm"
,
{{
"X"
,
"x"
},
{
"Scale"
,
"scale"
},
{
"Bias"
,
"bias"
},
{
"Mean"
,
"m"
},
{
"Variance"
,
"v"
}},
{{
"Y"
,
"bn_y"
}});
SetBatchNormAttrs
(
bn_op
,
true
,
false
);
CreateOp
(
&
prog
,
"relu"
,
{{
"X"
,
"bn_y"
}},
{{
"Out"
,
"act_y"
}},
false
);
Graph
graph
(
prog
);
constexpr
int
removed_nodes_count
=
2
;
RunPassAndAssert
(
&
graph
,
"x"
,
"act_y"
,
removed_nodes_count
);
AssertOpsCount
(
graph
,
{{
"batch_norm"
,
1
},
{
"relu"
,
0
}});
for
(
const
auto
*
node
:
graph
.
Nodes
())
{
if
(
node
->
IsOp
()
&&
node
->
Op
()
->
Type
()
==
"batch_norm"
)
{
const
auto
*
op
=
node
->
Op
();
ASSERT_TRUE
(
op
->
HasAttr
(
"use_mkldnn"
));
EXPECT_TRUE
(
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"use_mkldnn"
)));
ASSERT_TRUE
(
op
->
HasAttr
(
"fuse_with_relu"
));
EXPECT_TRUE
(
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"fuse_with_relu"
)));
ASSERT_TRUE
(
op
->
HasAttr
(
"trainable_statistics"
));
EXPECT_FALSE
(
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"trainable_statistics"
)));
}
}
}
TEST
(
FuseBatchNormActOneDNNPass
,
ThrowTrainableStats
)
{
auto
prog
=
BuildProgramDesc
(
{
"x"
,
"m"
,
"v"
,
"bn_y"
,
"act_y"
,
"m_out"
,
"var_out"
,
"sm"
,
"sv"
},
{
"scale"
,
"bias"
});
auto
*
bn_op
=
CreateOp
(
&
prog
,
"batch_norm"
,
{{
"X"
,
"x"
},
{
"Scale"
,
"scale"
},
{
"Bias"
,
"bias"
},
{
"Mean"
,
"m"
},
{
"Variance"
,
"v"
}},
{{
"Y"
,
"bn_y"
},
{
"MeanOut"
,
"m_out"
},
{
"VarianceOut"
,
"var_out"
},
{
"SavedMean"
,
"sm"
},
{
"SavedVariance"
,
"sv"
}});
SetBatchNormAttrs
(
bn_op
,
false
,
true
);
CreateOp
(
&
prog
,
"relu"
,
{{
"X"
,
"bn_y"
}},
{{
"Out"
,
"act_y"
}},
false
);
Graph
graph
(
prog
);
// No fusion in this attribute configuration
constexpr
int
removed_nodes_count
=
0
;
EXPECT_THROW
(
RunPassAndAssert
(
&
graph
,
"x"
,
"act_y"
,
removed_nodes_count
),
paddle
::
platform
::
EnforceNotMet
);
}
TEST
(
FuseBatchNormActOneDNNPass
,
AllAttrsFalse
)
{
auto
prog
=
BuildProgramDesc
(
{
"x"
,
"m"
,
"v"
,
"bn_y"
,
"act_y"
,
"m_out"
,
"var_out"
,
"sm"
,
"sv"
},
{
"scale"
,
"bias"
});
auto
*
bn_op
=
CreateOp
(
&
prog
,
"batch_norm"
,
{{
"X"
,
"x"
},
{
"Scale"
,
"scale"
},
{
"Bias"
,
"bias"
},
{
"Mean"
,
"m"
},
{
"Variance"
,
"v"
}},
{{
"Y"
,
"bn_y"
},
{
"MeanOut"
,
"m_out"
},
{
"VarianceOut"
,
"var_out"
},
{
"SavedMean"
,
"sm"
},
{
"SavedVariance"
,
"sv"
}});
SetBatchNormAttrs
(
bn_op
,
false
,
false
);
CreateOp
(
&
prog
,
"relu"
,
{{
"X"
,
"bn_y"
}},
{{
"Out"
,
"act_y"
}},
false
);
Graph
graph
(
prog
);
// No fusion in this attribute configuration
constexpr
int
removed_nodes_count
=
0
;
EXPECT_THROW
(
RunPassAndAssert
(
&
graph
,
"x"
,
"act_y"
,
removed_nodes_count
),
paddle
::
platform
::
EnforceNotMet
);
}
TEST
(
FuseBatchNormActOneDNNPass
,
ThrowUseMkldnn
)
{
auto
prog
=
BuildProgramDesc
(
{
"x"
,
"m"
,
"v"
,
"bn_y"
,
"act_y"
,
"m_out"
,
"var_out"
,
"sm"
,
"sv"
},
{
"scale"
,
"bias"
});
auto
*
bn_op
=
CreateOp
(
&
prog
,
"batch_norm"
,
{{
"X"
,
"x"
},
{
"Scale"
,
"scale"
},
{
"Bias"
,
"bias"
},
{
"Mean"
,
"m"
},
{
"Variance"
,
"v"
}},
{{
"Y"
,
"bn_y"
},
{
"MeanOut"
,
"m_out"
},
{
"VarianceOut"
,
"var_out"
},
{
"SavedMean"
,
"sm"
},
{
"SavedVariance"
,
"sv"
}},
false
);
SetBatchNormAttrs
(
bn_op
,
false
,
false
);
CreateOp
(
&
prog
,
"relu"
,
{{
"X"
,
"bn_y"
}},
{{
"Out"
,
"act_y"
}},
false
);
Graph
graph
(
prog
);
// No fusion in this attribute configuration
constexpr
int
removed_nodes_count
=
0
;
EXPECT_THROW
(
RunPassAndAssert
(
&
graph
,
"x"
,
"act_y"
,
removed_nodes_count
),
paddle
::
platform
::
EnforceNotMet
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
USE_PASS
(
batch_norm_act_fuse_pass
);
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
7db747d9
...
...
@@ -207,6 +207,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"matmul_transpose_reshape_fuse_pass"
,
//
// Disabled due to topology-dependent speed-up
// "fc_mkldnn_pass",
"batch_norm_act_fuse_pass"
,
"mkldnn_inplace_pass"
,
// This pass should be activated after
// fuses
}))
{
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_batch_norm_act_fuse_pass.py
0 → 100644
浏览文件 @
7db747d9
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for fusion of batch norm and activation."""
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
paddle.fluid
as
fluid
from
inference_pass_test
import
InferencePassTest
from
paddle
import
enable_static
from
paddle.fluid.core
import
PassVersionChecker
enable_static
()
class
BnReluOneDnnFusePassTest
(
InferencePassTest
):
def
setUp
(
self
):
self
.
set_params
()
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
3
,
100
,
100
],
dtype
=
"float32"
)
bn_out
=
fluid
.
layers
.
batch_norm
(
input
=
data
,
is_test
=
True
,
use_global_stats
=
self
.
global_stats
)
relu_out
=
fluid
.
layers
.
relu
(
bn_out
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
((
1
,
3
,
100
,
100
)).
astype
(
"float32"
)
}
self
.
fetch_list
=
[
relu_out
]
self
.
enable_mkldnn
=
True
def
set_params
(
self
):
self
.
global_stats
=
False
self
.
pass_name
=
"batch_norm_act_fuse_pass"
def
test_check_output
(
self
):
self
.
check_output
()
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
self
.
pass_name
))
class
BnReluGlobalStatsOneDnnFusePassTest
(
InferencePassTest
):
def
setUp
(
self
):
self
.
set_params
()
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
3
,
100
,
100
],
dtype
=
"float32"
)
bn_out
=
fluid
.
layers
.
batch_norm
(
input
=
data
,
is_test
=
True
,
use_global_stats
=
self
.
global_stats
)
relu_out
=
fluid
.
layers
.
relu
(
bn_out
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
((
1
,
3
,
100
,
100
)).
astype
(
"float32"
)
}
self
.
fetch_list
=
[
relu_out
]
self
.
enable_mkldnn
=
True
def
set_params
(
self
):
self
.
global_stats
=
True
self
.
pass_name
=
"batch_norm_act_fuse_pass"
def
test_check_output
(
self
):
self
.
check_output
()
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
self
.
pass_name
))
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录