Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
3c117179
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3c117179
编写于
9月 17, 2020
作者:
S
Shang Zhizhou
提交者:
GitHub
9月 17, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add op version checker to ir passes (#27329)
上级
515efe42
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
268 addition
and
1 deletion
+268
-1
paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.cc
...uid/framework/ir/embedding_eltwise_layernorm_fuse_pass.cc
+6
-0
paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass_tester.cc
...mework/ir/embedding_eltwise_layernorm_fuse_pass_tester.cc
+8
-1
paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc
...e/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc
+19
-0
paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc
.../framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc
+7
-0
paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc
...mework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc
+6
-0
paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc
...ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc
+7
-0
paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc
...e/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc
+5
-0
paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass_tester.cc
.../framework/ir/mkldnn/depthwise_conv_mkldnn_pass_tester.cc
+8
-0
paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc
paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc
+11
-0
paddle/fluid/framework/ir/multihead_matmul_fuse_pass_tester.cc
...e/fluid/framework/ir/multihead_matmul_fuse_pass_tester.cc
+7
-0
paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc
paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc
+6
-0
paddle/fluid/framework/ir/skip_layernorm_fuse_pass_tester.cc
paddle/fluid/framework/ir/skip_layernorm_fuse_pass_tester.cc
+7
-0
python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py
...unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py
+171
-0
未找到文件。
paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.cc
浏览文件 @
3c117179
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include <vector>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -334,3 +335,8 @@ void EmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const {
...
@@ -334,3 +335,8 @@ void EmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const {
REGISTER_PASS
(
embedding_eltwise_layernorm_fuse_pass
,
REGISTER_PASS
(
embedding_eltwise_layernorm_fuse_pass
,
paddle
::
framework
::
ir
::
EmbeddingEltwiseLayerNormFusePass
);
paddle
::
framework
::
ir
::
EmbeddingEltwiseLayerNormFusePass
);
REGISTER_PASS_CAPABILITY
(
embedding_eltwise_layernorm_fuse_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
EQ
(
"lookup_table"
,
0
)
.
EQ
(
"elementweise_add"
,
0
));
paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass_tester.cc
浏览文件 @
3c117179
...
@@ -16,12 +16,13 @@ limitations under the License. */
...
@@ -16,12 +16,13 @@ limitations under the License. */
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
TEST
(
SkipLayerN
ormFusePass
,
basic
)
{
TEST
(
EmbeddingElewiseLayern
ormFusePass
,
basic
)
{
// inputs operator output
// inputs operator output
// --------------------------------------------------------------------
// --------------------------------------------------------------------
// (x, y) elementwise_add -> elementwise_out
// (x, y) elementwise_add -> elementwise_out
...
@@ -91,6 +92,12 @@ TEST(SkipLayerNormFusePass, basic) {
...
@@ -91,6 +92,12 @@ TEST(SkipLayerNormFusePass, basic) {
"The number of fusion nodes does not meet expectations after fuse"
));
"The number of fusion nodes does not meet expectations after fuse"
));
}
}
TEST
(
EmbeddingElewiseLayernormFusePass
,
pass_op_version_check
)
{
ASSERT_TRUE
(
paddle
::
framework
::
compatible
::
PassVersionCheckerRegistrar
::
GetInstance
()
.
IsPassCompatible
(
"embedding_eltwise_layernorm_fuse_pass"
));
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
...
...
paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc
浏览文件 @
3c117179
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -84,6 +85,19 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
...
@@ -84,6 +85,19 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
VLOG
(
3
)
<<
"do not perform "
+
type
()
+
"+bias fuse"
;
VLOG
(
3
)
<<
"do not perform "
+
type
()
+
"+bias fuse"
;
return
;
return
;
}
}
if
(
conv
->
Op
()
->
HasAttr
(
"dilations"
))
{
auto
dilations
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
conv
->
Op
()
->
GetAttr
(
"dilations"
));
for
(
const
auto
&
d
:
dilations
)
{
if
(
d
!=
1
)
{
LOG
(
WARNING
)
<<
"dilation conv not supported in MKLDNN, fuse not apply "
<<
"and set conv attribute use_mkldnn = false"
;
conv
->
Op
()
->
SetAttr
(
"use_mkldnn"
,
false
);
return
;
}
}
}
auto
*
eltwise_bias_tensor
=
auto
*
eltwise_bias_tensor
=
scope
->
FindVar
(
eltwise_bias
->
Name
())
->
GetMutable
<
LoDTensor
>
();
scope
->
FindVar
(
eltwise_bias
->
Name
())
->
GetMutable
<
LoDTensor
>
();
...
@@ -151,3 +165,8 @@ REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass,
...
@@ -151,3 +165,8 @@ REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass,
paddle
::
framework
::
ir
::
Conv2DTransposeBiasFusePass
);
paddle
::
framework
::
ir
::
Conv2DTransposeBiasFusePass
);
REGISTER_PASS
(
conv3d_bias_mkldnn_fuse_pass
,
REGISTER_PASS
(
conv3d_bias_mkldnn_fuse_pass
,
paddle
::
framework
::
ir
::
Conv3DBiasFusePass
);
paddle
::
framework
::
ir
::
Conv3DBiasFusePass
);
REGISTER_PASS_CAPABILITY
(
conv_bias_mkldnn_fuse_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
EQ
(
"conv2d"
,
0
)
.
EQ
(
"elementwise_add"
,
0
));
paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc
浏览文件 @
3c117179
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/imperative/type_defs.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -149,6 +150,12 @@ TEST(ConvBiasFusePass, conv2d_transpose) {
...
@@ -149,6 +150,12 @@ TEST(ConvBiasFusePass, conv2d_transpose) {
ASSERT_EQ
(
pass
.
type
(),
std
::
string
(
"conv2d_transpose"
));
ASSERT_EQ
(
pass
.
type
(),
std
::
string
(
"conv2d_transpose"
));
}
}
TEST
(
ConvBiasFusePass
,
pass_op_version_check
)
{
ASSERT_TRUE
(
paddle
::
framework
::
compatible
::
PassVersionCheckerRegistrar
::
GetInstance
()
.
IsPassCompatible
(
"conv_bias_mkldnn_fuse_pass"
));
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
...
...
paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc
浏览文件 @
3c117179
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include <memory>
#include <memory>
#include <tuple>
#include <tuple>
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -341,3 +342,8 @@ void ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
...
@@ -341,3 +342,8 @@ void ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
REGISTER_PASS
(
conv_elementwise_add_mkldnn_fuse_pass
,
REGISTER_PASS
(
conv_elementwise_add_mkldnn_fuse_pass
,
paddle
::
framework
::
ir
::
ResidualConnectionMKLDNNFusePass
);
paddle
::
framework
::
ir
::
ResidualConnectionMKLDNNFusePass
);
REGISTER_PASS_CAPABILITY
(
conv_elementwise_add_mkldnn_fuse_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
EQ
(
"conv2d"
,
0
)
.
EQ
(
"elementwise_add"
,
0
));
paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc
浏览文件 @
3c117179
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -267,6 +268,12 @@ TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) {
...
@@ -267,6 +268,12 @@ TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) {
AssertOpsCount
(
graph
,
2
,
1
);
AssertOpsCount
(
graph
,
2
,
1
);
}
}
TEST
(
ConvElementwiseAddMKLDNNFusePass
,
pass_op_version_check
)
{
ASSERT_TRUE
(
paddle
::
framework
::
compatible
::
PassVersionCheckerRegistrar
::
GetInstance
()
.
IsPassCompatible
(
"conv_elementwise_add_mkldnn_fuse_pass"
));
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
...
...
paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc
浏览文件 @
3c117179
...
@@ -14,6 +14,7 @@ limitations under the License. */
...
@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -57,3 +58,7 @@ void DepthwiseConvMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
...
@@ -57,3 +58,7 @@ void DepthwiseConvMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
REGISTER_PASS
(
depthwise_conv_mkldnn_pass
,
REGISTER_PASS
(
depthwise_conv_mkldnn_pass
,
paddle
::
framework
::
ir
::
DepthwiseConvMKLDNNPass
);
paddle
::
framework
::
ir
::
DepthwiseConvMKLDNNPass
);
REGISTER_PASS_CAPABILITY
(
depthwise_conv_mkldnn_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
().
EQ
(
"depthwise_conv2d"
,
0
));
paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass_tester.cc
浏览文件 @
3c117179
...
@@ -16,6 +16,8 @@
...
@@ -16,6 +16,8 @@
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
...
@@ -70,6 +72,12 @@ ProgramDesc BuildProgramDesc() {
...
@@ -70,6 +72,12 @@ ProgramDesc BuildProgramDesc() {
return
prog
;
return
prog
;
}
}
TEST
(
DepthwiseConvMKLDNNPass
,
pass_op_version_check
)
{
ASSERT_TRUE
(
paddle
::
framework
::
compatible
::
PassVersionCheckerRegistrar
::
GetInstance
()
.
IsPassCompatible
(
"depthwise_conv_mkldnn_pass"
));
}
TEST
(
DepthwiseConvMKLDNNPass
,
basic
)
{
TEST
(
DepthwiseConvMKLDNNPass
,
basic
)
{
auto
prog
=
BuildProgramDesc
();
auto
prog
=
BuildProgramDesc
();
...
...
paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc
浏览文件 @
3c117179
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include <vector>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/errors.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -707,3 +708,13 @@ REGISTER_PASS(multihead_matmul_fuse_pass,
...
@@ -707,3 +708,13 @@ REGISTER_PASS(multihead_matmul_fuse_pass,
REGISTER_PASS
(
multihead_matmul_fuse_pass_v2
,
REGISTER_PASS
(
multihead_matmul_fuse_pass_v2
,
paddle
::
framework
::
ir
::
MultiHeadMatmulV2FusePass
);
paddle
::
framework
::
ir
::
MultiHeadMatmulV2FusePass
);
REGISTER_PASS_CAPABILITY
(
multihead_matmul_fuse_pass_v2
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
EQ
(
"mul"
,
0
)
.
EQ
(
"elementwise_add"
,
0
)
.
EQ
(
"reshape2"
,
0
)
.
EQ
(
"transpose2"
,
0
)
.
EQ
(
"scale"
,
0
)
.
EQ
(
"matmul"
,
0
)
.
EQ
(
"softmax"
,
0
));
paddle/fluid/framework/ir/multihead_matmul_fuse_pass_tester.cc
浏览文件 @
3c117179
...
@@ -12,6 +12,7 @@ limitations under the License. */
...
@@ -12,6 +12,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h" // NOLINT
#include "paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h" // NOLINT
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -133,6 +134,12 @@ TEST(MultiHeadMatmulFusePass, basic) {
...
@@ -133,6 +134,12 @@ TEST(MultiHeadMatmulFusePass, basic) {
num_fused_nodes_after
));
num_fused_nodes_after
));
}
}
TEST
(
MultiHeadMatmulFusePass
,
pass_op_version_check
)
{
ASSERT_TRUE
(
paddle
::
framework
::
compatible
::
PassVersionCheckerRegistrar
::
GetInstance
()
.
IsPassCompatible
(
"multihead_matmul_fuse_pass_v2"
));
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
...
...
paddle/fluid/framework/ir/skip_layernorm_fuse_pass.cc
浏览文件 @
3c117179
...
@@ -17,6 +17,7 @@ limitations under the License. */
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <unordered_set>
#include <unordered_set>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -180,3 +181,8 @@ void SkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
...
@@ -180,3 +181,8 @@ void SkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
REGISTER_PASS
(
skip_layernorm_fuse_pass
,
REGISTER_PASS
(
skip_layernorm_fuse_pass
,
paddle
::
framework
::
ir
::
SkipLayerNormFusePass
);
paddle
::
framework
::
ir
::
SkipLayerNormFusePass
);
REGISTER_PASS_CAPABILITY
(
skip_layernorm_fuse_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
EQ
(
"elementwise_add"
,
0
)
.
EQ
(
"layer_norm"
,
0
));
paddle/fluid/framework/ir/skip_layernorm_fuse_pass_tester.cc
浏览文件 @
3c117179
...
@@ -16,6 +16,7 @@ limitations under the License. */
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -54,6 +55,12 @@ TEST(SkipLayerNormFusePass, basic) {
...
@@ -54,6 +55,12 @@ TEST(SkipLayerNormFusePass, basic) {
"The number of fusion nodes does not meet expectations after fuse"
));
"The number of fusion nodes does not meet expectations after fuse"
));
}
}
TEST
(
SkipLayerNormFusePass
,
pass_op_version_check
)
{
ASSERT_TRUE
(
paddle
::
framework
::
compatible
::
PassVersionCheckerRegistrar
::
GetInstance
()
.
IsPassCompatible
(
"skip_layernorm_fuse_pass"
));
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py
0 → 100644
浏览文件 @
3c117179
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
from
inference_pass_test
import
InferencePassTest
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
from
paddle.fluid.core
import
AnalysisConfig
"""Test for fusion of conv and bias."""
#padding SAME
class
ConvBiasMkldnnFusePassTest
(
InferencePassTest
):
def
setUp
(
self
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
3
,
100
,
100
],
dtype
=
"float32"
)
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Xavier
(
uniform
=
False
),
learning_rate
=
0.001
)
conv_out
=
fluid
.
layers
.
conv2d
(
input
=
data
,
num_filters
=
3
,
filter_size
=
3
,
padding
=
"SAME"
,
bias_attr
=
param_attr
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
((
1
,
3
,
100
,
100
)).
astype
(
"float32"
)
}
self
.
fetch_list
=
[
conv_out
]
self
.
enable_mkldnn
=
True
def
test_check_output
(
self
):
use_gpu
=
False
self
.
check_output_with_option
(
use_gpu
)
#padding VALID
class
ConvBiasMkldnnFusePassTest1
(
InferencePassTest
):
def
setUp
(
self
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
3
,
100
,
100
],
dtype
=
"float32"
)
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Xavier
(
uniform
=
False
),
learning_rate
=
0.001
)
conv_out
=
fluid
.
layers
.
conv2d
(
input
=
data
,
num_filters
=
3
,
filter_size
=
3
,
padding
=
"VALID"
,
bias_attr
=
param_attr
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
((
1
,
3
,
100
,
100
)).
astype
(
"float32"
)
}
self
.
fetch_list
=
[
conv_out
]
self
.
enable_mkldnn
=
True
def
test_check_output
(
self
):
use_gpu
=
False
self
.
check_output_with_option
(
use_gpu
)
#padding number
class
ConvBiasMkldnnFusePassTest2
(
InferencePassTest
):
def
setUp
(
self
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
3
,
100
,
100
],
dtype
=
"float32"
)
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Xavier
(
uniform
=
False
),
learning_rate
=
0.001
)
conv_out
=
fluid
.
layers
.
conv2d
(
input
=
data
,
num_filters
=
3
,
filter_size
=
3
,
padding
=
[
2
,
4
,
6
,
8
],
bias_attr
=
param_attr
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
((
1
,
3
,
100
,
100
)).
astype
(
"float32"
)
}
self
.
fetch_list
=
[
conv_out
]
self
.
enable_mkldnn
=
True
def
test_check_output
(
self
):
use_gpu
=
False
self
.
check_output_with_option
(
use_gpu
)
#dilation not supported yet, just print warning log and does not fuse
class
ConvBiasMkldnnFusePassTest3
(
InferencePassTest
):
def
setUp
(
self
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
3
,
100
,
100
],
dtype
=
"float32"
)
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Xavier
(
uniform
=
False
),
learning_rate
=
0.001
)
conv_out
=
fluid
.
layers
.
conv2d
(
input
=
data
,
num_filters
=
3
,
filter_size
=
3
,
padding
=
"VALID"
,
dilation
=
2
,
groups
=
3
,
bias_attr
=
param_attr
,
use_cudnn
=
False
,
act
=
"softmax"
,
data_format
=
"NCHW"
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
((
1
,
3
,
100
,
100
)).
astype
(
"float32"
)
}
self
.
fetch_list
=
[
conv_out
]
self
.
enable_mkldnn
=
True
def
test_check_output
(
self
):
use_gpu
=
False
self
.
check_output_with_option
(
use_gpu
)
#all conv params except for dilation
class
ConvBiasMkldnnFusePassTest4
(
InferencePassTest
):
def
setUp
(
self
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
3
,
100
,
100
],
dtype
=
"float32"
)
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Xavier
(
uniform
=
False
),
learning_rate
=
0.001
)
conv_out
=
fluid
.
layers
.
conv2d
(
input
=
data
,
num_filters
=
3
,
filter_size
=
3
,
padding
=
"VALID"
,
groups
=
3
,
bias_attr
=
param_attr
,
use_cudnn
=
False
,
act
=
"softmax"
,
data_format
=
"NCHW"
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
((
1
,
3
,
100
,
100
)).
astype
(
"float32"
)
}
self
.
fetch_list
=
[
conv_out
]
self
.
enable_mkldnn
=
True
def
test_check_output
(
self
):
use_gpu
=
False
self
.
check_output_with_option
(
use_gpu
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录