Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
edc06c6a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2297
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
edc06c6a
编写于
12月 24, 2020
作者:
J
jakpiase
提交者:
GitHub
12月 24, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Added fc + activation fuse pass (currently only gelu, sigmoid and tanh are supported) (#29772)
上级
0e0bb1b9
变更
15
显示空白变更内容
内联
并排
Showing
15 changed file
with
744 addition
and
3 deletion
+744
-3
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+2
-0
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+17
-0
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+21
-0
paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc
paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc
+100
-0
paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h
paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h
+45
-0
paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc
...uid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc
+398
-0
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+2
-1
paddle/fluid/inference/tests/api/analyzer_dam_tester.cc
paddle/fluid/inference/tests/api/analyzer_dam_tester.cc
+2
-0
paddle/fluid/inference/tests/api/analyzer_image_classification_tester.cc
...ference/tests/api/analyzer_image_classification_tester.cc
+6
-2
paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester_helper.h
...id/inference/tests/api/analyzer_seq_pool1_tester_helper.h
+1
-0
paddle/fluid/inference/tests/api/analyzer_transformer_compare_tester.cc
...nference/tests/api/analyzer_transformer_compare_tester.cc
+1
-0
paddle/fluid/inference/tests/api/analyzer_transformer_profile_tester.cc
...nference/tests/api/analyzer_transformer_profile_tester.cc
+1
-0
paddle/fluid/inference/tests/api/analyzer_vis_tester.cc
paddle/fluid/inference/tests/api/analyzer_vis_tester.cc
+2
-0
paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc
paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc
+30
-0
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_fc_act_fuse_pass.py
...ts/unittests/ir/inference/test_mkldnn_fc_act_fuse_pass.py
+116
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
edc06c6a
...
@@ -105,6 +105,7 @@ if(WITH_MKLDNN)
...
@@ -105,6 +105,7 @@ if(WITH_MKLDNN)
pass_library
(
cpu_bfloat16_placement_pass inference DIR mkldnn
)
pass_library
(
cpu_bfloat16_placement_pass inference DIR mkldnn
)
pass_library
(
cpu_bfloat16_pass inference DIR mkldnn
)
pass_library
(
cpu_bfloat16_pass inference DIR mkldnn
)
pass_library
(
fc_mkldnn_pass inference DIR mkldnn
)
pass_library
(
fc_mkldnn_pass inference DIR mkldnn
)
pass_library
(
fc_act_mkldnn_fuse_pass inference DIR mkldnn
)
pass_library
(
cpu_quantize_placement_pass base DIR mkldnn
)
pass_library
(
cpu_quantize_placement_pass base DIR mkldnn
)
pass_library
(
cpu_quantize_pass inference DIR mkldnn
)
pass_library
(
cpu_quantize_pass inference DIR mkldnn
)
pass_library
(
cpu_quantize_squash_pass inference DIR mkldnn
)
pass_library
(
cpu_quantize_squash_pass inference DIR mkldnn
)
...
@@ -155,6 +156,7 @@ if (WITH_MKLDNN)
...
@@ -155,6 +156,7 @@ if (WITH_MKLDNN)
cc_test
(
test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass
)
cc_test
(
test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass
)
cc_test
(
test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass
)
cc_test
(
test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass
)
cc_test
(
test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass
)
cc_test
(
test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass
)
cc_test
(
test_fc_act_mkldnn_fuse_pass SRCS mkldnn/fc_act_mkldnn_fuse_pass_tester.cc DEPS fc_act_mkldnn_fuse_pass
)
cc_test
(
test_batch_norm_act_fuse_pass SRCS mkldnn/batch_norm_act_fuse_pass_tester.cc DEPS batch_norm_act_fuse_pass
)
cc_test
(
test_batch_norm_act_fuse_pass SRCS mkldnn/batch_norm_act_fuse_pass_tester.cc DEPS batch_norm_act_fuse_pass
)
set
(
TEST_CONV_BN_PASS_DEPS conv_bn_fuse_pass graph_to_program_pass conv_op conv_transpose_op math_function im2col vol2col batch_norm_op gelu_op activation_op elementwise_add_op concat_and_split naive_executor device_context
)
set
(
TEST_CONV_BN_PASS_DEPS conv_bn_fuse_pass graph_to_program_pass conv_op conv_transpose_op math_function im2col vol2col batch_norm_op gelu_op activation_op elementwise_add_op concat_and_split naive_executor device_context
)
if
(
WITH_GPU
)
if
(
WITH_GPU
)
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
edc06c6a
...
@@ -1017,6 +1017,23 @@ PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x,
...
@@ -1017,6 +1017,23 @@ PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x,
return
fc_out_var
;
return
fc_out_var
;
}
}
PDNode
*
patterns
::
FCActOneDNN
::
operator
()(
const
std
::
string
&
act_type
)
{
auto
*
fc
=
pattern
->
NewNode
(
fc_repr
())
->
assert_is_op
(
"fc"
);
auto
*
fc_out
=
pattern
->
NewNode
(
fc_out_repr
())
->
assert_is_op_output
(
"fc"
,
"Out"
)
->
assert_is_op_input
(
act_type
);
auto
*
act
=
pattern
->
NewNode
(
act_repr
())
->
assert_is_op
(
act_type
)
->
AsIntermediate
();
auto
*
act_out
=
pattern
->
NewNode
(
act_out_repr
())
->
assert_is_op_output
(
act_type
,
"Out"
)
->
AsOutput
();
fc
->
LinksTo
({
fc_out
});
act
->
LinksFrom
({
fc_out
}).
LinksTo
({
act_out
});
return
act_out
;
}
PDNode
*
patterns
::
Embedding
::
operator
()(
PDNode
*
x
)
{
PDNode
*
patterns
::
Embedding
::
operator
()(
PDNode
*
x
)
{
x
->
assert_is_op_input
(
"lookup_table"
,
"Ids"
);
x
->
assert_is_op_input
(
"lookup_table"
,
"Ids"
);
auto
*
lookup_table_op
=
auto
*
lookup_table_op
=
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
edc06c6a
...
@@ -552,6 +552,27 @@ struct FCMKLDNN : public PatternBase {
...
@@ -552,6 +552,27 @@ struct FCMKLDNN : public PatternBase {
PATTERN_DECL_NODE
(
output
);
PATTERN_DECL_NODE
(
output
);
};
};
//
// \brief Pattern looking for fc and a directly following activation
// operator.
//
// \note Currently only gelu and tanh are supported as an activation
// function.
// Formula: act(fc(x))
// Op: fc + act
struct
FCActOneDNN
:
public
PatternBase
{
FCActOneDNN
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"fc_act_onednn"
)
{}
PDNode
*
operator
()(
const
std
::
string
&
act_type
);
// declare operator node's name
PATTERN_DECL_NODE
(
fc
);
PATTERN_DECL_NODE
(
act
);
PATTERN_DECL_NODE
(
fc_out
);
PATTERN_DECL_NODE
(
act_out
);
};
// Embedding
// Embedding
struct
Embedding
:
public
PatternBase
{
struct
Embedding
:
public
PatternBase
{
Embedding
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
Embedding
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
...
...
paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc
0 → 100644
浏览文件 @
edc06c6a
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
using
string
::
PrettyLogDetail
;
void
FuseFCActOneDNNPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
std
::
vector
<
std
::
string
>
act_types
=
{
"gelu"
,
"tanh"
,
"sigmoid"
};
for
(
std
::
string
act_type
:
act_types
)
FuseFCAct
(
graph
,
act_type
);
}
void
FuseFCActOneDNNPass
::
FuseFCAct
(
Graph
*
graph
,
const
std
::
string
&
act_type
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
InvalidArgument
(
"Graph cannot be nullptr."
));
FusePassBase
::
Init
(
"fc_act"
,
graph
);
GraphPatternDetector
gpd
;
patterns
::
FCActOneDNN
fc_act_pattern
(
gpd
.
mutable_pattern
(),
"fc_act"
);
fc_act_pattern
(
act_type
);
int
found_fc_act_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
4
)
<<
"Fuse fc with activation op."
;
// FC output
GET_IR_NODE_FROM_SUBGRAPH
(
fc_out
,
fc_out
,
fc_act_pattern
);
// ACT output
GET_IR_NODE_FROM_SUBGRAPH
(
act_out
,
act_out
,
fc_act_pattern
);
// ops
GET_IR_NODE_FROM_SUBGRAPH
(
fc
,
fc
,
fc_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
act
,
act
,
fc_act_pattern
);
auto
*
fc_op
=
fc
->
Op
();
auto
*
act_op
=
act
->
Op
();
if
(
fc_op
->
HasAttr
(
"use_mkldnn"
))
{
PADDLE_ENFORCE
(
BOOST_GET_CONST
(
bool
,
fc_op
->
GetAttr
(
"use_mkldnn"
)),
platform
::
errors
::
PreconditionNotMet
(
"The FC+Act fusion may happen only when oneDNN library "
"is used."
));
}
if
(
act_type
==
"gelu"
&&
act_op
->
HasAttr
(
"approximate"
))
{
bool
approximate
=
BOOST_GET_CONST
(
bool
,
act_op
->
GetAttr
(
"approximate"
));
std
::
string
type
=
approximate
?
"_tanh"
:
"_erf"
;
fc_op
->
SetAttr
(
"activation_type"
,
act_type
+
type
);
}
else
fc_op
->
SetAttr
(
"activation_type"
,
act_type
);
fc_op
->
SetAttr
(
"use_mkldnn"
,
true
);
fc_op
->
SetOutput
(
"Out"
,
{
act_out
->
Name
()});
IR_OP_VAR_LINK
(
fc
,
act_out
);
GraphSafeRemoveNodes
(
g
,
{
act
,
fc_out
});
found_fc_act_count
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_fc_act_count
);
PrettyLogDetail
(
"--- fused %d fc with %s activation"
,
found_fc_act_count
,
act_type
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
fc_act_mkldnn_fuse_pass
,
paddle
::
framework
::
ir
::
FuseFCActOneDNNPass
);
REGISTER_PASS_CAPABILITY
(
fc_act_mkldnn_fuse_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
LE
(
"fc"
,
0
)
.
LE
(
"gelu"
,
0
)
.
LE
(
"sigmoid"
,
0
)
.
LE
(
"tanh"
,
0
));
paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h
0 → 100644
浏览文件 @
edc06c6a
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
/*
* \brief Fuse the FC and activation operators into single OneDNN's
* FC with post-op.
*
* \note Currently only GeLU, sigmoid and tanh are supported as an activation
* function.
*/
class
FuseFCActOneDNNPass
:
public
FusePassBase
{
public:
virtual
~
FuseFCActOneDNNPass
()
{}
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
void
FuseFCAct
(
ir
::
Graph
*
graph
,
const
std
::
string
&
act_types
)
const
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddlea
paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc
0 → 100644
浏览文件 @
edc06c6a
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <algorithm>
#include <exception>
#include <functional>
#include <iterator>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/errors.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
// -------------------------- helper functions --------------------------------
namespace
{
using
InOutVarNamePair
=
std
::
pair
<
std
::
string
,
std
::
string
>
;
using
OpTypeCountPair
=
std
::
pair
<
std
::
string
,
int
>
;
///
/// @brief Creates the specified operator and sets up its inputs/outputs.
///
/// @param prog The program descriptor to which we add new op.
/// @param[in] op_type_name The operator type name.
/// @param[in] inputs The vector of input pairs: {input_name, variable
/// name}
/// @param[in] outputs The vector of output pairs {output_name, variable}
/// @param[in] use_mkldnn The flag deciding whether or not to set
/// 'use_mkldnn' attribute.
///
/// @return Returns pointer to the created operator descriptor.
///
OpDesc
*
CreateOp
(
ProgramDesc
*
prog
,
const
std
::
string
&
op_type_name
,
const
std
::
vector
<
InOutVarNamePair
>&
inputs
,
const
std
::
vector
<
InOutVarNamePair
>&
outputs
,
bool
use_mkldnn
=
true
)
{
auto
op
=
prog
->
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
op_type_name
);
op
->
SetAttr
(
"use_mkldnn"
,
use_mkldnn
);
for
(
const
auto
&
input
:
inputs
)
{
op
->
SetInput
(
input
.
first
,
{
input
.
second
});
}
for
(
const
auto
&
output
:
outputs
)
{
op
->
SetOutput
(
output
.
first
,
{
output
.
second
});
}
return
op
;
}
///
/// @brief Check whether node 'to' is reachable from node 'from' in graph.
///
/// @param[in] graph The graph we're checking for reachability.
/// @param[in] from The 'from' node name.
/// @param[in] to The 'to' node name.
///
/// @return True if there is connection between nodes 'from' and 'to'.
///
bool
TestIsReachable
(
const
Graph
&
graph
,
std
::
string
from
,
std
::
string
to
)
{
auto
hash
=
[](
const
Node
*
node
)
->
std
::
string
{
return
node
->
Name
()
+
std
::
to_string
(
node
->
id
());
};
auto
find_node
=
[
&
](
const
Graph
&
graph
,
const
std
::
string
&
name
)
->
Node
*
{
for
(
auto
&
node
:
GraphTraits
::
DFS
(
graph
))
{
if
(
name
==
hash
(
&
node
))
{
return
&
node
;
}
}
return
nullptr
;
};
if
(
from
==
to
)
return
true
;
std
::
map
<
std
::
string
,
bool
>
visited
;
// update the from and to strings to hashed equivs in loop from graph traits
for
(
auto
&
node
:
GraphTraits
::
DFS
(
graph
))
{
auto
hashed
=
hash
(
&
node
);
if
(
node
.
Name
()
==
from
)
{
from
=
hashed
;
}
if
(
node
.
Name
()
==
to
)
{
to
=
hashed
;
}
visited
[
hashed
]
=
false
;
}
visited
[
from
]
=
true
;
std
::
list
<
std
::
string
>
queue
;
queue
.
push_back
(
from
);
while
(
!
queue
.
empty
())
{
auto
cur
=
find_node
(
graph
,
queue
.
front
());
queue
.
pop_front
();
if
(
cur
==
nullptr
)
{
return
false
;
}
for
(
auto
n
:
cur
->
outputs
)
{
auto
hashed_name
=
hash
(
n
);
if
(
hashed_name
==
to
)
{
return
true
;
}
if
(
!
visited
[
hashed_name
])
{
visited
[
hashed_name
]
=
true
;
queue
.
push_back
(
hashed_name
);
}
}
}
return
false
;
}
///
/// @brief Search through graph and counts provided operator occurences.
///
/// @param[in] graph The graph we search through.
/// @param[in] op_type_count The vector of pairs {op_type_name, op count}
///
/// @note After going through all graph nodes this function asserts
/// whether counted number for each requested op is as expected.
///
void
AssertOpsCount
(
const
Graph
&
graph
,
std
::
vector
<
OpTypeCountPair
>
op_type_count
)
{
for
(
auto
*
node
:
graph
.
Nodes
())
{
if
(
!
node
->
IsOp
())
{
continue
;
}
const
std
::
string
op_type_name
=
node
->
Op
()
->
Type
();
auto
op_it
=
std
::
find_if
(
std
::
begin
(
op_type_count
),
std
::
end
(
op_type_count
),
[
op_type_name
](
const
OpTypeCountPair
&
p
)
{
return
op_type_name
==
p
.
first
;
});
if
(
op_it
!=
std
::
end
(
op_type_count
))
{
op_it
->
second
--
;
}
}
for
(
const
OpTypeCountPair
&
p
:
op_type_count
)
{
EXPECT_EQ
(
p
.
second
,
0
);
}
}
///
/// @brief Builds a program descriptor.
///
/// @param[in] transient_vars The vector of transient variables names.
/// @param[in] persistent_vars The vector of persistent variables names. Those
/// will have persistable attribute set to true.
///
/// @return The program descriptor object.
///
ProgramDesc
BuildProgramDesc
(
const
std
::
vector
<
std
::
string
>&
transient_vars
,
const
std
::
vector
<
std
::
string
>&
persistent_vars
)
{
ProgramDesc
prog
;
auto
add_var_to_prog
=
[
&
prog
](
const
std
::
string
&
var_name
)
->
VarDesc
*
{
auto
var
=
prog
.
MutableBlock
(
0
)
->
Var
(
var_name
);
var
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
return
var
;
};
for
(
const
auto
&
v
:
transient_vars
)
{
add_var_to_prog
(
v
);
}
for
(
const
auto
&
v
:
persistent_vars
)
{
auto
*
var
=
add_var_to_prog
(
v
);
var
->
SetPersistable
(
true
);
}
return
prog
;
}
///
/// @brief Execute pass on provided graph and perform checks.
///
/// @param graph The graph we run pass on.
/// @param[in] from The name of a 'starting' node sequence in a
/// graph. This would be used to test for
/// correct node connections.
/// @param[in] to The name of a 'ending' node sequence in a
/// graph. This would be used to test for
/// correct node connections.
/// @param[in] removed_nodes_count The number of nodes we expect will be
/// removed/fused after pass execution.
/// @param[in] added_nodes_count The number of nodes we expect will be
/// added after pass execution.
///
void
RunPassAndAssert
(
Graph
*
graph
,
const
std
::
string
&
from
,
const
std
::
string
&
to
,
int
removed_nodes_count
,
int
added_nodes_count
=
0
)
{
EXPECT_TRUE
(
TestIsReachable
(
*
graph
,
from
,
to
));
int
original_nodes_num
=
graph
->
Nodes
().
size
();
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"fc_act_mkldnn_fuse_pass"
);
pass
->
Apply
(
graph
);
int
current_nodes_num
=
graph
->
Nodes
().
size
();
EXPECT_TRUE
(
TestIsReachable
(
*
graph
,
from
,
to
));
EXPECT_EQ
(
original_nodes_num
-
removed_nodes_count
+
added_nodes_count
,
current_nodes_num
);
}
}
// namespace
// ------------------------------ Test cases -----------------------------------
TEST
(
FuseFCActOneDNNPass
,
ThrowUseMkldnn
)
{
auto
prog
=
BuildProgramDesc
({
"x"
,
"fc_y"
,
"act_y"
},
{
"weights"
,
"bias"
});
CreateOp
(
&
prog
,
"fc"
,
{
{
"Input"
,
"x"
},
{
"Weights"
,
"weights"
},
{
"Bias"
,
"bias"
},
},
{{
"Out"
,
"fc_y"
}},
false
);
CreateOp
(
&
prog
,
"gelu"
,
{{
"Input"
,
"fc_y"
}},
{{
"Out"
,
"act_y"
}},
false
);
Graph
graph
(
prog
);
// No fusion in this attribute configuration
constexpr
int
removed_nodes_count
=
0
;
EXPECT_THROW
(
RunPassAndAssert
(
&
graph
,
"x"
,
"act_y"
,
removed_nodes_count
),
paddle
::
platform
::
EnforceNotMet
);
}
TEST
(
FuseFCActOneDNNPass
,
FuseWithGeluTanh
)
{
auto
prog
=
BuildProgramDesc
({
"x"
,
"fc_y"
,
"act_y"
},
{
"weights"
,
"bias"
});
CreateOp
(
&
prog
,
"fc"
,
{
{
"Input"
,
"x"
},
{
"Weights"
,
"weights"
},
{
"Bias"
,
"bias"
},
},
{{
"Out"
,
"fc_y"
}});
auto
*
act_op
=
CreateOp
(
&
prog
,
"gelu"
,
{{
"Input"
,
"fc_y"
}},
{{
"Out"
,
"act_y"
}},
false
);
act_op
->
SetAttr
(
"approximate"
,
true
);
Graph
graph
(
prog
);
constexpr
int
removed_nodes_count
=
2
;
RunPassAndAssert
(
&
graph
,
"x"
,
"act_y"
,
removed_nodes_count
);
AssertOpsCount
(
graph
,
{{
"fc"
,
1
},
{
"gelu"
,
0
}});
for
(
const
auto
*
node
:
graph
.
Nodes
())
{
if
(
node
->
IsOp
()
&&
node
->
Op
()
->
Type
()
==
"fc"
)
{
const
auto
*
op
=
node
->
Op
();
ASSERT_TRUE
(
op
->
HasAttr
(
"use_mkldnn"
));
EXPECT_TRUE
(
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"use_mkldnn"
)));
ASSERT_TRUE
(
op
->
HasAttr
(
"activation_type"
));
auto
act_type
=
BOOST_GET_CONST
(
std
::
string
,
op
->
GetAttr
(
"activation_type"
));
EXPECT_TRUE
(
act_type
.
compare
(
"gelu_tanh"
)
==
0
);
}
}
}
TEST
(
FuseFCActOneDNNPass
,
FuseWithGeluErf
)
{
auto
prog
=
BuildProgramDesc
({
"x"
,
"fc_y"
,
"act_y"
},
{
"weights"
,
"bias"
});
CreateOp
(
&
prog
,
"fc"
,
{
{
"Input"
,
"x"
},
{
"Weights"
,
"weights"
},
{
"Bias"
,
"bias"
},
},
{{
"Out"
,
"fc_y"
}});
auto
*
act_op
=
CreateOp
(
&
prog
,
"gelu"
,
{{
"Input"
,
"fc_y"
}},
{{
"Out"
,
"act_y"
}},
false
);
act_op
->
SetAttr
(
"approximate"
,
false
);
Graph
graph
(
prog
);
constexpr
int
removed_nodes_count
=
2
;
RunPassAndAssert
(
&
graph
,
"x"
,
"act_y"
,
removed_nodes_count
);
AssertOpsCount
(
graph
,
{{
"fc"
,
1
},
{
"gelu"
,
0
}});
for
(
const
auto
*
node
:
graph
.
Nodes
())
{
if
(
node
->
IsOp
()
&&
node
->
Op
()
->
Type
()
==
"fc"
)
{
const
auto
*
op
=
node
->
Op
();
ASSERT_TRUE
(
op
->
HasAttr
(
"use_mkldnn"
));
EXPECT_TRUE
(
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"use_mkldnn"
)));
ASSERT_TRUE
(
op
->
HasAttr
(
"activation_type"
));
auto
act_type
=
BOOST_GET_CONST
(
std
::
string
,
op
->
GetAttr
(
"activation_type"
));
EXPECT_TRUE
(
act_type
.
compare
(
"gelu_erf"
)
==
0
);
}
}
}
TEST
(
FuseFCActOneDNNPass
,
FuseWithGeluAuto
)
{
auto
prog
=
BuildProgramDesc
({
"x"
,
"fc_y"
,
"act_y"
},
{
"weights"
,
"bias"
});
CreateOp
(
&
prog
,
"fc"
,
{
{
"Input"
,
"x"
},
{
"Weights"
,
"weights"
},
{
"Bias"
,
"bias"
},
},
{{
"Out"
,
"fc_y"
}});
CreateOp
(
&
prog
,
"gelu"
,
{{
"Input"
,
"fc_y"
}},
{{
"Out"
,
"act_y"
}},
false
);
Graph
graph
(
prog
);
constexpr
int
removed_nodes_count
=
2
;
RunPassAndAssert
(
&
graph
,
"x"
,
"act_y"
,
removed_nodes_count
);
AssertOpsCount
(
graph
,
{{
"fc"
,
1
},
{
"gelu"
,
0
}});
for
(
const
auto
*
node
:
graph
.
Nodes
())
{
if
(
node
->
IsOp
()
&&
node
->
Op
()
->
Type
()
==
"fc"
)
{
const
auto
*
op
=
node
->
Op
();
ASSERT_TRUE
(
op
->
HasAttr
(
"use_mkldnn"
));
EXPECT_TRUE
(
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"use_mkldnn"
)));
ASSERT_TRUE
(
op
->
HasAttr
(
"activation_type"
));
auto
act_type
=
BOOST_GET_CONST
(
std
::
string
,
op
->
GetAttr
(
"activation_type"
));
EXPECT_TRUE
(
act_type
.
compare
(
"gelu"
)
==
0
);
}
}
}
TEST
(
FuseFCActOneDNNPass
,
FuseWithTanh
)
{
auto
prog
=
BuildProgramDesc
({
"x"
,
"fc_y"
,
"act_y"
},
{
"weights"
,
"bias"
});
CreateOp
(
&
prog
,
"fc"
,
{
{
"Input"
,
"x"
},
{
"Weights"
,
"weights"
},
{
"Bias"
,
"bias"
},
},
{{
"Out"
,
"fc_y"
}});
CreateOp
(
&
prog
,
"tanh"
,
{{
"Input"
,
"fc_y"
}},
{{
"Out"
,
"act_y"
}},
false
);
Graph
graph
(
prog
);
constexpr
int
removed_nodes_count
=
2
;
RunPassAndAssert
(
&
graph
,
"x"
,
"act_y"
,
removed_nodes_count
);
AssertOpsCount
(
graph
,
{{
"fc"
,
1
},
{
"tanh"
,
0
}});
for
(
const
auto
*
node
:
graph
.
Nodes
())
{
if
(
node
->
IsOp
()
&&
node
->
Op
()
->
Type
()
==
"fc"
)
{
const
auto
*
op
=
node
->
Op
();
ASSERT_TRUE
(
op
->
HasAttr
(
"use_mkldnn"
));
EXPECT_TRUE
(
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"use_mkldnn"
)));
ASSERT_TRUE
(
op
->
HasAttr
(
"activation_type"
));
auto
act_type
=
BOOST_GET_CONST
(
std
::
string
,
op
->
GetAttr
(
"activation_type"
));
EXPECT_TRUE
(
act_type
.
compare
(
"tanh"
)
==
0
);
}
}
}
TEST
(
FuseFCActOneDNNPass
,
FuseWithSigmoid
)
{
auto
prog
=
BuildProgramDesc
({
"x"
,
"fc_y"
,
"act_y"
},
{
"weights"
,
"bias"
});
CreateOp
(
&
prog
,
"fc"
,
{
{
"Input"
,
"x"
},
{
"Weights"
,
"weights"
},
{
"Bias"
,
"bias"
},
},
{{
"Out"
,
"fc_y"
}});
CreateOp
(
&
prog
,
"sigmoid"
,
{{
"Input"
,
"fc_y"
}},
{{
"Out"
,
"act_y"
}},
false
);
Graph
graph
(
prog
);
constexpr
int
removed_nodes_count
=
2
;
RunPassAndAssert
(
&
graph
,
"x"
,
"act_y"
,
removed_nodes_count
);
AssertOpsCount
(
graph
,
{{
"fc"
,
1
},
{
"sigmoid"
,
0
}});
for
(
const
auto
*
node
:
graph
.
Nodes
())
{
if
(
node
->
IsOp
()
&&
node
->
Op
()
->
Type
()
==
"fc"
)
{
const
auto
*
op
=
node
->
Op
();
ASSERT_TRUE
(
op
->
HasAttr
(
"use_mkldnn"
));
EXPECT_TRUE
(
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"use_mkldnn"
)));
ASSERT_TRUE
(
op
->
HasAttr
(
"activation_type"
));
auto
act_type
=
BOOST_GET_CONST
(
std
::
string
,
op
->
GetAttr
(
"activation_type"
));
EXPECT_TRUE
(
act_type
.
compare
(
"sigmoid"
)
==
0
);
}
}
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
USE_PASS
(
fc_act_mkldnn_fuse_pass
);
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
edc06c6a
...
@@ -206,7 +206,8 @@ void CpuPassStrategy::EnableMKLDNN() {
...
@@ -206,7 +206,8 @@ void CpuPassStrategy::EnableMKLDNN() {
"reshape_transpose_matmul_mkldnn_fuse_pass"
,
//
"reshape_transpose_matmul_mkldnn_fuse_pass"
,
//
"matmul_transpose_reshape_fuse_pass"
,
//
"matmul_transpose_reshape_fuse_pass"
,
//
// Disabled due to topology-dependent speed-up
// Disabled due to topology-dependent speed-up
// "fc_mkldnn_pass",
//"fc_mkldnn_pass",
//"fc_act_mkldnn_fuse_pass",
"batch_norm_act_fuse_pass"
,
"batch_norm_act_fuse_pass"
,
"mkldnn_inplace_pass"
,
// This pass should be activated after
"mkldnn_inplace_pass"
,
// This pass should be activated after
// fuses
// fuses
...
...
paddle/fluid/inference/tests/api/analyzer_dam_tester.cc
浏览文件 @
edc06c6a
...
@@ -206,6 +206,7 @@ void profile(bool use_mkldnn = false) {
...
@@ -206,6 +206,7 @@ void profile(bool use_mkldnn = false) {
"relu"
,
"fc"
};
"relu"
,
"fc"
};
cfg
.
SetMKLDNNOp
(
op_list
);
cfg
.
SetMKLDNNOp
(
op_list
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_act_mkldnn_fuse_pass"
);
}
}
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
outputs
;
...
@@ -262,6 +263,7 @@ void compare(bool use_mkldnn = false) {
...
@@ -262,6 +263,7 @@ void compare(bool use_mkldnn = false) {
"relu"
};
"relu"
};
cfg
.
SetMKLDNNOp
(
op_list
);
cfg
.
SetMKLDNNOp
(
op_list
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_act_mkldnn_fuse_pass"
);
}
}
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
...
...
paddle/fluid/inference/tests/api/analyzer_image_classification_tester.cc
浏览文件 @
edc06c6a
...
@@ -50,8 +50,10 @@ void profile(bool use_mkldnn = false) {
...
@@ -50,8 +50,10 @@ void profile(bool use_mkldnn = false) {
if
(
use_mkldnn
)
{
if
(
use_mkldnn
)
{
cfg
.
EnableMKLDNN
();
cfg
.
EnableMKLDNN
();
if
(
!
FLAGS_disable_mkldnn_fc
)
if
(
!
FLAGS_disable_mkldnn_fc
)
{
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_act_mkldnn_fuse_pass"
);
}
}
}
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
outputs
;
...
@@ -83,8 +85,10 @@ void compare(bool use_mkldnn = false) {
...
@@ -83,8 +85,10 @@ void compare(bool use_mkldnn = false) {
SetConfig
(
&
cfg
);
SetConfig
(
&
cfg
);
if
(
use_mkldnn
)
{
if
(
use_mkldnn
)
{
cfg
.
EnableMKLDNN
();
cfg
.
EnableMKLDNN
();
if
(
!
FLAGS_disable_mkldnn_fc
)
if
(
!
FLAGS_disable_mkldnn_fc
)
{
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_act_mkldnn_fuse_pass"
);
}
}
}
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
...
...
paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester_helper.h
浏览文件 @
edc06c6a
...
@@ -163,6 +163,7 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) {
...
@@ -163,6 +163,7 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) {
if
(
use_mkldnn
)
{
if
(
use_mkldnn
)
{
cfg
->
EnableMKLDNN
();
cfg
->
EnableMKLDNN
();
cfg
->
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
->
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
->
pass_builder
()
->
AppendPass
(
"fc_act_mkldnn_fuse_pass"
);
}
}
// Enable seqpool_concat_fuse_pass, disabled by default since it takes much
// Enable seqpool_concat_fuse_pass, disabled by default since it takes much
// time
// time
...
...
paddle/fluid/inference/tests/api/analyzer_transformer_compare_tester.cc
浏览文件 @
edc06c6a
...
@@ -25,6 +25,7 @@ void compare(bool use_mkldnn = false) {
...
@@ -25,6 +25,7 @@ void compare(bool use_mkldnn = false) {
if
(
use_mkldnn
)
{
if
(
use_mkldnn
)
{
cfg
.
EnableMKLDNN
();
cfg
.
EnableMKLDNN
();
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_act_mkldnn_fuse_pass"
);
}
}
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
...
...
paddle/fluid/inference/tests/api/analyzer_transformer_profile_tester.cc
浏览文件 @
edc06c6a
...
@@ -26,6 +26,7 @@ void profile(bool use_mkldnn = false) {
...
@@ -26,6 +26,7 @@ void profile(bool use_mkldnn = false) {
if
(
use_mkldnn
)
{
if
(
use_mkldnn
)
{
cfg
.
EnableMKLDNN
();
cfg
.
EnableMKLDNN
();
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_act_mkldnn_fuse_pass"
);
}
}
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
...
...
paddle/fluid/inference/tests/api/analyzer_vis_tester.cc
浏览文件 @
edc06c6a
...
@@ -86,6 +86,7 @@ void profile(bool use_mkldnn = false) {
...
@@ -86,6 +86,7 @@ void profile(bool use_mkldnn = false) {
if
(
use_mkldnn
)
{
if
(
use_mkldnn
)
{
cfg
.
EnableMKLDNN
();
cfg
.
EnableMKLDNN
();
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_act_mkldnn_fuse_pass"
);
}
}
// cfg.pass_builder()->TurnOnDebug();
// cfg.pass_builder()->TurnOnDebug();
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
outputs
;
...
@@ -136,6 +137,7 @@ void compare(bool use_mkldnn = false) {
...
@@ -136,6 +137,7 @@ void compare(bool use_mkldnn = false) {
if
(
use_mkldnn
)
{
if
(
use_mkldnn
)
{
cfg
.
EnableMKLDNN
();
cfg
.
EnableMKLDNN
();
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_act_mkldnn_fuse_pass"
);
}
}
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
...
...
paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc
浏览文件 @
edc06c6a
...
@@ -459,6 +459,36 @@ class FCPrimitiveFactory {
...
@@ -459,6 +459,36 @@ class FCPrimitiveFactory {
constexpr
float
placeholder
=
1.0
f
;
// beta
constexpr
float
placeholder
=
1.0
f
;
// beta
post_operations
.
append_eltwise
(
scale
,
mkldnn
::
algorithm
::
eltwise_relu
,
post_operations
.
append_eltwise
(
scale
,
mkldnn
::
algorithm
::
eltwise_relu
,
negative_slope
,
placeholder
);
negative_slope
,
placeholder
);
}
else
if
(
ctx
.
Attr
<
std
::
string
>
(
"activation_type"
)
==
"gelu"
)
{
constexpr
float
scale
=
1.0
f
;
constexpr
float
alpha
=
0.0
f
;
constexpr
float
beta
=
0.0
f
;
post_operations
.
append_eltwise
(
scale
,
mkldnn
::
algorithm
::
eltwise_gelu
,
alpha
,
beta
);
}
else
if
(
ctx
.
Attr
<
std
::
string
>
(
"activation_type"
)
==
"gelu_tanh"
)
{
constexpr
float
scale
=
1.0
f
;
constexpr
float
alpha
=
0.0
f
;
constexpr
float
beta
=
0.0
f
;
post_operations
.
append_eltwise
(
scale
,
mkldnn
::
algorithm
::
eltwise_gelu_tanh
,
alpha
,
beta
);
}
else
if
(
ctx
.
Attr
<
std
::
string
>
(
"activation_type"
)
==
"gelu_erf"
)
{
constexpr
float
scale
=
1.0
f
;
constexpr
float
alpha
=
0.0
f
;
constexpr
float
beta
=
0.0
f
;
post_operations
.
append_eltwise
(
scale
,
mkldnn
::
algorithm
::
eltwise_gelu_erf
,
alpha
,
beta
);
}
else
if
(
ctx
.
Attr
<
std
::
string
>
(
"activation_type"
)
==
"tanh"
)
{
constexpr
float
scale
=
1.0
f
;
constexpr
float
alpha
=
0.0
f
;
constexpr
float
beta
=
0.0
f
;
post_operations
.
append_eltwise
(
scale
,
mkldnn
::
algorithm
::
eltwise_tanh
,
alpha
,
beta
);
}
else
if
(
ctx
.
Attr
<
std
::
string
>
(
"activation_type"
)
==
"sigmoid"
)
{
constexpr
float
scale
=
1.0
f
;
constexpr
float
alpha
=
0.0
f
;
constexpr
float
beta
=
0.0
f
;
post_operations
.
append_eltwise
(
scale
,
mkldnn
::
algorithm
::
eltwise_logistic
,
alpha
,
beta
);
}
}
attributes
.
set_post_ops
(
post_operations
);
attributes
.
set_post_ops
(
post_operations
);
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_fc_act_fuse_pass.py
0 → 100644
浏览文件 @
edc06c6a
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for fusion of fc and activation."""
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
paddle.fluid
as
fluid
from
inference_pass_test
import
InferencePassTest
from
paddle
import
enable_static
from
paddle.fluid.core
import
PassVersionChecker
enable_static
()
class
FCGeluTanhOneDnnFusePassTest
(
InferencePassTest
):
def
setUp
(
self
):
self
.
set_params
()
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
128
,
768
],
dtype
=
"float32"
)
fc_out
=
fluid
.
layers
.
fc
(
input
=
data
,
size
=
3072
,
num_flatten_dims
=
2
)
gelu_out
=
fluid
.
layers
.
gelu
(
fc_out
,
approximate
=
False
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
((
1
,
128
,
768
)).
astype
(
"float32"
)}
self
.
fetch_list
=
[
gelu_out
]
self
.
enable_mkldnn
=
True
def
set_params
(
self
):
self
.
pass_name
=
"fc_act_mkldnn_fuse_pass"
def
test_check_output
(
self
):
self
.
check_output
()
class
FCGeluErfOneDnnFusePassTest
(
InferencePassTest
):
def
setUp
(
self
):
self
.
set_params
()
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
128
,
768
],
dtype
=
"float32"
)
fc_out
=
fluid
.
layers
.
fc
(
input
=
data
,
size
=
3072
,
num_flatten_dims
=
2
)
gelu_out
=
fluid
.
layers
.
gelu
(
fc_out
,
approximate
=
True
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
((
1
,
128
,
768
)).
astype
(
"float32"
)}
self
.
fetch_list
=
[
gelu_out
]
self
.
enable_mkldnn
=
True
def
set_params
(
self
):
self
.
pass_name
=
"fc_act_mkldnn_fuse_pass"
def
test_check_output
(
self
):
self
.
check_output
()
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
self
.
pass_name
))
class
FCTanhOneDnnFusePassTest
(
InferencePassTest
):
def
setUp
(
self
):
self
.
set_params
()
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
128
,
768
],
dtype
=
"float32"
)
fc_out
=
fluid
.
layers
.
fc
(
input
=
data
,
size
=
3072
,
num_flatten_dims
=
2
)
tanh_out
=
fluid
.
layers
.
tanh
(
fc_out
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
((
1
,
128
,
768
)).
astype
(
"float32"
)}
self
.
fetch_list
=
[
tanh_out
]
self
.
enable_mkldnn
=
True
def
set_params
(
self
):
self
.
pass_name
=
"fc_act_mkldnn_fuse_pass"
def
test_check_output
(
self
):
self
.
check_output
()
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
self
.
pass_name
))
class
FCSigmoidOneDnnFusePassTest
(
InferencePassTest
):
def
setUp
(
self
):
self
.
set_params
()
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
128
,
768
],
dtype
=
"float32"
)
fc_out
=
fluid
.
layers
.
fc
(
input
=
data
,
size
=
3072
,
num_flatten_dims
=
2
)
sigmoid_out
=
fluid
.
layers
.
sigmoid
(
fc_out
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
((
1
,
128
,
768
)).
astype
(
"float32"
)}
self
.
fetch_list
=
[
sigmoid_out
]
self
.
enable_mkldnn
=
True
def
set_params
(
self
):
self
.
pass_name
=
"fc_act_mkldnn_fuse_pass"
def
test_check_output
(
self
):
self
.
check_output
()
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
self
.
pass_name
))
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录