Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
826e2781
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
826e2781
编写于
7月 11, 2022
作者:
S
Sławomir Siwek
提交者:
GitHub
7月 11, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Unify and generalize activation fuse passes (#44185)
* reduce redundancy * python code style * fix int8 ut
上级
526be01a
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
203 addition
and
450 deletion
+203
-450
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+14
-95
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+6
-78
paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc
...d/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc
+14
-47
paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h
...id/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h
+3
-5
paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.cc
paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.cc
+26
-53
paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.h
paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.h
+3
-5
paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc
paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc
+6
-9
paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.cc
...amework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.cc
+25
-46
paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h
...ramework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h
+2
-4
paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h
...luid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h
+1
-16
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
+1
-17
paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h
paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h
+1
-35
paddle/fluid/platform/mkldnn_reuse.h
paddle/fluid/platform/mkldnn_reuse.h
+87
-26
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_softplus_activation_fuse_pass.py
...ir/inference/test_mkldnn_softplus_activation_fuse_pass.py
+14
-14
未找到文件。
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
826e2781
...
...
@@ -931,65 +931,22 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input,
return
bn_out_var
;
}
PDNode
*
patterns
::
ConvActivation
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
conv_input
,
std
::
string
conv_type
,
std
::
string
activation_type
)
{
// Create Operators
conv_input
->
assert_is_op_input
(
conv_type
,
"Input"
);
auto
*
conv_op
=
pattern
->
NewNode
(
conv_repr
())
->
assert_is_op
(
conv_type
);
auto
*
activation_op
=
pattern
->
NewNode
(
activation_repr
())
->
assert_is_op
(
activation_type
);
// Create variables
// Filter
auto
*
conv_weight_var
=
pattern
->
NewNode
(
conv_weight_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
conv_type
,
"Filter"
);
// intermediate variable, will be removed in the IR after fuse.
auto
*
conv_out_var
=
pattern
->
NewNode
(
conv_out_repr
())
->
AsIntermediate
()
->
assert_is_only_output_of_op
(
conv_type
)
->
assert_is_op_input
(
activation_type
);
// output
auto
*
activation_out_var
=
pattern
->
NewNode
(
activation_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
activation_type
);
conv_op
->
LinksFrom
({
conv_input
,
conv_weight_var
}).
LinksTo
({
conv_out_var
});
activation_op
->
LinksFrom
({
conv_out_var
}).
LinksTo
({
activation_out_var
});
return
activation_out_var
;
}
PDNode
*
patterns
::
ElementwiseActivation
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
elementwise_a
,
const
std
::
string
&
elementwise_type
,
const
std
::
string
&
activation_type
)
{
// Create Operators
elementwise_a
->
assert_is_op_input
(
elementwise_type
,
"X"
);
auto
*
elementwise_op
=
pattern
->
NewNode
(
elementwise_repr
())
->
assert_is_op
(
elementwise_type
);
PDNode
*
patterns
::
OperatorActivation
::
operator
()(
const
std
::
string
&
operator_type
,
const
std
::
string
&
activation_type
)
{
auto
*
preceding_op
=
pattern
->
NewNode
(
preceding_op_repr
())
->
assert_is_op
(
operator_type
);
auto
*
preceding_op_out
=
pattern
->
NewNode
(
preceding_op_out_repr
())
->
AsIntermediate
()
->
assert_is_only_output_of_op
(
operator_type
)
->
assert_is_op_input
(
activation_type
);
auto
*
activation_op
=
pattern
->
NewNode
(
activation_repr
())
->
assert_is_op
(
activation_type
);
// Create variables
auto
*
elementwise_b
=
pattern
->
NewNode
(
elementwise_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
elementwise_type
,
"Y"
);
// intermediate variable, will be removed in the IR after fuse.
auto
*
elementwise_out_var
=
pattern
->
NewNode
(
elementwise_out_repr
())
->
AsIntermediate
()
->
assert_is_only_output_of_op
(
elementwise_type
)
->
assert_is_op_input
(
activation_type
);
// output
auto
*
activation_out_var
=
pattern
->
NewNode
(
activation_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
activation_type
);
elementwise_op
->
LinksFrom
({
elementwise_a
,
elementwise_b
})
.
LinksTo
({
elementwise_out_var
});
activation_op
->
LinksFrom
({
elementwise_out_var
}).
LinksTo
({
activation_out_var
});
return
activation_out_var
;
auto
*
activation_out
=
pattern
->
NewNode
(
activation_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
activation_type
);
preceding_op
->
LinksTo
({
preceding_op_out
});
activation_op
->
LinksFrom
({
preceding_op_out
}).
LinksTo
({
activation_out
});
return
activation_out
;
}
PDNode
*
patterns
::
SeqConvEltAddRelu
::
operator
()(
...
...
@@ -1121,44 +1078,6 @@ PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x,
return
fc_out_var
;
}
PDNode
*
patterns
::
FCActOneDNN
::
operator
()(
const
std
::
string
&
act_type
)
{
auto
*
fc
=
pattern
->
NewNode
(
fc_repr
())
->
assert_is_op
(
"fc"
);
auto
*
fc_out
=
pattern
->
NewNode
(
fc_out_repr
())
->
assert_is_op_output
(
"fc"
,
"Out"
)
->
assert_is_op_input
(
act_type
);
auto
*
act
=
pattern
->
NewNode
(
act_repr
())
->
assert_is_op
(
act_type
)
->
AsIntermediate
();
auto
*
act_out
=
pattern
->
NewNode
(
act_out_repr
())
->
assert_is_op_output
(
act_type
,
"Out"
)
->
AsOutput
();
fc
->
LinksTo
({
fc_out
});
act
->
LinksFrom
({
fc_out
}).
LinksTo
({
act_out
});
return
act_out
;
}
PDNode
*
patterns
::
SoftplusActivation
::
operator
()(
std
::
string
activation_type
)
{
// Create Operators
auto
*
softplus_op
=
pattern
->
NewNode
(
softplus_repr
())
->
assert_is_op
(
"softplus"
);
auto
*
activation_op
=
pattern
->
NewNode
(
activation_repr
())
->
assert_is_op
(
activation_type
);
// intermediate variable, will be removed in the IR after fuse.
auto
*
softplus_out
=
pattern
->
NewNode
(
softplus_out_repr
())
->
AsIntermediate
()
->
assert_is_only_output_of_op
(
"softplus"
)
->
assert_is_op_input
(
activation_type
);
// output
auto
*
activation_out
=
pattern
->
NewNode
(
activation_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
activation_type
);
softplus_op
->
LinksTo
({
softplus_out
});
activation_op
->
LinksFrom
({
softplus_out
}).
LinksTo
({
activation_out
});
return
activation_out
;
}
PDNode
*
patterns
::
Embedding
::
operator
()(
PDNode
*
x
)
{
x
->
assert_is_op_input
(
"lookup_table"
,
"Ids"
);
auto
*
lookup_table_op
=
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
826e2781
...
...
@@ -524,49 +524,16 @@ struct ConvBN : public PatternBase {
PATTERN_DECL_NODE
(
bn_saved_variance
);
};
// Conv with Activation
// op: conv + activation
// named nodes:
// conv_input, conv_weight,
// conv_out, conv,
// activation_out, activation
struct
ConvActivation
:
public
PatternBase
{
ConvActivation
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"conv_activation"
)
{}
PDNode
*
operator
()(
PDNode
*
conv_input
,
std
::
string
conv_type
=
"conv2d"
,
std
::
string
activation_type
=
"relu"
);
struct
OperatorActivation
:
public
PatternBase
{
OperatorActivation
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"operator_activation"
)
{}
// declare operator node's name
PATTERN_DECL_NODE
(
conv
);
PATTERN_DECL_NODE
(
activation
);
// declare variable node's name
PATTERN_DECL_NODE
(
conv_weight
);
PATTERN_DECL_NODE
(
conv_out
);
PATTERN_DECL_NODE
(
activation_out
);
};
// Elementwise with Activation
// op: elementwise + activation
// named nodes:
// elementwise_a, elementwise_b,
// elementwise_out, elementwise,
// activation_out, activation
struct
ElementwiseActivation
:
public
PatternBase
{
ElementwiseActivation
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"elementwise_add_activation"
)
{}
PDNode
*
operator
()(
PDNode
*
elementwise_a
,
const
std
::
string
&
elementwise_type
,
PDNode
*
operator
()(
const
std
::
string
&
operator_type
,
const
std
::
string
&
activation_type
);
// declare operator node's name
PATTERN_DECL_NODE
(
elementwise
);
PATTERN_DECL_NODE
(
preceding_op
);
PATTERN_DECL_NODE
(
preceding_op_out
);
PATTERN_DECL_NODE
(
activation
);
// declare variable node's name
PATTERN_DECL_NODE
(
elementwise_b
);
PATTERN_DECL_NODE
(
elementwise_out
);
PATTERN_DECL_NODE
(
activation_out
);
};
...
...
@@ -639,45 +606,6 @@ struct FCMKLDNN : public PatternBase {
PATTERN_DECL_NODE
(
output
);
};
//
// \brief Pattern looking for fc and a directly following activation
// operator.
//
// \note Currently only gelu and tanh are supported as an activation
// function.
// Formula: act(fc(x))
// Op: fc + act
struct
FCActOneDNN
:
public
PatternBase
{
FCActOneDNN
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"fc_act_onednn"
)
{}
PDNode
*
operator
()(
const
std
::
string
&
act_type
);
// declare operator node's name
PATTERN_DECL_NODE
(
fc
);
PATTERN_DECL_NODE
(
act
);
PATTERN_DECL_NODE
(
fc_out
);
PATTERN_DECL_NODE
(
act_out
);
};
// Fuse softplus with activation
// ops: softplus + activation
// nodes:
// softplus, softplus_out,
// activation, activation_out
struct
SoftplusActivation
:
public
PatternBase
{
SoftplusActivation
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"softplus_activation"
)
{}
PDNode
*
operator
()(
std
::
string
activation_type
);
// declare operator node's name
PATTERN_DECL_NODE
(
softplus
);
PATTERN_DECL_NODE
(
activation
);
PATTERN_DECL_NODE
(
softplus_out
);
PATTERN_DECL_NODE
(
activation_out
);
};
// Embedding
struct
Embedding
:
public
PatternBase
{
Embedding
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
...
...
paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc
浏览文件 @
826e2781
...
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h"
namespace
paddle
{
...
...
@@ -24,61 +25,27 @@ namespace ir {
using
string
::
PrettyLogDetail
;
void
ConvActivationMkldnnFusePass
::
ApplyImpl
(
Graph
*
graph
)
const
{
std
::
vector
<
std
::
string
>
act_types
=
{
"relu"
,
"mish"
,
"swish"
,
"sqrt"
,
"hard_swish"
,
"sigmoid"
,
"abs"
,
"gelu"
,
"relu6"
,
"clip"
,
"tanh"
,
"hard_sigmoid"
,
"leaky_relu"
};
auto
act_types
=
paddle
::
platform
::
GetSupportedActivations
();
std
::
vector
<
std
::
string
>
conv_types
=
{
"conv2d"
};
for
(
const
auto
&
conv_type
:
conv_types
)
for
(
auto
&
act_type
:
act_types
)
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
attrs_map
;
if
(
act_type
==
"swish"
)
attrs_map
.
emplace
(
"beta"
,
"fuse_alpha"
);
else
if
(
act_type
==
"relu6"
)
attrs_map
.
emplace
(
"threshold"
,
"fuse_alpha"
);
else
if
(
act_type
==
"hard_sigmoid"
)
{
attrs_map
.
emplace
(
"slope"
,
"fuse_alpha"
);
attrs_map
.
emplace
(
"offset"
,
"fuse_beta"
);
}
else
if
(
act_type
==
"clip"
)
{
attrs_map
.
emplace
(
"min"
,
"fuse_alpha"
);
attrs_map
.
emplace
(
"max"
,
"fuse_beta"
);
}
else
{
attrs_map
.
emplace
(
"alpha"
,
"fuse_alpha"
);
attrs_map
.
emplace
(
"beta"
,
"fuse_beta"
);
}
FuseConvAct
(
graph
,
conv_type
,
act_type
,
attrs_map
);
FuseConvAct
(
graph
,
conv_type
,
act_type
);
}
}
void
ConvActivationMkldnnFusePass
::
FuseConvAct
(
Graph
*
graph
,
const
std
::
string
&
conv_type
,
std
::
string
&
act_type
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
attrs_map
)
const
{
void
ConvActivationMkldnnFusePass
::
FuseConvAct
(
Graph
*
graph
,
const
std
::
string
&
conv_type
,
std
::
string
&
act_type
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
InvalidArgument
(
"Graph cannot be nullptr."
));
FusePassBase
::
Init
(
conv_type
+
"_"
+
act_type
+
"_mkldnn_fuse_pass"
,
graph
);
GraphPatternDetector
gpd
;
auto
*
conv_input
=
gpd
.
mutable_pattern
()
->
NewNode
(
"conv_activation_mkldnn_fuse/conv_input"
)
->
AsInput
()
->
assert_is_op_input
(
conv_type
,
"Input"
);
patterns
::
ConvActivation
conv_act_pattern
(
gpd
.
mutable_pattern
(),
"conv_activation_mkldnn_fuse"
);
conv_act_pattern
(
conv_input
,
conv_type
,
act_type
);
patterns
::
OperatorActivation
conv_act_pattern
(
gpd
.
mutable_pattern
(),
"conv_activation_mkldnn_fuse"
);
conv_act_pattern
(
conv_type
,
act_type
);
int
found_conv_activation_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
...
...
@@ -90,16 +57,16 @@ void ConvActivationMkldnnFusePass::FuseConvAct(
return
;
}
GET_IR_NODE_FROM_SUBGRAPH
(
conv_weight
,
conv_weight
,
conv_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_out
,
conv_out
,
conv_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv
,
conv
,
conv_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
activation_out
,
activation_out
,
conv_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv
,
preceding_op
,
conv_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_out
,
preceding_op_out
,
conv_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
activation
,
activation
,
conv_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
activation_out
,
activation_out
,
conv_act_pattern
);
OpDesc
*
conv_op
=
conv
->
Op
();
OpDesc
*
act_op
=
activation
->
Op
();
for
(
const
auto
&
attrs
:
attrs_map
)
{
auto
attr_map
=
paddle
::
platform
::
GetAttributeMap
(
act_type
);
for
(
const
auto
&
attrs
:
attr_map
)
{
if
(
act_op
->
HasAttr
(
attrs
.
first
))
{
conv_op
->
SetAttr
(
attrs
.
second
,
act_op
->
GetAttr
(
attrs
.
first
));
}
...
...
paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h
浏览文件 @
826e2781
...
...
@@ -31,11 +31,9 @@ class ConvActivationMkldnnFusePass : public FusePassBase {
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
override
;
void
FuseConvAct
(
Graph
*
graph
,
const
std
::
string
&
conv_type
,
std
::
string
&
act_type
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
attrs_map
)
const
;
void
FuseConvAct
(
Graph
*
graph
,
const
std
::
string
&
conv_type
,
std
::
string
&
act_type
)
const
;
};
}
// namespace ir
...
...
paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.cc
浏览文件 @
826e2781
...
...
@@ -17,6 +17,7 @@
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h"
namespace
paddle
{
...
...
@@ -26,71 +27,40 @@ namespace ir {
using
string
::
PrettyLogDetail
;
void
ElementwiseActivationOneDNNPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
std
::
vector
<
std
::
string
>
act_types
=
{
"relu"
,
"tanh"
,
"leaky_relu"
,
"swish"
,
"hard_swish"
,
"sqrt"
,
"abs"
,
"clip"
,
"gelu"
,
"relu6"
,
"sigmoid"
};
auto
act_types
=
paddle
::
platform
::
GetSupportedActivations
();
std
::
vector
<
std
::
string
>
elt_types
=
{
"elementwise_add"
,
"elementwise_sub"
,
"elementwise_mul"
};
for
(
const
auto
&
elt_type
:
elt_types
)
for
(
const
auto
&
act_type
:
act_types
)
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
attr_map
;
if
(
act_type
==
"swish"
)
attr_map
.
emplace
(
"beta"
,
"activation_alpha"
);
else
if
(
act_type
==
"relu6"
)
attr_map
.
emplace
(
"threshold"
,
"activation_alpha"
);
else
if
(
act_type
==
"clip"
)
{
attr_map
.
emplace
(
"min"
,
"activation_alpha"
);
attr_map
.
emplace
(
"max"
,
"activation_beta"
);
}
else
{
attr_map
.
emplace
(
"alpha"
,
"activation_alpha"
);
attr_map
.
emplace
(
"beta"
,
"activation_beta"
);
}
FuseElementwiseAct
(
graph
,
elt_type
,
act_type
,
attr_map
);
FuseElementwiseAct
(
graph
,
elt_type
,
act_type
);
}
}
void
ElementwiseActivationOneDNNPass
::
FuseElementwiseAct
(
Graph
*
graph
,
const
std
::
string
&
elt_type
,
const
std
::
string
&
act_type
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
attr_map
)
const
{
const
std
::
string
&
act_type
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
InvalidArgument
(
"Graph cannot be nullptr."
));
FusePassBase
::
Init
(
elt_type
+
"_"
+
act_type
+
"_mkldnn_fuse_pass"
,
graph
);
GraphPatternDetector
gpd
;
auto
*
elementwise_input
=
gpd
.
mutable_pattern
()
->
NewNode
(
elt_type
+
"_act/elementwise_input"
)
->
AsInput
()
->
assert_is_op_input
(
elt_type
,
"X"
);
patterns
::
ElementwiseActivation
elementwise_act_pattern
(
gpd
.
mutable_pattern
(),
elt_type
+
"_act"
);
elementwise_act_pattern
(
elementwise_input
,
elt_type
,
act_type
);
patterns
::
OperatorActivation
elementwise_act_pattern
(
gpd
.
mutable_pattern
(),
elt_type
+
"_act"
);
elementwise_act_pattern
(
elt_type
,
act_type
);
int
found_elementwise_activation_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
4
)
<<
"Fuse "
<<
elt_type
<<
" with activation op."
;
// Elementwise output
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_out
,
elementwise_out
,
elementwise_act_pattern
);
// ACT output
GET_IR_NODE_FROM_SUBGRAPH
(
activation_out
,
activation_out
,
elementwise_act_pattern
);
// ops
elementwise
,
preceding_op
,
elementwise_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise
,
elementwise
,
elementwise_act_pattern
);
elementwise
_out
,
preceding_op_out
,
elementwise_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
activation
,
activation
,
elementwise_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
activation_out
,
activation_out
,
elementwise_act_pattern
);
auto
*
elementwise_op
=
elementwise
->
Op
();
...
...
@@ -106,6 +76,7 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct(
}
auto
*
activation_op
=
activation
->
Op
();
auto
attr_map
=
paddle
::
platform
::
GetAttributeMap
(
act_type
);
for
(
const
auto
&
attr
:
attr_map
)
{
if
(
activation_op
->
HasAttr
(
attr
.
first
))
{
elementwise_op
->
SetAttr
(
attr
.
second
,
...
...
@@ -115,9 +86,9 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct(
if
(
act_type
==
"gelu"
&&
activation_op
->
HasAttr
(
"approximate"
)
&&
BOOST_GET_CONST
(
bool
,
activation_op
->
GetAttr
(
"approximate"
)))
elementwise_op
->
SetAttr
(
"
activation_type
"
,
std
::
string
(
"gelu_tanh"
));
elementwise_op
->
SetAttr
(
"
fuse_activation
"
,
std
::
string
(
"gelu_tanh"
));
else
elementwise_op
->
SetAttr
(
"
activation_type
"
,
act_type
);
elementwise_op
->
SetAttr
(
"
fuse_activation
"
,
act_type
);
elementwise_op
->
SetOutput
(
"Out"
,
{
activation_out
->
Name
()});
...
...
@@ -146,14 +117,16 @@ REGISTER_PASS_CAPABILITY(elt_act_mkldnn_fuse_pass)
.
LE
(
"elementwise_add"
,
1
)
.
LE
(
"elementwise_sub"
,
1
)
.
LE
(
"elementwise_mul"
,
1
)
.
LE
(
"relu"
,
0
)
.
LE
(
"tanh"
,
0
)
.
LE
(
"leaky_relu"
,
1
)
.
LE
(
"swish"
,
0
)
.
LE
(
"hard_swish"
,
0
)
.
LE
(
"sqrt"
,
0
)
.
LE
(
"abs"
,
0
)
.
EQ
(
"abs"
,
0
)
.
LE
(
"clip"
,
1
)
.
LE
(
"gelu"
,
0
)
.
LE
(
"relu6"
,
0
)
.
LE
(
"sigmoid"
,
0
));
.
EQ
(
"gelu"
,
0
)
.
EQ
(
"hard_sigmoid"
,
0
)
.
LE
(
"hard_swish"
,
0
)
.
LE
(
"leaky_relu"
,
1
)
.
LE
(
"mish"
,
1
)
.
EQ
(
"relu"
,
0
)
.
EQ
(
"relu6"
,
0
)
.
EQ
(
"sigmoid"
,
0
)
.
EQ
(
"sqrt"
,
0
)
.
EQ
(
"swish"
,
0
)
.
EQ
(
"tanh"
,
0
));
paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.h
浏览文件 @
826e2781
...
...
@@ -34,11 +34,9 @@ class ElementwiseActivationOneDNNPass : public FusePassBase {
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
override
;
void
FuseElementwiseAct
(
Graph
*
graph
,
const
std
::
string
&
elt_types
,
const
std
::
string
&
act_types
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
attr_map
)
const
;
void
FuseElementwiseAct
(
Graph
*
graph
,
const
std
::
string
&
elt_types
,
const
std
::
string
&
act_types
)
const
;
};
}
// namespace ir
...
...
paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc
浏览文件 @
826e2781
...
...
@@ -39,20 +39,17 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
FusePassBase
::
Init
(
"fc_act"
,
graph
);
GraphPatternDetector
gpd
;
patterns
::
FCActOneDNN
fc_act_pattern
(
gpd
.
mutable_pattern
(),
"fc_act"
);
fc_act_pattern
(
act_type
);
patterns
::
OperatorActivation
fc_act_pattern
(
gpd
.
mutable_pattern
(),
"fc_act"
);
fc_act_pattern
(
"fc"
,
act_type
);
int
found_fc_act_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
4
)
<<
"Fuse fc with activation op."
;
// FC output
GET_IR_NODE_FROM_SUBGRAPH
(
fc_out
,
fc_out
,
fc_act_pattern
);
// ACT output
GET_IR_NODE_FROM_SUBGRAPH
(
act_out
,
act_out
,
fc_act_pattern
);
// ops
GET_IR_NODE_FROM_SUBGRAPH
(
fc
,
fc
,
fc_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
act
,
act
,
fc_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fc
,
preceding_op
,
fc_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fc_out
,
preceding_op_out
,
fc_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
act
,
activation
,
fc_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
act_out
,
activation_out
,
fc_act_pattern
);
auto
*
fc_op
=
fc
->
Op
();
auto
*
act_op
=
act
->
Op
();
...
...
paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.cc
浏览文件 @
826e2781
...
...
@@ -17,6 +17,7 @@
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h"
namespace
paddle
{
...
...
@@ -26,59 +27,34 @@ namespace ir {
using
string
::
PrettyLogDetail
;
void
SoftplusActivationOneDNNPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
std
::
vector
<
std
::
string
>
act_types
=
{
"relu"
,
"tanh"
,
"leaky_relu"
,
"swish"
,
"hardswish"
,
"sqrt"
,
"abs"
,
"clip"
,
"gelu"
,
"relu6"
,
"sigmoid"
};
auto
act_types
=
paddle
::
platform
::
GetSupportedActivations
();
for
(
const
auto
&
act_type
:
act_types
)
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
attr_map
;
if
(
act_type
==
"swish"
)
attr_map
.
emplace
(
"beta"
,
"fuse_activation_alpha"
);
else
if
(
act_type
==
"relu6"
)
attr_map
.
emplace
(
"threshold"
,
"fuse_activation_alpha"
);
else
if
(
act_type
==
"clip"
)
{
attr_map
.
emplace
(
"min"
,
"fuse_activation_alpha"
);
attr_map
.
emplace
(
"max"
,
"fuse_activation_beta"
);
}
else
{
attr_map
.
emplace
(
"alpha"
,
"fuse_activation_alpha"
);
attr_map
.
emplace
(
"beta"
,
"fuse_activation_beta"
);
}
FuseSoftplusActivation
(
graph
,
act_type
,
attr_map
);
FuseSoftplusActivation
(
graph
,
act_type
);
}
}
void
SoftplusActivationOneDNNPass
::
FuseSoftplusActivation
(
Graph
*
graph
,
const
std
::
string
&
fuse_activation_type
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
attr_map
)
const
{
Graph
*
graph
,
const
std
::
string
&
act_type
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
InvalidArgument
(
"Graph cannot be nullptr."
));
FusePassBase
::
Init
(
"softplus_activation"
,
graph
);
GraphPatternDetector
gpd
;
patterns
::
Softplus
Activation
softplus_activation_pattern
(
patterns
::
Operator
Activation
softplus_activation_pattern
(
gpd
.
mutable_pattern
(),
"softplus_activation"
);
softplus_activation_pattern
(
fuse_activation
_type
);
softplus_activation_pattern
(
"softplus"
,
act
_type
);
int
found_softplus_activation_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
4
)
<<
"Fuse softplus with activation op."
;
GET_IR_NODE_FROM_SUBGRAPH
(
softplus_out
,
softplus
_out
,
softplus_activation_pattern
);
softplus_out
,
preceding_op
_out
,
softplus_activation_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
activation_out
,
activation_out
,
softplus_activation_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softplus
,
softplus
,
softplus_activation_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
softplus
,
preceding_op
,
softplus_activation_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
activation
,
activation
,
softplus_activation_pattern
);
...
...
@@ -94,18 +70,18 @@ void SoftplusActivationOneDNNPass::FuseSoftplusActivation(
}
auto
*
activation_op
=
activation
->
Op
();
auto
attr_map
=
paddle
::
platform
::
GetAttributeMap
(
act_type
);
for
(
const
auto
&
attr
:
attr_map
)
{
if
(
activation_op
->
HasAttr
(
attr
.
first
))
{
softplus_op
->
SetAttr
(
attr
.
second
,
activation_op
->
GetAttr
(
attr
.
first
));
}
}
if
(
fuse_activation_type
==
"gelu"
&&
activation_op
->
HasAttr
(
"approximate"
)
&&
if
(
act_type
==
"gelu"
&&
activation_op
->
HasAttr
(
"approximate"
)
&&
BOOST_GET_CONST
(
bool
,
activation_op
->
GetAttr
(
"approximate"
)))
softplus_op
->
SetAttr
(
"fuse_activation
_type
"
,
std
::
string
(
"gelu_tanh"
));
softplus_op
->
SetAttr
(
"fuse_activation"
,
std
::
string
(
"gelu_tanh"
));
else
softplus_op
->
SetAttr
(
"fuse_activation
_type"
,
fuse_activation
_type
);
softplus_op
->
SetAttr
(
"fuse_activation
"
,
act
_type
);
softplus_op
->
SetAttr
(
"use_mkldnn"
,
true
);
...
...
@@ -121,7 +97,7 @@ void SoftplusActivationOneDNNPass::FuseSoftplusActivation(
if
(
!
Has
(
"disable_logs"
)
||
!
Get
<
bool
>
(
"disable_logs"
))
PrettyLogDetail
(
"--- fused %d softplus with %s activation"
,
found_softplus_activation_count
,
fuse_activation
_type
);
act
_type
);
}
}
// namespace ir
}
// namespace framework
...
...
@@ -133,13 +109,16 @@ REGISTER_PASS_CAPABILITY(softplus_activation_mkldnn_fuse_pass)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
LE
(
"softplus"
,
1
)
.
EQ
(
"relu"
,
0
)
.
EQ
(
"tanh"
,
0
)
.
LE
(
"leaky_relu"
,
1
)
.
EQ
(
"swish"
,
0
)
.
EQ
(
"hard_swish"
,
0
)
.
EQ
(
"sqrt"
,
0
)
.
EQ
(
"abs"
,
0
)
.
LE
(
"relu6"
,
1
)
.
LE
(
"clip"
,
1
)
.
EQ
(
"gelu"
,
0
));
.
EQ
(
"gelu"
,
0
)
.
EQ
(
"hard_sigmoid"
,
0
)
.
LE
(
"hard_swish"
,
0
)
.
LE
(
"leaky_relu"
,
1
)
.
LE
(
"mish"
,
1
)
.
EQ
(
"relu"
,
0
)
.
EQ
(
"relu6"
,
0
)
.
EQ
(
"sigmoid"
,
0
)
.
EQ
(
"sqrt"
,
0
)
.
EQ
(
"swish"
,
0
)
.
EQ
(
"tanh"
,
0
));
paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h
浏览文件 @
826e2781
...
...
@@ -34,10 +34,8 @@ class SoftplusActivationOneDNNPass : public FusePassBase {
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
void
FuseSoftplusActivation
(
ir
::
Graph
*
graph
,
const
std
::
string
&
fuse_activation_type
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
attr_map
)
const
;
void
FuseSoftplusActivation
(
ir
::
Graph
*
graph
,
const
std
::
string
&
act_type
)
const
;
};
}
// namespace ir
...
...
paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h
浏览文件 @
826e2781
...
...
@@ -50,22 +50,7 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
private:
dnnl
::
post_ops
get_post_ops
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
dnnl
::
post_ops
post_operations
;
if
(
ctx
.
HasAttr
(
"activation_type"
))
{
const
float
scale
=
ctx
.
HasAttr
(
"activation_scale"
)
?
ctx
.
Attr
<
float
>
(
"activation_scale"
)
:
1.0
f
;
const
float
alpha
=
ctx
.
HasAttr
(
"activation_alpha"
)
?
ctx
.
Attr
<
float
>
(
"activation_alpha"
)
:
0.0
f
;
const
float
beta
=
ctx
.
HasAttr
(
"activation_beta"
)
?
ctx
.
Attr
<
float
>
(
"activation_beta"
)
:
0.0
f
;
const
auto
activation_algorithm
=
platform
::
AcquireActivationAlgorithm
(
ctx
.
Attr
<
std
::
string
>
(
"activation_type"
));
post_operations
.
append_eltwise
(
scale
,
activation_algorithm
,
alpha
,
beta
);
}
platform
::
AppendActivation
(
ctx
,
post_operations
);
return
post_operations
;
}
...
...
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
浏览文件 @
826e2781
...
...
@@ -553,10 +553,6 @@ class ConvMKLDNNHandlerT
dnnl
::
primitive_attr
conv_attr
;
dnnl
::
post_ops
post_operations
;
const
std
::
string
fuse_activation
=
ctx
.
Attr
<
std
::
string
>
(
"fuse_activation"
);
const
float
fuse_alpha
=
ctx
.
Attr
<
float
>
(
"fuse_alpha"
);
const
float
fuse_beta
=
ctx
.
Attr
<
float
>
(
"fuse_beta"
);
const
bool
fuse_residual_conn
=
ctx
.
Attr
<
bool
>
(
"fuse_residual_connection"
);
float
sum_scale
=
1.0
f
;
...
...
@@ -587,19 +583,7 @@ class ConvMKLDNNHandlerT
post_operations
.
append_sum
(
sum_scale
);
}
if
(
fuse_activation
==
"hard_sigmoid"
)
{
post_operations
.
append_eltwise
(
activation_scale
,
dnnl
::
algorithm
::
eltwise_linear
,
fuse_alpha
,
fuse_beta
);
post_operations
.
append_eltwise
(
activation_scale
,
dnnl
::
algorithm
::
eltwise_clip
,
0.0
f
,
1.0
f
);
}
else
if
(
fuse_activation
!=
""
)
{
const
auto
activation_algorithm
=
platform
::
AcquireActivationAlgorithm
(
fuse_activation
);
post_operations
.
append_eltwise
(
activation_scale
,
activation_algorithm
,
fuse_alpha
,
fuse_beta
);
}
platform
::
AppendActivation
(
ctx
,
post_operations
,
activation_scale
);
conv_attr
.
set_post_ops
(
post_operations
);
return
conv_attr
;
...
...
paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h
浏览文件 @
826e2781
...
...
@@ -46,7 +46,7 @@ class SoftplusMKLDNNHandler
1.0
f
,
dnnl
::
algorithm
::
eltwise_linear
,
1.0
f
/
beta
,
0.0
f
);
}
AppendFusedActivationIfExists
(
ctx
,
&
post_ops
);
platform
::
AppendActivation
(
ctx
,
post_ops
);
dnnl
::
primitive_attr
attrs
;
attrs
.
set_post_ops
(
post_ops
);
...
...
@@ -62,42 +62,8 @@ class SoftplusMKLDNNHandler
return
this
->
AcquireMemoryFromPrimitive
(
this
->
fwd_pd_
->
src1_desc
(),
platform
::
to_void_cast
<
float
>
(
beta
));
}
private:
void
AppendFusedActivationIfExists
(
const
framework
::
ExecutionContext
&
ctx
,
dnnl
::
post_ops
*
post_ops
)
{
const
auto
&
fused_activation_type
=
algo_map
.
find
(
ctx
.
Attr
<
std
::
string
>
(
"fuse_activation_type"
));
if
(
fused_activation_type
!=
algo_map
.
end
())
{
auto
scale_out
=
ctx
.
Attr
<
float
>
(
"fuse_activation_scale"
);
// for future int8 support
post_ops
->
append_eltwise
(
scale_out
,
fused_activation_type
->
second
,
ctx
.
Attr
<
float
>
(
"fuse_activation_alpha"
),
ctx
.
Attr
<
float
>
(
"fuse_activation_beta"
));
}
}
static
const
std
::
unordered_map
<
std
::
string
,
dnnl
::
algorithm
>
algo_map
;
};
template
<
typename
T
>
const
std
::
unordered_map
<
std
::
string
,
dnnl
::
algorithm
>
SoftplusMKLDNNHandler
<
T
>::
algo_map
=
{
{
"relu"
,
dnnl
::
algorithm
::
eltwise_relu
},
{
"tanh"
,
dnnl
::
algorithm
::
eltwise_tanh
},
{
"leaky_relu"
,
dnnl
::
algorithm
::
eltwise_relu
},
{
"swish"
,
dnnl
::
algorithm
::
eltwise_swish
},
{
"hardswish"
,
dnnl
::
algorithm
::
eltwise_hardswish
},
{
"sqrt"
,
dnnl
::
algorithm
::
eltwise_sqrt
},
{
"abs"
,
dnnl
::
algorithm
::
eltwise_abs
},
{
"clip"
,
dnnl
::
algorithm
::
eltwise_clip
},
{
"gelu"
,
dnnl
::
algorithm
::
eltwise_gelu_erf
},
{
"gelu_tanh"
,
dnnl
::
algorithm
::
eltwise_gelu_tanh
},
{
"relu6"
,
dnnl
::
algorithm
::
eltwise_bounded_relu
},
{
"sigmoid"
,
dnnl
::
algorithm
::
eltwise_logistic
}};
template
<
typename
T
>
void
custom_softplus_eltwise_forward
(
const
framework
::
ExecutionContext
&
ctx
)
{
const
auto
&
dev_ctx
=
...
...
paddle/fluid/platform/mkldnn_reuse.h
浏览文件 @
826e2781
...
...
@@ -1013,32 +1013,93 @@ class ActivationMKLDNNHandler
}
};
static
const
dnnl
::
algorithm
AcquireActivationAlgorithm
(
std
::
string
activation_name
)
{
std
::
unordered_map
<
std
::
string
,
dnnl
::
algorithm
>
activation_map
=
{
{
"abs"
,
dnnl
::
algorithm
::
eltwise_abs
},
{
"clip"
,
dnnl
::
algorithm
::
eltwise_clip
},
{
"gelu"
,
dnnl
::
algorithm
::
eltwise_gelu_erf
},
{
"gelu_erf"
,
dnnl
::
algorithm
::
eltwise_gelu_erf
},
{
"gelu_tanh"
,
dnnl
::
algorithm
::
eltwise_gelu_tanh
},
{
"hard_swish"
,
dnnl
::
algorithm
::
eltwise_hardswish
},
{
"leaky_relu"
,
dnnl
::
algorithm
::
eltwise_relu
},
{
"mish"
,
dnnl
::
algorithm
::
eltwise_mish
},
{
"relu"
,
dnnl
::
algorithm
::
eltwise_relu
},
{
"relu6"
,
dnnl
::
algorithm
::
eltwise_bounded_relu
},
{
"sigmoid"
,
dnnl
::
algorithm
::
eltwise_logistic
},
{
"sqrt"
,
dnnl
::
algorithm
::
eltwise_sqrt
},
{
"swish"
,
dnnl
::
algorithm
::
eltwise_swish
},
{
"tanh"
,
dnnl
::
algorithm
::
eltwise_tanh
}};
const
auto
&
activation_type
=
activation_map
.
find
(
activation_name
);
PADDLE_ENFORCE_NE
(
activation_type
,
activation_map
.
end
(),
platform
::
errors
::
InvalidArgument
(
"Activation '%s' not found in oneDNN algorithms mapper"
,
activation_name
));
return
activation_type
->
second
;
static
void
AppendActivation
(
const
framework
::
ExecutionContext
&
ctx
,
dnnl
::
post_ops
&
post_ops
,
float
activation_scale
=
1.0
f
)
{
const
auto
invalid_attribute
=
ctx
.
HasAttr
(
"fuse_activation"
)
?
ctx
.
Attr
<
std
::
string
>
(
"fuse_activation"
).
empty
()
:
true
;
if
(
invalid_attribute
)
return
;
const
auto
fuse_activation
=
ctx
.
Attr
<
std
::
string
>
(
"fuse_activation"
);
const
auto
fuse_alpha
=
ctx
.
HasAttr
(
"fuse_alpha"
)
?
ctx
.
Attr
<
float
>
(
"fuse_alpha"
)
:
0.0
f
;
const
auto
fuse_beta
=
ctx
.
HasAttr
(
"fuse_beta"
)
?
ctx
.
Attr
<
float
>
(
"fuse_beta"
)
:
0.0
f
;
if
(
fuse_activation
==
"hard_sigmoid"
)
{
post_ops
.
append_eltwise
(
activation_scale
,
dnnl
::
algorithm
::
eltwise_linear
,
fuse_alpha
,
fuse_beta
);
post_ops
.
append_eltwise
(
activation_scale
,
dnnl
::
algorithm
::
eltwise_clip
,
0.0
f
,
1.0
f
);
}
else
{
const
std
::
unordered_map
<
std
::
string
,
dnnl
::
algorithm
>
activation_map
=
{
{
"abs"
,
dnnl
::
algorithm
::
eltwise_abs
},
{
"clip"
,
dnnl
::
algorithm
::
eltwise_clip
},
{
"gelu"
,
dnnl
::
algorithm
::
eltwise_gelu_erf
},
{
"gelu_erf"
,
dnnl
::
algorithm
::
eltwise_gelu_erf
},
{
"gelu_tanh"
,
dnnl
::
algorithm
::
eltwise_gelu_tanh
},
{
"hard_swish"
,
dnnl
::
algorithm
::
eltwise_hardswish
},
{
"leaky_relu"
,
dnnl
::
algorithm
::
eltwise_relu
},
{
"mish"
,
dnnl
::
algorithm
::
eltwise_mish
},
{
"relu"
,
dnnl
::
algorithm
::
eltwise_relu
},
{
"relu6"
,
dnnl
::
algorithm
::
eltwise_bounded_relu
},
{
"sigmoid"
,
dnnl
::
algorithm
::
eltwise_logistic
},
{
"sqrt"
,
dnnl
::
algorithm
::
eltwise_sqrt
},
{
"swish"
,
dnnl
::
algorithm
::
eltwise_swish
},
{
"tanh"
,
dnnl
::
algorithm
::
eltwise_tanh
}};
const
auto
&
activation_type
=
activation_map
.
find
(
fuse_activation
);
PADDLE_ENFORCE_NE
(
activation_type
,
activation_map
.
end
(),
platform
::
errors
::
InvalidArgument
(
"Activation '%s' not found in oneDNN algorithms mapper"
,
fuse_activation
));
post_ops
.
append_eltwise
(
activation_scale
,
activation_type
->
second
,
fuse_alpha
,
fuse_beta
);
}
}
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetAttributeMap
(
std
::
string
act_type
)
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
attr_map
;
if
(
act_type
==
"swish"
)
attr_map
.
emplace
(
"beta"
,
"fuse_alpha"
);
else
if
(
act_type
==
"relu6"
)
attr_map
.
emplace
(
"threshold"
,
"fuse_alpha"
);
else
if
(
act_type
==
"hard_sigmoid"
)
{
attr_map
.
emplace
(
"slope"
,
"fuse_alpha"
);
attr_map
.
emplace
(
"offset"
,
"fuse_beta"
);
}
else
if
(
act_type
==
"clip"
)
{
attr_map
.
emplace
(
"min"
,
"fuse_alpha"
);
attr_map
.
emplace
(
"max"
,
"fuse_beta"
);
}
else
{
attr_map
.
emplace
(
"alpha"
,
"fuse_alpha"
);
attr_map
.
emplace
(
"beta"
,
"fuse_beta"
);
}
return
attr_map
;
}
static
std
::
vector
<
std
::
string
>
GetSupportedActivations
()
{
return
std
::
vector
<
std
::
string
>
{
"abs"
,
"clip"
,
"gelu"
,
"hard_sigmoid"
,
"hard_swish"
,
"leaky_relu"
,
"mish"
,
"relu"
,
"relu6"
,
"sigmoid"
,
"sqrt"
,
"swish"
,
"tanh"
};
}
class
ReorderMKLDNNHandler
{
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_softplus_activation_fuse_pass.py
浏览文件 @
826e2781
...
...
@@ -23,8 +23,8 @@ from paddle.fluid.core import PassVersionChecker
class
SoftplusActivationReluOneDNNFusePassTest
(
InferencePassTest
):
fuse_a
ctivation_a
lpha
=
None
fuse_
activation_
beta
=
None
fuse_alpha
=
None
fuse_beta
=
None
pass_name
=
'softplus_activation_mkldnn_fuse_pass'
def
setUp
(
self
):
...
...
@@ -34,13 +34,13 @@ class SoftplusActivationReluOneDNNFusePassTest(InferencePassTest):
shape
=
[
-
1
,
3
,
100
,
100
],
dtype
=
"float32"
)
softplus_out
=
fluid
.
layers
.
softplus
(
data
)
if
self
.
fuse_
activation_
beta
is
not
None
:
activation_out
=
self
.
fuse_activation
(
softplus_out
,
self
.
fuse_activation
_alpha
,
self
.
fuse_activation
_beta
)
elif
self
.
fuse_a
ctivation_a
lpha
is
not
None
:
activation_out
=
self
.
fuse_activation
(
softplus_out
,
self
.
fuse_activation
_alpha
)
if
self
.
fuse_beta
is
not
None
:
activation_out
=
self
.
fuse_activation
(
softplus_out
,
self
.
fuse
_alpha
,
self
.
fuse
_beta
)
elif
self
.
fuse_alpha
is
not
None
:
activation_out
=
self
.
fuse_activation
(
softplus_out
,
self
.
fuse
_alpha
)
else
:
activation_out
=
self
.
fuse_activation
(
softplus_out
)
...
...
@@ -73,7 +73,7 @@ class SoftplusActivationLeakyReluOneDNNFusePassTest(
def
set_params
(
self
):
self
.
fuse_activation
=
fluid
.
layers
.
leaky_relu
self
.
fuse_a
ctivation_a
lpha
=
0.3
self
.
fuse_alpha
=
0.3
class
SoftplusActivationSwishOneDNNFusePassTest
(
...
...
@@ -81,7 +81,7 @@ class SoftplusActivationSwishOneDNNFusePassTest(
def
set_params
(
self
):
self
.
fuse_activation
=
fluid
.
layers
.
swish
self
.
fuse_a
ctivation_a
lpha
=
3
self
.
fuse_alpha
=
3
class
SoftplusActivationHardSwishOneDNNFusePassTest
(
...
...
@@ -110,8 +110,8 @@ class SoftplusActivationClipOneDNNFusePassTest(
def
set_params
(
self
):
self
.
fuse_activation
=
fluid
.
layers
.
clip
self
.
fuse_a
ctivation_a
lpha
=
1.1
self
.
fuse_
activation_
beta
=
5.2
self
.
fuse_alpha
=
1.1
self
.
fuse_beta
=
5.2
class
SoftplusActivationGeluErfOneDNNFusePassTest
(
...
...
@@ -126,7 +126,7 @@ class SoftplusActivationGeluTanhOneDNNFusePassTest(
def
set_params
(
self
):
self
.
fuse_activation
=
fluid
.
layers
.
gelu
self
.
fuse_a
ctivation_a
lpha
=
True
# simulated "Approximate" attr
self
.
fuse_alpha
=
True
# simulated "Approximate" attr
class
SoftplusActivationRelu6OneDNNFusePassTest
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录