Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
edc06c6a
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
edc06c6a
编写于
12月 24, 2020
作者:
J
jakpiase
提交者:
GitHub
12月 24, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Added fc + activation fuse pass (currently only gelu, sigmoid and tanh are supported) (#29772)
上级
0e0bb1b9
变更
15
显示空白变更内容
内联
并排
Showing
15 changed file
with
744 addition
and
3 deletion
+744
-3
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+2
-0
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+17
-0
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+21
-0
paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc
paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc
+100
-0
paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h
paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h
+45
-0
paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc
...uid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc
+398
-0
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+2
-1
paddle/fluid/inference/tests/api/analyzer_dam_tester.cc
paddle/fluid/inference/tests/api/analyzer_dam_tester.cc
+2
-0
paddle/fluid/inference/tests/api/analyzer_image_classification_tester.cc
...ference/tests/api/analyzer_image_classification_tester.cc
+6
-2
paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester_helper.h
...id/inference/tests/api/analyzer_seq_pool1_tester_helper.h
+1
-0
paddle/fluid/inference/tests/api/analyzer_transformer_compare_tester.cc
...nference/tests/api/analyzer_transformer_compare_tester.cc
+1
-0
paddle/fluid/inference/tests/api/analyzer_transformer_profile_tester.cc
...nference/tests/api/analyzer_transformer_profile_tester.cc
+1
-0
paddle/fluid/inference/tests/api/analyzer_vis_tester.cc
paddle/fluid/inference/tests/api/analyzer_vis_tester.cc
+2
-0
paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc
paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc
+30
-0
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_fc_act_fuse_pass.py
...ts/unittests/ir/inference/test_mkldnn_fc_act_fuse_pass.py
+116
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
edc06c6a
...
@@ -105,6 +105,7 @@ if(WITH_MKLDNN)
...
@@ -105,6 +105,7 @@ if(WITH_MKLDNN)
pass_library
(
cpu_bfloat16_placement_pass inference DIR mkldnn
)
pass_library
(
cpu_bfloat16_placement_pass inference DIR mkldnn
)
pass_library
(
cpu_bfloat16_pass inference DIR mkldnn
)
pass_library
(
cpu_bfloat16_pass inference DIR mkldnn
)
pass_library
(
fc_mkldnn_pass inference DIR mkldnn
)
pass_library
(
fc_mkldnn_pass inference DIR mkldnn
)
pass_library
(
fc_act_mkldnn_fuse_pass inference DIR mkldnn
)
pass_library
(
cpu_quantize_placement_pass base DIR mkldnn
)
pass_library
(
cpu_quantize_placement_pass base DIR mkldnn
)
pass_library
(
cpu_quantize_pass inference DIR mkldnn
)
pass_library
(
cpu_quantize_pass inference DIR mkldnn
)
pass_library
(
cpu_quantize_squash_pass inference DIR mkldnn
)
pass_library
(
cpu_quantize_squash_pass inference DIR mkldnn
)
...
@@ -155,6 +156,7 @@ if (WITH_MKLDNN)
...
@@ -155,6 +156,7 @@ if (WITH_MKLDNN)
cc_test
(
test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass
)
cc_test
(
test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass
)
cc_test
(
test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass
)
cc_test
(
test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass
)
cc_test
(
test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass
)
cc_test
(
test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass
)
cc_test
(
test_fc_act_mkldnn_fuse_pass SRCS mkldnn/fc_act_mkldnn_fuse_pass_tester.cc DEPS fc_act_mkldnn_fuse_pass
)
cc_test
(
test_batch_norm_act_fuse_pass SRCS mkldnn/batch_norm_act_fuse_pass_tester.cc DEPS batch_norm_act_fuse_pass
)
cc_test
(
test_batch_norm_act_fuse_pass SRCS mkldnn/batch_norm_act_fuse_pass_tester.cc DEPS batch_norm_act_fuse_pass
)
set
(
TEST_CONV_BN_PASS_DEPS conv_bn_fuse_pass graph_to_program_pass conv_op conv_transpose_op math_function im2col vol2col batch_norm_op gelu_op activation_op elementwise_add_op concat_and_split naive_executor device_context
)
set
(
TEST_CONV_BN_PASS_DEPS conv_bn_fuse_pass graph_to_program_pass conv_op conv_transpose_op math_function im2col vol2col batch_norm_op gelu_op activation_op elementwise_add_op concat_and_split naive_executor device_context
)
if
(
WITH_GPU
)
if
(
WITH_GPU
)
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
edc06c6a
...
@@ -1017,6 +1017,23 @@ PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x,
...
@@ -1017,6 +1017,23 @@ PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x,
return
fc_out_var
;
return
fc_out_var
;
}
}
PDNode
*
patterns
::
FCActOneDNN
::
operator
()(
const
std
::
string
&
act_type
)
{
auto
*
fc
=
pattern
->
NewNode
(
fc_repr
())
->
assert_is_op
(
"fc"
);
auto
*
fc_out
=
pattern
->
NewNode
(
fc_out_repr
())
->
assert_is_op_output
(
"fc"
,
"Out"
)
->
assert_is_op_input
(
act_type
);
auto
*
act
=
pattern
->
NewNode
(
act_repr
())
->
assert_is_op
(
act_type
)
->
AsIntermediate
();
auto
*
act_out
=
pattern
->
NewNode
(
act_out_repr
())
->
assert_is_op_output
(
act_type
,
"Out"
)
->
AsOutput
();
fc
->
LinksTo
({
fc_out
});
act
->
LinksFrom
({
fc_out
}).
LinksTo
({
act_out
});
return
act_out
;
}
PDNode
*
patterns
::
Embedding
::
operator
()(
PDNode
*
x
)
{
PDNode
*
patterns
::
Embedding
::
operator
()(
PDNode
*
x
)
{
x
->
assert_is_op_input
(
"lookup_table"
,
"Ids"
);
x
->
assert_is_op_input
(
"lookup_table"
,
"Ids"
);
auto
*
lookup_table_op
=
auto
*
lookup_table_op
=
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
edc06c6a
...
@@ -552,6 +552,27 @@ struct FCMKLDNN : public PatternBase {
...
@@ -552,6 +552,27 @@ struct FCMKLDNN : public PatternBase {
PATTERN_DECL_NODE
(
output
);
PATTERN_DECL_NODE
(
output
);
};
};
//
// \brief Pattern looking for fc and a directly following activation
// operator.
//
// \note Currently only gelu and tanh are supported as an activation
// function.
// Formula: act(fc(x))
// Op: fc + act
struct
FCActOneDNN
:
public
PatternBase
{
FCActOneDNN
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"fc_act_onednn"
)
{}
PDNode
*
operator
()(
const
std
::
string
&
act_type
);
// declare operator node's name
PATTERN_DECL_NODE
(
fc
);
PATTERN_DECL_NODE
(
act
);
PATTERN_DECL_NODE
(
fc_out
);
PATTERN_DECL_NODE
(
act_out
);
};
// Embedding
// Embedding
struct
Embedding
:
public
PatternBase
{
struct
Embedding
:
public
PatternBase
{
Embedding
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
Embedding
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
...
...
paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc
0 → 100644
浏览文件 @
edc06c6a
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
using
string
::
PrettyLogDetail
;
void
FuseFCActOneDNNPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
std
::
vector
<
std
::
string
>
act_types
=
{
"gelu"
,
"tanh"
,
"sigmoid"
};
for
(
std
::
string
act_type
:
act_types
)
FuseFCAct
(
graph
,
act_type
);
}
void
FuseFCActOneDNNPass
::
FuseFCAct
(
Graph
*
graph
,
const
std
::
string
&
act_type
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
InvalidArgument
(
"Graph cannot be nullptr."
));
FusePassBase
::
Init
(
"fc_act"
,
graph
);
GraphPatternDetector
gpd
;
patterns
::
FCActOneDNN
fc_act_pattern
(
gpd
.
mutable_pattern
(),
"fc_act"
);
fc_act_pattern
(
act_type
);
int
found_fc_act_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
4
)
<<
"Fuse fc with activation op."
;
// FC output
GET_IR_NODE_FROM_SUBGRAPH
(
fc_out
,
fc_out
,
fc_act_pattern
);
// ACT output
GET_IR_NODE_FROM_SUBGRAPH
(
act_out
,
act_out
,
fc_act_pattern
);
// ops
GET_IR_NODE_FROM_SUBGRAPH
(
fc
,
fc
,
fc_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
act
,
act
,
fc_act_pattern
);
auto
*
fc_op
=
fc
->
Op
();
auto
*
act_op
=
act
->
Op
();
if
(
fc_op
->
HasAttr
(
"use_mkldnn"
))
{
PADDLE_ENFORCE
(
BOOST_GET_CONST
(
bool
,
fc_op
->
GetAttr
(
"use_mkldnn"
)),
platform
::
errors
::
PreconditionNotMet
(
"The FC+Act fusion may happen only when oneDNN library "
"is used."
));
}
if
(
act_type
==
"gelu"
&&
act_op
->
HasAttr
(
"approximate"
))
{
bool
approximate
=
BOOST_GET_CONST
(
bool
,
act_op
->
GetAttr
(
"approximate"
));
std
::
string
type
=
approximate
?
"_tanh"
:
"_erf"
;
fc_op
->
SetAttr
(
"activation_type"
,
act_type
+
type
);
}
else
fc_op
->
SetAttr
(
"activation_type"
,
act_type
);
fc_op
->
SetAttr
(
"use_mkldnn"
,
true
);
fc_op
->
SetOutput
(
"Out"
,
{
act_out
->
Name
()});
IR_OP_VAR_LINK
(
fc
,
act_out
);
GraphSafeRemoveNodes
(
g
,
{
act
,
fc_out
});
found_fc_act_count
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_fc_act_count
);
PrettyLogDetail
(
"--- fused %d fc with %s activation"
,
found_fc_act_count
,
act_type
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
fc_act_mkldnn_fuse_pass
,
paddle
::
framework
::
ir
::
FuseFCActOneDNNPass
);
REGISTER_PASS_CAPABILITY
(
fc_act_mkldnn_fuse_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
LE
(
"fc"
,
0
)
.
LE
(
"gelu"
,
0
)
.
LE
(
"sigmoid"
,
0
)
.
LE
(
"tanh"
,
0
));
paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h
0 → 100644
浏览文件 @
edc06c6a
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
/*
* \brief Fuse the FC and activation operators into single OneDNN's
* FC with post-op.
*
* \note Currently only GeLU, sigmoid and tanh are supported as an activation
* function.
*/
class
FuseFCActOneDNNPass
:
public
FusePassBase
{
public:
virtual
~
FuseFCActOneDNNPass
()
{}
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
void
FuseFCAct
(
ir
::
Graph
*
graph
,
const
std
::
string
&
act_types
)
const
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddlea
paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc
0 → 100644
浏览文件 @
edc06c6a
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <algorithm>
#include <exception>
#include <functional>
#include <iterator>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/errors.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
// -------------------------- helper functions --------------------------------
namespace
{
using
InOutVarNamePair
=
std
::
pair
<
std
::
string
,
std
::
string
>
;
using
OpTypeCountPair
=
std
::
pair
<
std
::
string
,
int
>
;
///
/// @brief Creates the specified operator and sets up its inputs/outputs.
///
/// @param prog The program descriptor to which we add new op.
/// @param[in] op_type_name The operator type name.
/// @param[in] inputs The vector of input pairs: {input_name, variable
/// name}
/// @param[in] outputs The vector of output pairs {output_name, variable}
/// @param[in] use_mkldnn The flag deciding whether or not to set
/// 'use_mkldnn' attribute.
///
/// @return Returns pointer to the created operator descriptor.
///
OpDesc
*
CreateOp
(
ProgramDesc
*
prog
,
const
std
::
string
&
op_type_name
,
const
std
::
vector
<
InOutVarNamePair
>&
inputs
,
const
std
::
vector
<
InOutVarNamePair
>&
outputs
,
bool
use_mkldnn
=
true
)
{
auto
op
=
prog
->
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
op_type_name
);
op
->
SetAttr
(
"use_mkldnn"
,
use_mkldnn
);
for
(
const
auto
&
input
:
inputs
)
{
op
->
SetInput
(
input
.
first
,
{
input
.
second
});
}
for
(
const
auto
&
output
:
outputs
)
{
op
->
SetOutput
(
output
.
first
,
{
output
.
second
});
}
return
op
;
}
///
/// @brief Check whether node 'to' is reachable from node 'from' in graph.
///
/// @param[in] graph The graph we're checking for reachability.
/// @param[in] from The 'from' node name.
/// @param[in] to The 'to' node name.
///
/// @return True if there is connection between nodes 'from' and 'to'.
///
bool
TestIsReachable
(
const
Graph
&
graph
,
std
::
string
from
,
std
::
string
to
)
{
auto
hash
=
[](
const
Node
*
node
)
->
std
::
string
{
return
node
->
Name
()
+
std
::
to_string
(
node
->
id
());
};
auto
find_node
=
[
&
](
const
Graph
&
graph
,
const
std
::
string
&
name
)
->
Node
*
{
for
(
auto
&
node
:
GraphTraits
::
DFS
(
graph
))
{
if
(
name
==
hash
(
&
node
))
{
return
&
node
;
}
}
return
nullptr
;
};
if
(
from
==
to
)
return
true
;
std
::
map
<
std
::
string
,
bool
>
visited
;
// update the from and to strings to hashed equivs in loop from graph traits
for
(
auto
&
node
:
GraphTraits
::
DFS
(
graph
))
{
auto
hashed
=
hash
(
&
node
);
if
(
node
.
Name
()
==
from
)
{
from
=
hashed
;
}
if
(
node
.
Name
()
==
to
)
{
to
=
hashed
;
}
visited
[
hashed
]
=
false
;
}
visited
[
from
]
=
true
;
std
::
list
<
std
::
string
>
queue
;
queue
.
push_back
(
from
);
while
(
!
queue
.
empty
())
{
auto
cur
=
find_node
(
graph
,
queue
.
front
());
queue
.
pop_front
();
if
(
cur
==
nullptr
)
{
return
false
;
}
for
(
auto
n
:
cur
->
outputs
)
{
auto
hashed_name
=
hash
(
n
);
if
(
hashed_name
==
to
)
{
return
true
;
}
if
(
!
visited
[
hashed_name
])
{
visited
[
hashed_name
]
=
true
;
queue
.
push_back
(
hashed_name
);
}
}
}
return
false
;
}
///
/// @brief Search through graph and counts provided operator occurences.
///
/// @param[in] graph The graph we search through.
/// @param[in] op_type_count The vector of pairs {op_type_name, op count}
///
/// @note After going through all graph nodes this function asserts
/// whether counted number for each requested op is as expected.
///
void
AssertOpsCount
(
const
Graph
&
graph
,
std
::
vector
<
OpTypeCountPair
>
op_type_count
)
{
for
(
auto
*
node
:
graph
.
Nodes
())
{
if
(
!
node
->
IsOp
())
{
continue
;
}
const
std
::
string
op_type_name
=
node
->
Op
()
->
Type
();
auto
op_it
=
std
::
find_if
(
std
::
begin
(
op_type_count
),
std
::
end
(
op_type_count
),
[
op_type_name
](
const
OpTypeCountPair
&
p
)
{
return
op_type_name
==
p
.
first
;
});
if
(
op_it
!=
std
::
end
(
op_type_count
))
{
op_it
->
second
--
;
}
}
for
(
const
OpTypeCountPair
&
p
:
op_type_count
)
{
EXPECT_EQ
(
p
.
second
,
0
);
}
}
///
/// @brief Builds a program descriptor.
///
/// @param[in] transient_vars The vector of transient variables names.
/// @param[in] persistent_vars The vector of persistent variables names. Those
/// will have persistable attribute set to true.
///
/// @return The program descriptor object.
///
ProgramDesc
BuildProgramDesc
(
const
std
::
vector
<
std
::
string
>&
transient_vars
,
const
std
::
vector
<
std
::
string
>&
persistent_vars
)
{
ProgramDesc
prog
;
auto
add_var_to_prog
=
[
&
prog
](
const
std
::
string
&
var_name
)
->
VarDesc
*
{
auto
var
=
prog
.
MutableBlock
(
0
)
->
Var
(
var_name
);
var
->
SetType
(
proto
::
VarType
::
LOD_TENSOR
);
return
var
;
};
for
(
const
auto
&
v
:
transient_vars
)
{
add_var_to_prog
(
v
);
}
for
(
const
auto
&
v
:
persistent_vars
)
{
auto
*
var
=
add_var_to_prog
(
v
);
var
->
SetPersistable
(
true
);
}
return
prog
;
}
///
/// @brief Execute pass on provided graph and perform checks.
///
/// @param graph The graph we run pass on.
/// @param[in] from The name of a 'starting' node sequence in a
/// graph. This would be used to test for
/// correct node connections.
/// @param[in] to The name of a 'ending' node sequence in a
/// graph. This would be used to test for
/// correct node connections.
/// @param[in] removed_nodes_count The number of nodes we expect will be
/// removed/fused after pass execution.
/// @param[in] added_nodes_count The number of nodes we expect will be
/// added after pass execution.
///
void
RunPassAndAssert
(
Graph
*
graph
,
const
std
::
string
&
from
,
const
std
::
string
&
to
,
int
removed_nodes_count
,
int
added_nodes_count
=
0
)
{
EXPECT_TRUE
(
TestIsReachable
(
*
graph
,
from
,
to
));
int
original_nodes_num
=
graph
->
Nodes
().
size
();
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"fc_act_mkldnn_fuse_pass"
);
pass
->
Apply
(
graph
);
int
current_nodes_num
=
graph
->
Nodes
().
size
();
EXPECT_TRUE
(
TestIsReachable
(
*
graph
,
from
,
to
));
EXPECT_EQ
(
original_nodes_num
-
removed_nodes_count
+
added_nodes_count
,
current_nodes_num
);
}
}
// namespace
// ------------------------------ Test cases -----------------------------------
TEST
(
FuseFCActOneDNNPass
,
ThrowUseMkldnn
)
{
auto
prog
=
BuildProgramDesc
({
"x"
,
"fc_y"
,
"act_y"
},
{
"weights"
,
"bias"
});
CreateOp
(
&
prog
,
"fc"
,
{
{
"Input"
,
"x"
},
{
"Weights"
,
"weights"
},
{
"Bias"
,
"bias"
},
},
{{
"Out"
,
"fc_y"
}},
false
);
CreateOp
(
&
prog
,
"gelu"
,
{{
"Input"
,
"fc_y"
}},
{{
"Out"
,
"act_y"
}},
false
);
Graph
graph
(
prog
);
// No fusion in this attribute configuration
constexpr
int
removed_nodes_count
=
0
;
EXPECT_THROW
(
RunPassAndAssert
(
&
graph
,
"x"
,
"act_y"
,
removed_nodes_count
),
paddle
::
platform
::
EnforceNotMet
);
}
TEST
(
FuseFCActOneDNNPass
,
FuseWithGeluTanh
)
{
auto
prog
=
BuildProgramDesc
({
"x"
,
"fc_y"
,
"act_y"
},
{
"weights"
,
"bias"
});
CreateOp
(
&
prog
,
"fc"
,
{
{
"Input"
,
"x"
},
{
"Weights"
,
"weights"
},
{
"Bias"
,
"bias"
},
},
{{
"Out"
,
"fc_y"
}});
auto
*
act_op
=
CreateOp
(
&
prog
,
"gelu"
,
{{
"Input"
,
"fc_y"
}},
{{
"Out"
,
"act_y"
}},
false
);
act_op
->
SetAttr
(
"approximate"
,
true
);
Graph
graph
(
prog
);
constexpr
int
removed_nodes_count
=
2
;
RunPassAndAssert
(
&
graph
,
"x"
,
"act_y"
,
removed_nodes_count
);
AssertOpsCount
(
graph
,
{{
"fc"
,
1
},
{
"gelu"
,
0
}});
for
(
const
auto
*
node
:
graph
.
Nodes
())
{
if
(
node
->
IsOp
()
&&
node
->
Op
()
->
Type
()
==
"fc"
)
{
const
auto
*
op
=
node
->
Op
();
ASSERT_TRUE
(
op
->
HasAttr
(
"use_mkldnn"
));
EXPECT_TRUE
(
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"use_mkldnn"
)));
ASSERT_TRUE
(
op
->
HasAttr
(
"activation_type"
));
auto
act_type
=
BOOST_GET_CONST
(
std
::
string
,
op
->
GetAttr
(
"activation_type"
));
EXPECT_TRUE
(
act_type
.
compare
(
"gelu_tanh"
)
==
0
);
}
}
}
TEST
(
FuseFCActOneDNNPass
,
FuseWithGeluErf
)
{
auto
prog
=
BuildProgramDesc
({
"x"
,
"fc_y"
,
"act_y"
},
{
"weights"
,
"bias"
});
CreateOp
(
&
prog
,
"fc"
,
{
{
"Input"
,
"x"
},
{
"Weights"
,
"weights"
},
{
"Bias"
,
"bias"
},
},
{{
"Out"
,
"fc_y"
}});
auto
*
act_op
=
CreateOp
(
&
prog
,
"gelu"
,
{{
"Input"
,
"fc_y"
}},
{{
"Out"
,
"act_y"
}},
false
);
act_op
->
SetAttr
(
"approximate"
,
false
);
Graph
graph
(
prog
);
constexpr
int
removed_nodes_count
=
2
;
RunPassAndAssert
(
&
graph
,
"x"
,
"act_y"
,
removed_nodes_count
);
AssertOpsCount
(
graph
,
{{
"fc"
,
1
},
{
"gelu"
,
0
}});
for
(
const
auto
*
node
:
graph
.
Nodes
())
{
if
(
node
->
IsOp
()
&&
node
->
Op
()
->
Type
()
==
"fc"
)
{
const
auto
*
op
=
node
->
Op
();
ASSERT_TRUE
(
op
->
HasAttr
(
"use_mkldnn"
));
EXPECT_TRUE
(
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"use_mkldnn"
)));
ASSERT_TRUE
(
op
->
HasAttr
(
"activation_type"
));
auto
act_type
=
BOOST_GET_CONST
(
std
::
string
,
op
->
GetAttr
(
"activation_type"
));
EXPECT_TRUE
(
act_type
.
compare
(
"gelu_erf"
)
==
0
);
}
}
}
TEST
(
FuseFCActOneDNNPass
,
FuseWithGeluAuto
)
{
auto
prog
=
BuildProgramDesc
({
"x"
,
"fc_y"
,
"act_y"
},
{
"weights"
,
"bias"
});
CreateOp
(
&
prog
,
"fc"
,
{
{
"Input"
,
"x"
},
{
"Weights"
,
"weights"
},
{
"Bias"
,
"bias"
},
},
{{
"Out"
,
"fc_y"
}});
CreateOp
(
&
prog
,
"gelu"
,
{{
"Input"
,
"fc_y"
}},
{{
"Out"
,
"act_y"
}},
false
);
Graph
graph
(
prog
);
constexpr
int
removed_nodes_count
=
2
;
RunPassAndAssert
(
&
graph
,
"x"
,
"act_y"
,
removed_nodes_count
);
AssertOpsCount
(
graph
,
{{
"fc"
,
1
},
{
"gelu"
,
0
}});
for
(
const
auto
*
node
:
graph
.
Nodes
())
{
if
(
node
->
IsOp
()
&&
node
->
Op
()
->
Type
()
==
"fc"
)
{
const
auto
*
op
=
node
->
Op
();
ASSERT_TRUE
(
op
->
HasAttr
(
"use_mkldnn"
));
EXPECT_TRUE
(
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"use_mkldnn"
)));
ASSERT_TRUE
(
op
->
HasAttr
(
"activation_type"
));
auto
act_type
=
BOOST_GET_CONST
(
std
::
string
,
op
->
GetAttr
(
"activation_type"
));
EXPECT_TRUE
(
act_type
.
compare
(
"gelu"
)
==
0
);
}
}
}
TEST
(
FuseFCActOneDNNPass
,
FuseWithTanh
)
{
auto
prog
=
BuildProgramDesc
({
"x"
,
"fc_y"
,
"act_y"
},
{
"weights"
,
"bias"
});
CreateOp
(
&
prog
,
"fc"
,
{
{
"Input"
,
"x"
},
{
"Weights"
,
"weights"
},
{
"Bias"
,
"bias"
},
},
{{
"Out"
,
"fc_y"
}});
CreateOp
(
&
prog
,
"tanh"
,
{{
"Input"
,
"fc_y"
}},
{{
"Out"
,
"act_y"
}},
false
);
Graph
graph
(
prog
);
constexpr
int
removed_nodes_count
=
2
;
RunPassAndAssert
(
&
graph
,
"x"
,
"act_y"
,
removed_nodes_count
);
AssertOpsCount
(
graph
,
{{
"fc"
,
1
},
{
"tanh"
,
0
}});
for
(
const
auto
*
node
:
graph
.
Nodes
())
{
if
(
node
->
IsOp
()
&&
node
->
Op
()
->
Type
()
==
"fc"
)
{
const
auto
*
op
=
node
->
Op
();
ASSERT_TRUE
(
op
->
HasAttr
(
"use_mkldnn"
));
EXPECT_TRUE
(
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"use_mkldnn"
)));
ASSERT_TRUE
(
op
->
HasAttr
(
"activation_type"
));
auto
act_type
=
BOOST_GET_CONST
(
std
::
string
,
op
->
GetAttr
(
"activation_type"
));
EXPECT_TRUE
(
act_type
.
compare
(
"tanh"
)
==
0
);
}
}
}
TEST
(
FuseFCActOneDNNPass
,
FuseWithSigmoid
)
{
auto
prog
=
BuildProgramDesc
({
"x"
,
"fc_y"
,
"act_y"
},
{
"weights"
,
"bias"
});
CreateOp
(
&
prog
,
"fc"
,
{
{
"Input"
,
"x"
},
{
"Weights"
,
"weights"
},
{
"Bias"
,
"bias"
},
},
{{
"Out"
,
"fc_y"
}});
CreateOp
(
&
prog
,
"sigmoid"
,
{{
"Input"
,
"fc_y"
}},
{{
"Out"
,
"act_y"
}},
false
);
Graph
graph
(
prog
);
constexpr
int
removed_nodes_count
=
2
;
RunPassAndAssert
(
&
graph
,
"x"
,
"act_y"
,
removed_nodes_count
);
AssertOpsCount
(
graph
,
{{
"fc"
,
1
},
{
"sigmoid"
,
0
}});
for
(
const
auto
*
node
:
graph
.
Nodes
())
{
if
(
node
->
IsOp
()
&&
node
->
Op
()
->
Type
()
==
"fc"
)
{
const
auto
*
op
=
node
->
Op
();
ASSERT_TRUE
(
op
->
HasAttr
(
"use_mkldnn"
));
EXPECT_TRUE
(
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"use_mkldnn"
)));
ASSERT_TRUE
(
op
->
HasAttr
(
"activation_type"
));
auto
act_type
=
BOOST_GET_CONST
(
std
::
string
,
op
->
GetAttr
(
"activation_type"
));
EXPECT_TRUE
(
act_type
.
compare
(
"sigmoid"
)
==
0
);
}
}
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
USE_PASS
(
fc_act_mkldnn_fuse_pass
);
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
edc06c6a
...
@@ -206,7 +206,8 @@ void CpuPassStrategy::EnableMKLDNN() {
...
@@ -206,7 +206,8 @@ void CpuPassStrategy::EnableMKLDNN() {
"reshape_transpose_matmul_mkldnn_fuse_pass"
,
//
"reshape_transpose_matmul_mkldnn_fuse_pass"
,
//
"matmul_transpose_reshape_fuse_pass"
,
//
"matmul_transpose_reshape_fuse_pass"
,
//
// Disabled due to topology-dependent speed-up
// Disabled due to topology-dependent speed-up
// "fc_mkldnn_pass",
//"fc_mkldnn_pass",
//"fc_act_mkldnn_fuse_pass",
"batch_norm_act_fuse_pass"
,
"batch_norm_act_fuse_pass"
,
"mkldnn_inplace_pass"
,
// This pass should be activated after
"mkldnn_inplace_pass"
,
// This pass should be activated after
// fuses
// fuses
...
...
paddle/fluid/inference/tests/api/analyzer_dam_tester.cc
浏览文件 @
edc06c6a
...
@@ -206,6 +206,7 @@ void profile(bool use_mkldnn = false) {
...
@@ -206,6 +206,7 @@ void profile(bool use_mkldnn = false) {
"relu"
,
"fc"
};
"relu"
,
"fc"
};
cfg
.
SetMKLDNNOp
(
op_list
);
cfg
.
SetMKLDNNOp
(
op_list
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_act_mkldnn_fuse_pass"
);
}
}
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
outputs
;
...
@@ -262,6 +263,7 @@ void compare(bool use_mkldnn = false) {
...
@@ -262,6 +263,7 @@ void compare(bool use_mkldnn = false) {
"relu"
};
"relu"
};
cfg
.
SetMKLDNNOp
(
op_list
);
cfg
.
SetMKLDNNOp
(
op_list
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_act_mkldnn_fuse_pass"
);
}
}
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
...
...
paddle/fluid/inference/tests/api/analyzer_image_classification_tester.cc
浏览文件 @
edc06c6a
...
@@ -50,8 +50,10 @@ void profile(bool use_mkldnn = false) {
...
@@ -50,8 +50,10 @@ void profile(bool use_mkldnn = false) {
if
(
use_mkldnn
)
{
if
(
use_mkldnn
)
{
cfg
.
EnableMKLDNN
();
cfg
.
EnableMKLDNN
();
if
(
!
FLAGS_disable_mkldnn_fc
)
if
(
!
FLAGS_disable_mkldnn_fc
)
{
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_act_mkldnn_fuse_pass"
);
}
}
}
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
outputs
;
...
@@ -83,8 +85,10 @@ void compare(bool use_mkldnn = false) {
...
@@ -83,8 +85,10 @@ void compare(bool use_mkldnn = false) {
SetConfig
(
&
cfg
);
SetConfig
(
&
cfg
);
if
(
use_mkldnn
)
{
if
(
use_mkldnn
)
{
cfg
.
EnableMKLDNN
();
cfg
.
EnableMKLDNN
();
if
(
!
FLAGS_disable_mkldnn_fc
)
if
(
!
FLAGS_disable_mkldnn_fc
)
{
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_act_mkldnn_fuse_pass"
);
}
}
}
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
...
...
paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester_helper.h
浏览文件 @
edc06c6a
...
@@ -163,6 +163,7 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) {
...
@@ -163,6 +163,7 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) {
if
(
use_mkldnn
)
{
if
(
use_mkldnn
)
{
cfg
->
EnableMKLDNN
();
cfg
->
EnableMKLDNN
();
cfg
->
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
->
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
->
pass_builder
()
->
AppendPass
(
"fc_act_mkldnn_fuse_pass"
);
}
}
// Enable seqpool_concat_fuse_pass, disabled by default since it takes much
// Enable seqpool_concat_fuse_pass, disabled by default since it takes much
// time
// time
...
...
paddle/fluid/inference/tests/api/analyzer_transformer_compare_tester.cc
浏览文件 @
edc06c6a
...
@@ -25,6 +25,7 @@ void compare(bool use_mkldnn = false) {
...
@@ -25,6 +25,7 @@ void compare(bool use_mkldnn = false) {
if
(
use_mkldnn
)
{
if
(
use_mkldnn
)
{
cfg
.
EnableMKLDNN
();
cfg
.
EnableMKLDNN
();
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_act_mkldnn_fuse_pass"
);
}
}
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
...
...
paddle/fluid/inference/tests/api/analyzer_transformer_profile_tester.cc
浏览文件 @
edc06c6a
...
@@ -26,6 +26,7 @@ void profile(bool use_mkldnn = false) {
...
@@ -26,6 +26,7 @@ void profile(bool use_mkldnn = false) {
if
(
use_mkldnn
)
{
if
(
use_mkldnn
)
{
cfg
.
EnableMKLDNN
();
cfg
.
EnableMKLDNN
();
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_act_mkldnn_fuse_pass"
);
}
}
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
...
...
paddle/fluid/inference/tests/api/analyzer_vis_tester.cc
浏览文件 @
edc06c6a
...
@@ -86,6 +86,7 @@ void profile(bool use_mkldnn = false) {
...
@@ -86,6 +86,7 @@ void profile(bool use_mkldnn = false) {
if
(
use_mkldnn
)
{
if
(
use_mkldnn
)
{
cfg
.
EnableMKLDNN
();
cfg
.
EnableMKLDNN
();
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_act_mkldnn_fuse_pass"
);
}
}
// cfg.pass_builder()->TurnOnDebug();
// cfg.pass_builder()->TurnOnDebug();
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
outputs
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
outputs
;
...
@@ -136,6 +137,7 @@ void compare(bool use_mkldnn = false) {
...
@@ -136,6 +137,7 @@ void compare(bool use_mkldnn = false) {
if
(
use_mkldnn
)
{
if
(
use_mkldnn
)
{
cfg
.
EnableMKLDNN
();
cfg
.
EnableMKLDNN
();
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_mkldnn_pass"
);
cfg
.
pass_builder
()
->
AppendPass
(
"fc_act_mkldnn_fuse_pass"
);
}
}
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
...
...
paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc
浏览文件 @
edc06c6a
...
@@ -459,6 +459,36 @@ class FCPrimitiveFactory {
...
@@ -459,6 +459,36 @@ class FCPrimitiveFactory {
constexpr
float
placeholder
=
1.0
f
;
// beta
constexpr
float
placeholder
=
1.0
f
;
// beta
post_operations
.
append_eltwise
(
scale
,
mkldnn
::
algorithm
::
eltwise_relu
,
post_operations
.
append_eltwise
(
scale
,
mkldnn
::
algorithm
::
eltwise_relu
,
negative_slope
,
placeholder
);
negative_slope
,
placeholder
);
}
else
if
(
ctx
.
Attr
<
std
::
string
>
(
"activation_type"
)
==
"gelu"
)
{
constexpr
float
scale
=
1.0
f
;
constexpr
float
alpha
=
0.0
f
;
constexpr
float
beta
=
0.0
f
;
post_operations
.
append_eltwise
(
scale
,
mkldnn
::
algorithm
::
eltwise_gelu
,
alpha
,
beta
);
}
else
if
(
ctx
.
Attr
<
std
::
string
>
(
"activation_type"
)
==
"gelu_tanh"
)
{
constexpr
float
scale
=
1.0
f
;
constexpr
float
alpha
=
0.0
f
;
constexpr
float
beta
=
0.0
f
;
post_operations
.
append_eltwise
(
scale
,
mkldnn
::
algorithm
::
eltwise_gelu_tanh
,
alpha
,
beta
);
}
else
if
(
ctx
.
Attr
<
std
::
string
>
(
"activation_type"
)
==
"gelu_erf"
)
{
constexpr
float
scale
=
1.0
f
;
constexpr
float
alpha
=
0.0
f
;
constexpr
float
beta
=
0.0
f
;
post_operations
.
append_eltwise
(
scale
,
mkldnn
::
algorithm
::
eltwise_gelu_erf
,
alpha
,
beta
);
}
else
if
(
ctx
.
Attr
<
std
::
string
>
(
"activation_type"
)
==
"tanh"
)
{
constexpr
float
scale
=
1.0
f
;
constexpr
float
alpha
=
0.0
f
;
constexpr
float
beta
=
0.0
f
;
post_operations
.
append_eltwise
(
scale
,
mkldnn
::
algorithm
::
eltwise_tanh
,
alpha
,
beta
);
}
else
if
(
ctx
.
Attr
<
std
::
string
>
(
"activation_type"
)
==
"sigmoid"
)
{
constexpr
float
scale
=
1.0
f
;
constexpr
float
alpha
=
0.0
f
;
constexpr
float
beta
=
0.0
f
;
post_operations
.
append_eltwise
(
scale
,
mkldnn
::
algorithm
::
eltwise_logistic
,
alpha
,
beta
);
}
}
attributes
.
set_post_ops
(
post_operations
);
attributes
.
set_post_ops
(
post_operations
);
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_fc_act_fuse_pass.py
0 → 100644
浏览文件 @
edc06c6a
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for fusion of fc and activation."""
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
paddle.fluid
as
fluid
from
inference_pass_test
import
InferencePassTest
from
paddle
import
enable_static
from
paddle.fluid.core
import
PassVersionChecker
enable_static
()
class
FCGeluTanhOneDnnFusePassTest
(
InferencePassTest
):
def
setUp
(
self
):
self
.
set_params
()
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
128
,
768
],
dtype
=
"float32"
)
fc_out
=
fluid
.
layers
.
fc
(
input
=
data
,
size
=
3072
,
num_flatten_dims
=
2
)
gelu_out
=
fluid
.
layers
.
gelu
(
fc_out
,
approximate
=
False
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
((
1
,
128
,
768
)).
astype
(
"float32"
)}
self
.
fetch_list
=
[
gelu_out
]
self
.
enable_mkldnn
=
True
def
set_params
(
self
):
self
.
pass_name
=
"fc_act_mkldnn_fuse_pass"
def
test_check_output
(
self
):
self
.
check_output
()
class
FCGeluErfOneDnnFusePassTest
(
InferencePassTest
):
def
setUp
(
self
):
self
.
set_params
()
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
128
,
768
],
dtype
=
"float32"
)
fc_out
=
fluid
.
layers
.
fc
(
input
=
data
,
size
=
3072
,
num_flatten_dims
=
2
)
gelu_out
=
fluid
.
layers
.
gelu
(
fc_out
,
approximate
=
True
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
((
1
,
128
,
768
)).
astype
(
"float32"
)}
self
.
fetch_list
=
[
gelu_out
]
self
.
enable_mkldnn
=
True
def
set_params
(
self
):
self
.
pass_name
=
"fc_act_mkldnn_fuse_pass"
def
test_check_output
(
self
):
self
.
check_output
()
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
self
.
pass_name
))
class
FCTanhOneDnnFusePassTest
(
InferencePassTest
):
def
setUp
(
self
):
self
.
set_params
()
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
128
,
768
],
dtype
=
"float32"
)
fc_out
=
fluid
.
layers
.
fc
(
input
=
data
,
size
=
3072
,
num_flatten_dims
=
2
)
tanh_out
=
fluid
.
layers
.
tanh
(
fc_out
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
((
1
,
128
,
768
)).
astype
(
"float32"
)}
self
.
fetch_list
=
[
tanh_out
]
self
.
enable_mkldnn
=
True
def
set_params
(
self
):
self
.
pass_name
=
"fc_act_mkldnn_fuse_pass"
def
test_check_output
(
self
):
self
.
check_output
()
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
self
.
pass_name
))
class
FCSigmoidOneDnnFusePassTest
(
InferencePassTest
):
def
setUp
(
self
):
self
.
set_params
()
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
128
,
768
],
dtype
=
"float32"
)
fc_out
=
fluid
.
layers
.
fc
(
input
=
data
,
size
=
3072
,
num_flatten_dims
=
2
)
sigmoid_out
=
fluid
.
layers
.
sigmoid
(
fc_out
)
self
.
feeds
=
{
"data"
:
np
.
random
.
random
((
1
,
128
,
768
)).
astype
(
"float32"
)}
self
.
fetch_list
=
[
sigmoid_out
]
self
.
enable_mkldnn
=
True
def
set_params
(
self
):
self
.
pass_name
=
"fc_act_mkldnn_fuse_pass"
def
test_check_output
(
self
):
self
.
check_output
()
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
self
.
pass_name
))
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录