Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
826e2781
P
Paddle
项目概览
PaddlePaddle
/
Paddle
接近 2 年 前同步成功
通知
2323
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
826e2781
编写于
7月 11, 2022
作者:
S
Sławomir Siwek
提交者:
GitHub
7月 11, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Unify and generalize activation fuse passes (#44185)
* reduce redundancy * python code style * fix int8 ut
上级
526be01a
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
203 addition
and
450 deletion
+203
-450
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+14
-95
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+6
-78
paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc
...d/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc
+14
-47
paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h
...id/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h
+3
-5
paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.cc
paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.cc
+26
-53
paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.h
paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.h
+3
-5
paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc
paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc
+6
-9
paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.cc
...amework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.cc
+25
-46
paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h
...ramework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h
+2
-4
paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h
...luid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h
+1
-16
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
+1
-17
paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h
paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h
+1
-35
paddle/fluid/platform/mkldnn_reuse.h
paddle/fluid/platform/mkldnn_reuse.h
+87
-26
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_softplus_activation_fuse_pass.py
...ir/inference/test_mkldnn_softplus_activation_fuse_pass.py
+14
-14
未找到文件。
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
826e2781
...
@@ -931,65 +931,22 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input,
...
@@ -931,65 +931,22 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input,
return
bn_out_var
;
return
bn_out_var
;
}
}
PDNode
*
patterns
::
ConvActivation
::
operator
()(
PDNode
*
patterns
::
OperatorActivation
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
conv_input
,
const
std
::
string
&
operator_type
,
const
std
::
string
&
activation_type
)
{
std
::
string
conv_type
,
auto
*
preceding_op
=
std
::
string
activation_type
)
{
pattern
->
NewNode
(
preceding_op_repr
())
->
assert_is_op
(
operator_type
);
// Create Operators
auto
*
preceding_op_out
=
pattern
->
NewNode
(
preceding_op_out_repr
())
conv_input
->
assert_is_op_input
(
conv_type
,
"Input"
);
->
AsIntermediate
()
auto
*
conv_op
=
pattern
->
NewNode
(
conv_repr
())
->
assert_is_op
(
conv_type
);
->
assert_is_only_output_of_op
(
operator_type
)
auto
*
activation_op
=
->
assert_is_op_input
(
activation_type
);
pattern
->
NewNode
(
activation_repr
())
->
assert_is_op
(
activation_type
);
// Create variables
// Filter
auto
*
conv_weight_var
=
pattern
->
NewNode
(
conv_weight_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
conv_type
,
"Filter"
);
// intermediate variable, will be removed in the IR after fuse.
auto
*
conv_out_var
=
pattern
->
NewNode
(
conv_out_repr
())
->
AsIntermediate
()
->
assert_is_only_output_of_op
(
conv_type
)
->
assert_is_op_input
(
activation_type
);
// output
auto
*
activation_out_var
=
pattern
->
NewNode
(
activation_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
activation_type
);
conv_op
->
LinksFrom
({
conv_input
,
conv_weight_var
}).
LinksTo
({
conv_out_var
});
activation_op
->
LinksFrom
({
conv_out_var
}).
LinksTo
({
activation_out_var
});
return
activation_out_var
;
}
PDNode
*
patterns
::
ElementwiseActivation
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
elementwise_a
,
const
std
::
string
&
elementwise_type
,
const
std
::
string
&
activation_type
)
{
// Create Operators
elementwise_a
->
assert_is_op_input
(
elementwise_type
,
"X"
);
auto
*
elementwise_op
=
pattern
->
NewNode
(
elementwise_repr
())
->
assert_is_op
(
elementwise_type
);
auto
*
activation_op
=
auto
*
activation_op
=
pattern
->
NewNode
(
activation_repr
())
->
assert_is_op
(
activation_type
);
pattern
->
NewNode
(
activation_repr
())
->
assert_is_op
(
activation_type
);
// Create variables
auto
*
activation_out
=
pattern
->
NewNode
(
activation_out_repr
())
auto
*
elementwise_b
=
pattern
->
NewNode
(
elementwise_b_repr
())
->
AsOutput
()
->
AsInput
()
->
assert_is_op_output
(
activation_type
);
->
assert_is_op_input
(
elementwise_type
,
"Y"
);
preceding_op
->
LinksTo
({
preceding_op_out
});
// intermediate variable, will be removed in the IR after fuse.
activation_op
->
LinksFrom
({
preceding_op_out
}).
LinksTo
({
activation_out
});
auto
*
elementwise_out_var
=
return
activation_out
;
pattern
->
NewNode
(
elementwise_out_repr
())
->
AsIntermediate
()
->
assert_is_only_output_of_op
(
elementwise_type
)
->
assert_is_op_input
(
activation_type
);
// output
auto
*
activation_out_var
=
pattern
->
NewNode
(
activation_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
activation_type
);
elementwise_op
->
LinksFrom
({
elementwise_a
,
elementwise_b
})
.
LinksTo
({
elementwise_out_var
});
activation_op
->
LinksFrom
({
elementwise_out_var
}).
LinksTo
({
activation_out_var
});
return
activation_out_var
;
}
}
PDNode
*
patterns
::
SeqConvEltAddRelu
::
operator
()(
PDNode
*
patterns
::
SeqConvEltAddRelu
::
operator
()(
...
@@ -1121,44 +1078,6 @@ PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x,
...
@@ -1121,44 +1078,6 @@ PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x,
return
fc_out_var
;
return
fc_out_var
;
}
}
PDNode
*
patterns
::
FCActOneDNN
::
operator
()(
const
std
::
string
&
act_type
)
{
auto
*
fc
=
pattern
->
NewNode
(
fc_repr
())
->
assert_is_op
(
"fc"
);
auto
*
fc_out
=
pattern
->
NewNode
(
fc_out_repr
())
->
assert_is_op_output
(
"fc"
,
"Out"
)
->
assert_is_op_input
(
act_type
);
auto
*
act
=
pattern
->
NewNode
(
act_repr
())
->
assert_is_op
(
act_type
)
->
AsIntermediate
();
auto
*
act_out
=
pattern
->
NewNode
(
act_out_repr
())
->
assert_is_op_output
(
act_type
,
"Out"
)
->
AsOutput
();
fc
->
LinksTo
({
fc_out
});
act
->
LinksFrom
({
fc_out
}).
LinksTo
({
act_out
});
return
act_out
;
}
PDNode
*
patterns
::
SoftplusActivation
::
operator
()(
std
::
string
activation_type
)
{
// Create Operators
auto
*
softplus_op
=
pattern
->
NewNode
(
softplus_repr
())
->
assert_is_op
(
"softplus"
);
auto
*
activation_op
=
pattern
->
NewNode
(
activation_repr
())
->
assert_is_op
(
activation_type
);
// intermediate variable, will be removed in the IR after fuse.
auto
*
softplus_out
=
pattern
->
NewNode
(
softplus_out_repr
())
->
AsIntermediate
()
->
assert_is_only_output_of_op
(
"softplus"
)
->
assert_is_op_input
(
activation_type
);
// output
auto
*
activation_out
=
pattern
->
NewNode
(
activation_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
activation_type
);
softplus_op
->
LinksTo
({
softplus_out
});
activation_op
->
LinksFrom
({
softplus_out
}).
LinksTo
({
activation_out
});
return
activation_out
;
}
PDNode
*
patterns
::
Embedding
::
operator
()(
PDNode
*
x
)
{
PDNode
*
patterns
::
Embedding
::
operator
()(
PDNode
*
x
)
{
x
->
assert_is_op_input
(
"lookup_table"
,
"Ids"
);
x
->
assert_is_op_input
(
"lookup_table"
,
"Ids"
);
auto
*
lookup_table_op
=
auto
*
lookup_table_op
=
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
826e2781
...
@@ -524,49 +524,16 @@ struct ConvBN : public PatternBase {
...
@@ -524,49 +524,16 @@ struct ConvBN : public PatternBase {
PATTERN_DECL_NODE
(
bn_saved_variance
);
PATTERN_DECL_NODE
(
bn_saved_variance
);
};
};
// Conv with Activation
struct
OperatorActivation
:
public
PatternBase
{
// op: conv + activation
OperatorActivation
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
// named nodes:
:
PatternBase
(
pattern
,
name_scope
,
"operator_activation"
)
{}
// conv_input, conv_weight,
// conv_out, conv,
// activation_out, activation
struct
ConvActivation
:
public
PatternBase
{
ConvActivation
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"conv_activation"
)
{}
PDNode
*
operator
()(
PDNode
*
conv_input
,
std
::
string
conv_type
=
"conv2d"
,
std
::
string
activation_type
=
"relu"
);
// declare operator node's name
PDNode
*
operator
()(
const
std
::
string
&
operator_type
,
PATTERN_DECL_NODE
(
conv
);
PATTERN_DECL_NODE
(
activation
);
// declare variable node's name
PATTERN_DECL_NODE
(
conv_weight
);
PATTERN_DECL_NODE
(
conv_out
);
PATTERN_DECL_NODE
(
activation_out
);
};
// Elementwise with Activation
// op: elementwise + activation
// named nodes:
// elementwise_a, elementwise_b,
// elementwise_out, elementwise,
// activation_out, activation
struct
ElementwiseActivation
:
public
PatternBase
{
ElementwiseActivation
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"elementwise_add_activation"
)
{}
PDNode
*
operator
()(
PDNode
*
elementwise_a
,
const
std
::
string
&
elementwise_type
,
const
std
::
string
&
activation_type
);
const
std
::
string
&
activation_type
);
// declare operator node's name
PATTERN_DECL_NODE
(
preceding_op
);
PATTERN_DECL_NODE
(
elementwise
);
PATTERN_DECL_NODE
(
preceding_op_out
);
PATTERN_DECL_NODE
(
activation
);
PATTERN_DECL_NODE
(
activation
);
// declare variable node's name
PATTERN_DECL_NODE
(
elementwise_b
);
PATTERN_DECL_NODE
(
elementwise_out
);
PATTERN_DECL_NODE
(
activation_out
);
PATTERN_DECL_NODE
(
activation_out
);
};
};
...
@@ -639,45 +606,6 @@ struct FCMKLDNN : public PatternBase {
...
@@ -639,45 +606,6 @@ struct FCMKLDNN : public PatternBase {
PATTERN_DECL_NODE
(
output
);
PATTERN_DECL_NODE
(
output
);
};
};
//
// \brief Pattern looking for fc and a directly following activation
// operator.
//
// \note Currently only gelu and tanh are supported as an activation
// function.
// Formula: act(fc(x))
// Op: fc + act
struct
FCActOneDNN
:
public
PatternBase
{
FCActOneDNN
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"fc_act_onednn"
)
{}
PDNode
*
operator
()(
const
std
::
string
&
act_type
);
// declare operator node's name
PATTERN_DECL_NODE
(
fc
);
PATTERN_DECL_NODE
(
act
);
PATTERN_DECL_NODE
(
fc_out
);
PATTERN_DECL_NODE
(
act_out
);
};
// Fuse softplus with activation
// ops: softplus + activation
// nodes:
// softplus, softplus_out,
// activation, activation_out
struct
SoftplusActivation
:
public
PatternBase
{
SoftplusActivation
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"softplus_activation"
)
{}
PDNode
*
operator
()(
std
::
string
activation_type
);
// declare operator node's name
PATTERN_DECL_NODE
(
softplus
);
PATTERN_DECL_NODE
(
activation
);
PATTERN_DECL_NODE
(
softplus_out
);
PATTERN_DECL_NODE
(
activation_out
);
};
// Embedding
// Embedding
struct
Embedding
:
public
PatternBase
{
struct
Embedding
:
public
PatternBase
{
Embedding
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
Embedding
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
...
...
paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc
浏览文件 @
826e2781
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h"
#include "paddle/fluid/string/pretty_log.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -24,61 +25,27 @@ namespace ir {
...
@@ -24,61 +25,27 @@ namespace ir {
using
string
::
PrettyLogDetail
;
using
string
::
PrettyLogDetail
;
void
ConvActivationMkldnnFusePass
::
ApplyImpl
(
Graph
*
graph
)
const
{
void
ConvActivationMkldnnFusePass
::
ApplyImpl
(
Graph
*
graph
)
const
{
std
::
vector
<
std
::
string
>
act_types
=
{
"relu"
,
auto
act_types
=
paddle
::
platform
::
GetSupportedActivations
();
"mish"
,
"swish"
,
"sqrt"
,
"hard_swish"
,
"sigmoid"
,
"abs"
,
"gelu"
,
"relu6"
,
"clip"
,
"tanh"
,
"hard_sigmoid"
,
"leaky_relu"
};
std
::
vector
<
std
::
string
>
conv_types
=
{
"conv2d"
};
std
::
vector
<
std
::
string
>
conv_types
=
{
"conv2d"
};
for
(
const
auto
&
conv_type
:
conv_types
)
for
(
const
auto
&
conv_type
:
conv_types
)
for
(
auto
&
act_type
:
act_types
)
{
for
(
auto
&
act_type
:
act_types
)
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
attrs_map
;
FuseConvAct
(
graph
,
conv_type
,
act_type
);
if
(
act_type
==
"swish"
)
attrs_map
.
emplace
(
"beta"
,
"fuse_alpha"
);
else
if
(
act_type
==
"relu6"
)
attrs_map
.
emplace
(
"threshold"
,
"fuse_alpha"
);
else
if
(
act_type
==
"hard_sigmoid"
)
{
attrs_map
.
emplace
(
"slope"
,
"fuse_alpha"
);
attrs_map
.
emplace
(
"offset"
,
"fuse_beta"
);
}
else
if
(
act_type
==
"clip"
)
{
attrs_map
.
emplace
(
"min"
,
"fuse_alpha"
);
attrs_map
.
emplace
(
"max"
,
"fuse_beta"
);
}
else
{
attrs_map
.
emplace
(
"alpha"
,
"fuse_alpha"
);
attrs_map
.
emplace
(
"beta"
,
"fuse_beta"
);
}
FuseConvAct
(
graph
,
conv_type
,
act_type
,
attrs_map
);
}
}
}
}
void
ConvActivationMkldnnFusePass
::
FuseConvAct
(
void
ConvActivationMkldnnFusePass
::
FuseConvAct
(
Graph
*
graph
,
Graph
*
graph
,
const
std
::
string
&
conv_type
,
const
std
::
string
&
conv_type
,
std
::
string
&
act_type
)
const
{
std
::
string
&
act_type
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
attrs_map
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
InvalidArgument
(
"Graph cannot be nullptr."
));
graph
,
platform
::
errors
::
InvalidArgument
(
"Graph cannot be nullptr."
));
FusePassBase
::
Init
(
conv_type
+
"_"
+
act_type
+
"_mkldnn_fuse_pass"
,
graph
);
FusePassBase
::
Init
(
conv_type
+
"_"
+
act_type
+
"_mkldnn_fuse_pass"
,
graph
);
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
auto
*
conv_input
=
gpd
.
mutable_pattern
()
patterns
::
OperatorActivation
conv_act_pattern
(
gpd
.
mutable_pattern
(),
->
NewNode
(
"conv_activation_mkldnn_fuse/conv_input"
)
"conv_activation_mkldnn_fuse"
);
->
AsInput
()
conv_act_pattern
(
conv_type
,
act_type
);
->
assert_is_op_input
(
conv_type
,
"Input"
);
patterns
::
ConvActivation
conv_act_pattern
(
gpd
.
mutable_pattern
(),
"conv_activation_mkldnn_fuse"
);
conv_act_pattern
(
conv_input
,
conv_type
,
act_type
);
int
found_conv_activation_count
=
0
;
int
found_conv_activation_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
...
@@ -90,16 +57,16 @@ void ConvActivationMkldnnFusePass::FuseConvAct(
...
@@ -90,16 +57,16 @@ void ConvActivationMkldnnFusePass::FuseConvAct(
return
;
return
;
}
}
GET_IR_NODE_FROM_SUBGRAPH
(
conv_weight
,
conv_weight
,
conv_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv
,
preceding_op
,
conv_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_out
,
conv_out
,
conv_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_out
,
preceding_op_out
,
conv_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv
,
conv
,
conv_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
activation_out
,
activation_out
,
conv_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
activation
,
activation
,
conv_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
activation
,
activation
,
conv_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
activation_out
,
activation_out
,
conv_act_pattern
);
OpDesc
*
conv_op
=
conv
->
Op
();
OpDesc
*
conv_op
=
conv
->
Op
();
OpDesc
*
act_op
=
activation
->
Op
();
OpDesc
*
act_op
=
activation
->
Op
();
for
(
const
auto
&
attrs
:
attrs_map
)
{
auto
attr_map
=
paddle
::
platform
::
GetAttributeMap
(
act_type
);
for
(
const
auto
&
attrs
:
attr_map
)
{
if
(
act_op
->
HasAttr
(
attrs
.
first
))
{
if
(
act_op
->
HasAttr
(
attrs
.
first
))
{
conv_op
->
SetAttr
(
attrs
.
second
,
act_op
->
GetAttr
(
attrs
.
first
));
conv_op
->
SetAttr
(
attrs
.
second
,
act_op
->
GetAttr
(
attrs
.
first
));
}
}
...
...
paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h
浏览文件 @
826e2781
...
@@ -31,11 +31,9 @@ class ConvActivationMkldnnFusePass : public FusePassBase {
...
@@ -31,11 +31,9 @@ class ConvActivationMkldnnFusePass : public FusePassBase {
protected:
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
override
;
void
ApplyImpl
(
Graph
*
graph
)
const
override
;
void
FuseConvAct
(
void
FuseConvAct
(
Graph
*
graph
,
Graph
*
graph
,
const
std
::
string
&
conv_type
,
const
std
::
string
&
conv_type
,
std
::
string
&
act_type
)
const
;
std
::
string
&
act_type
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
attrs_map
)
const
;
};
};
}
// namespace ir
}
// namespace ir
...
...
paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.cc
浏览文件 @
826e2781
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h"
#include "paddle/fluid/string/pretty_log.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -26,71 +27,40 @@ namespace ir {
...
@@ -26,71 +27,40 @@ namespace ir {
using
string
::
PrettyLogDetail
;
using
string
::
PrettyLogDetail
;
void
ElementwiseActivationOneDNNPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
void
ElementwiseActivationOneDNNPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
std
::
vector
<
std
::
string
>
act_types
=
{
"relu"
,
auto
act_types
=
paddle
::
platform
::
GetSupportedActivations
();
"tanh"
,
"leaky_relu"
,
"swish"
,
"hard_swish"
,
"sqrt"
,
"abs"
,
"clip"
,
"gelu"
,
"relu6"
,
"sigmoid"
};
std
::
vector
<
std
::
string
>
elt_types
=
{
std
::
vector
<
std
::
string
>
elt_types
=
{
"elementwise_add"
,
"elementwise_sub"
,
"elementwise_mul"
};
"elementwise_add"
,
"elementwise_sub"
,
"elementwise_mul"
};
for
(
const
auto
&
elt_type
:
elt_types
)
for
(
const
auto
&
elt_type
:
elt_types
)
for
(
const
auto
&
act_type
:
act_types
)
{
for
(
const
auto
&
act_type
:
act_types
)
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
attr_map
;
FuseElementwiseAct
(
graph
,
elt_type
,
act_type
);
if
(
act_type
==
"swish"
)
attr_map
.
emplace
(
"beta"
,
"activation_alpha"
);
else
if
(
act_type
==
"relu6"
)
attr_map
.
emplace
(
"threshold"
,
"activation_alpha"
);
else
if
(
act_type
==
"clip"
)
{
attr_map
.
emplace
(
"min"
,
"activation_alpha"
);
attr_map
.
emplace
(
"max"
,
"activation_beta"
);
}
else
{
attr_map
.
emplace
(
"alpha"
,
"activation_alpha"
);
attr_map
.
emplace
(
"beta"
,
"activation_beta"
);
}
FuseElementwiseAct
(
graph
,
elt_type
,
act_type
,
attr_map
);
}
}
}
}
void
ElementwiseActivationOneDNNPass
::
FuseElementwiseAct
(
void
ElementwiseActivationOneDNNPass
::
FuseElementwiseAct
(
Graph
*
graph
,
Graph
*
graph
,
const
std
::
string
&
elt_type
,
const
std
::
string
&
elt_type
,
const
std
::
string
&
act_type
,
const
std
::
string
&
act_type
)
const
{
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
attr_map
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
InvalidArgument
(
"Graph cannot be nullptr."
));
graph
,
platform
::
errors
::
InvalidArgument
(
"Graph cannot be nullptr."
));
FusePassBase
::
Init
(
elt_type
+
"_"
+
act_type
+
"_mkldnn_fuse_pass"
,
graph
);
FusePassBase
::
Init
(
elt_type
+
"_"
+
act_type
+
"_mkldnn_fuse_pass"
,
graph
);
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
auto
*
elementwise_input
=
gpd
.
mutable_pattern
()
patterns
::
OperatorActivation
elementwise_act_pattern
(
gpd
.
mutable_pattern
(),
->
NewNode
(
elt_type
+
"_act/elementwise_input"
)
elt_type
+
"_act"
);
->
AsInput
()
elementwise_act_pattern
(
elt_type
,
act_type
);
->
assert_is_op_input
(
elt_type
,
"X"
);
patterns
::
ElementwiseActivation
elementwise_act_pattern
(
gpd
.
mutable_pattern
(),
elt_type
+
"_act"
);
elementwise_act_pattern
(
elementwise_input
,
elt_type
,
act_type
);
int
found_elementwise_activation_count
=
0
;
int
found_elementwise_activation_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
VLOG
(
4
)
<<
"Fuse "
<<
elt_type
<<
" with activation op."
;
VLOG
(
4
)
<<
"Fuse "
<<
elt_type
<<
" with activation op."
;
// Elementwise output
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_out
,
elementwise_out
,
elementwise_act_pattern
);
// ACT output
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
activation_out
,
activation_out
,
elementwise_act_pattern
);
elementwise
,
preceding_op
,
elementwise_act_pattern
);
// ops
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise
,
elementwise
,
elementwise_act_pattern
);
elementwise
_out
,
preceding_op_out
,
elementwise_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
activation
,
activation
,
elementwise_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
activation
,
activation
,
elementwise_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
activation_out
,
activation_out
,
elementwise_act_pattern
);
auto
*
elementwise_op
=
elementwise
->
Op
();
auto
*
elementwise_op
=
elementwise
->
Op
();
...
@@ -106,6 +76,7 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct(
...
@@ -106,6 +76,7 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct(
}
}
auto
*
activation_op
=
activation
->
Op
();
auto
*
activation_op
=
activation
->
Op
();
auto
attr_map
=
paddle
::
platform
::
GetAttributeMap
(
act_type
);
for
(
const
auto
&
attr
:
attr_map
)
{
for
(
const
auto
&
attr
:
attr_map
)
{
if
(
activation_op
->
HasAttr
(
attr
.
first
))
{
if
(
activation_op
->
HasAttr
(
attr
.
first
))
{
elementwise_op
->
SetAttr
(
attr
.
second
,
elementwise_op
->
SetAttr
(
attr
.
second
,
...
@@ -115,9 +86,9 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct(
...
@@ -115,9 +86,9 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct(
if
(
act_type
==
"gelu"
&&
activation_op
->
HasAttr
(
"approximate"
)
&&
if
(
act_type
==
"gelu"
&&
activation_op
->
HasAttr
(
"approximate"
)
&&
BOOST_GET_CONST
(
bool
,
activation_op
->
GetAttr
(
"approximate"
)))
BOOST_GET_CONST
(
bool
,
activation_op
->
GetAttr
(
"approximate"
)))
elementwise_op
->
SetAttr
(
"
activation_type
"
,
std
::
string
(
"gelu_tanh"
));
elementwise_op
->
SetAttr
(
"
fuse_activation
"
,
std
::
string
(
"gelu_tanh"
));
else
else
elementwise_op
->
SetAttr
(
"
activation_type
"
,
act_type
);
elementwise_op
->
SetAttr
(
"
fuse_activation
"
,
act_type
);
elementwise_op
->
SetOutput
(
"Out"
,
{
activation_out
->
Name
()});
elementwise_op
->
SetOutput
(
"Out"
,
{
activation_out
->
Name
()});
...
@@ -146,14 +117,16 @@ REGISTER_PASS_CAPABILITY(elt_act_mkldnn_fuse_pass)
...
@@ -146,14 +117,16 @@ REGISTER_PASS_CAPABILITY(elt_act_mkldnn_fuse_pass)
.
LE
(
"elementwise_add"
,
1
)
.
LE
(
"elementwise_add"
,
1
)
.
LE
(
"elementwise_sub"
,
1
)
.
LE
(
"elementwise_sub"
,
1
)
.
LE
(
"elementwise_mul"
,
1
)
.
LE
(
"elementwise_mul"
,
1
)
.
LE
(
"relu"
,
0
)
.
EQ
(
"abs"
,
0
)
.
LE
(
"tanh"
,
0
)
.
LE
(
"leaky_relu"
,
1
)
.
LE
(
"swish"
,
0
)
.
LE
(
"hard_swish"
,
0
)
.
LE
(
"sqrt"
,
0
)
.
LE
(
"abs"
,
0
)
.
LE
(
"clip"
,
1
)
.
LE
(
"clip"
,
1
)
.
LE
(
"gelu"
,
0
)
.
EQ
(
"gelu"
,
0
)
.
LE
(
"relu6"
,
0
)
.
EQ
(
"hard_sigmoid"
,
0
)
.
LE
(
"sigmoid"
,
0
));
.
LE
(
"hard_swish"
,
0
)
.
LE
(
"leaky_relu"
,
1
)
.
LE
(
"mish"
,
1
)
.
EQ
(
"relu"
,
0
)
.
EQ
(
"relu6"
,
0
)
.
EQ
(
"sigmoid"
,
0
)
.
EQ
(
"sqrt"
,
0
)
.
EQ
(
"swish"
,
0
)
.
EQ
(
"tanh"
,
0
));
paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.h
浏览文件 @
826e2781
...
@@ -34,11 +34,9 @@ class ElementwiseActivationOneDNNPass : public FusePassBase {
...
@@ -34,11 +34,9 @@ class ElementwiseActivationOneDNNPass : public FusePassBase {
protected:
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
override
;
void
ApplyImpl
(
Graph
*
graph
)
const
override
;
void
FuseElementwiseAct
(
void
FuseElementwiseAct
(
Graph
*
graph
,
Graph
*
graph
,
const
std
::
string
&
elt_types
,
const
std
::
string
&
elt_types
,
const
std
::
string
&
act_types
)
const
;
const
std
::
string
&
act_types
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
attr_map
)
const
;
};
};
}
// namespace ir
}
// namespace ir
...
...
paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc
浏览文件 @
826e2781
...
@@ -39,20 +39,17 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
...
@@ -39,20 +39,17 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
FusePassBase
::
Init
(
"fc_act"
,
graph
);
FusePassBase
::
Init
(
"fc_act"
,
graph
);
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
patterns
::
FCActOneDNN
fc_act_pattern
(
gpd
.
mutable_pattern
(),
"fc_act"
);
patterns
::
OperatorActivation
fc_act_pattern
(
gpd
.
mutable_pattern
(),
"fc_act"
);
fc_act_pattern
(
act_type
);
fc_act_pattern
(
"fc"
,
act_type
);
int
found_fc_act_count
=
0
;
int
found_fc_act_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
VLOG
(
4
)
<<
"Fuse fc with activation op."
;
VLOG
(
4
)
<<
"Fuse fc with activation op."
;
// FC output
GET_IR_NODE_FROM_SUBGRAPH
(
fc
,
preceding_op
,
fc_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fc_out
,
fc_out
,
fc_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fc_out
,
preceding_op_out
,
fc_act_pattern
);
// ACT output
GET_IR_NODE_FROM_SUBGRAPH
(
act
,
activation
,
fc_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
act_out
,
act_out
,
fc_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
act_out
,
activation_out
,
fc_act_pattern
);
// ops
GET_IR_NODE_FROM_SUBGRAPH
(
fc
,
fc
,
fc_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
act
,
act
,
fc_act_pattern
);
auto
*
fc_op
=
fc
->
Op
();
auto
*
fc_op
=
fc
->
Op
();
auto
*
act_op
=
act
->
Op
();
auto
*
act_op
=
act
->
Op
();
...
...
paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.cc
浏览文件 @
826e2781
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h"
#include "paddle/fluid/string/pretty_log.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -26,59 +27,34 @@ namespace ir {
...
@@ -26,59 +27,34 @@ namespace ir {
using
string
::
PrettyLogDetail
;
using
string
::
PrettyLogDetail
;
void
SoftplusActivationOneDNNPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
void
SoftplusActivationOneDNNPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
std
::
vector
<
std
::
string
>
act_types
=
{
"relu"
,
auto
act_types
=
paddle
::
platform
::
GetSupportedActivations
();
"tanh"
,
"leaky_relu"
,
"swish"
,
"hardswish"
,
"sqrt"
,
"abs"
,
"clip"
,
"gelu"
,
"relu6"
,
"sigmoid"
};
for
(
const
auto
&
act_type
:
act_types
)
{
for
(
const
auto
&
act_type
:
act_types
)
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
attr_map
;
FuseSoftplusActivation
(
graph
,
act_type
);
if
(
act_type
==
"swish"
)
attr_map
.
emplace
(
"beta"
,
"fuse_activation_alpha"
);
else
if
(
act_type
==
"relu6"
)
attr_map
.
emplace
(
"threshold"
,
"fuse_activation_alpha"
);
else
if
(
act_type
==
"clip"
)
{
attr_map
.
emplace
(
"min"
,
"fuse_activation_alpha"
);
attr_map
.
emplace
(
"max"
,
"fuse_activation_beta"
);
}
else
{
attr_map
.
emplace
(
"alpha"
,
"fuse_activation_alpha"
);
attr_map
.
emplace
(
"beta"
,
"fuse_activation_beta"
);
}
FuseSoftplusActivation
(
graph
,
act_type
,
attr_map
);
}
}
}
}
void
SoftplusActivationOneDNNPass
::
FuseSoftplusActivation
(
void
SoftplusActivationOneDNNPass
::
FuseSoftplusActivation
(
Graph
*
graph
,
Graph
*
graph
,
const
std
::
string
&
act_type
)
const
{
const
std
::
string
&
fuse_activation_type
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
attr_map
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
InvalidArgument
(
"Graph cannot be nullptr."
));
graph
,
platform
::
errors
::
InvalidArgument
(
"Graph cannot be nullptr."
));
FusePassBase
::
Init
(
"softplus_activation"
,
graph
);
FusePassBase
::
Init
(
"softplus_activation"
,
graph
);
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
patterns
::
Softplus
Activation
softplus_activation_pattern
(
patterns
::
Operator
Activation
softplus_activation_pattern
(
gpd
.
mutable_pattern
(),
"softplus_activation"
);
gpd
.
mutable_pattern
(),
"softplus_activation"
);
softplus_activation_pattern
(
fuse_activation
_type
);
softplus_activation_pattern
(
"softplus"
,
act
_type
);
int
found_softplus_activation_count
=
0
;
int
found_softplus_activation_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
VLOG
(
4
)
<<
"Fuse softplus with activation op."
;
VLOG
(
4
)
<<
"Fuse softplus with activation op."
;
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
softplus_out
,
softplus
_out
,
softplus_activation_pattern
);
softplus_out
,
preceding_op
_out
,
softplus_activation_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
activation_out
,
activation_out
,
softplus_activation_pattern
);
activation_out
,
activation_out
,
softplus_activation_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
softplus
,
softplus
,
softplus_activation_pattern
);
softplus
,
preceding_op
,
softplus_activation_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
activation
,
activation
,
softplus_activation_pattern
);
activation
,
activation
,
softplus_activation_pattern
);
...
@@ -94,18 +70,18 @@ void SoftplusActivationOneDNNPass::FuseSoftplusActivation(
...
@@ -94,18 +70,18 @@ void SoftplusActivationOneDNNPass::FuseSoftplusActivation(
}
}
auto
*
activation_op
=
activation
->
Op
();
auto
*
activation_op
=
activation
->
Op
();
auto
attr_map
=
paddle
::
platform
::
GetAttributeMap
(
act_type
);
for
(
const
auto
&
attr
:
attr_map
)
{
for
(
const
auto
&
attr
:
attr_map
)
{
if
(
activation_op
->
HasAttr
(
attr
.
first
))
{
if
(
activation_op
->
HasAttr
(
attr
.
first
))
{
softplus_op
->
SetAttr
(
attr
.
second
,
activation_op
->
GetAttr
(
attr
.
first
));
softplus_op
->
SetAttr
(
attr
.
second
,
activation_op
->
GetAttr
(
attr
.
first
));
}
}
}
}
if
(
fuse_activation_type
==
"gelu"
&&
if
(
act_type
==
"gelu"
&&
activation_op
->
HasAttr
(
"approximate"
)
&&
activation_op
->
HasAttr
(
"approximate"
)
&&
BOOST_GET_CONST
(
bool
,
activation_op
->
GetAttr
(
"approximate"
)))
BOOST_GET_CONST
(
bool
,
activation_op
->
GetAttr
(
"approximate"
)))
softplus_op
->
SetAttr
(
"fuse_activation
_type
"
,
std
::
string
(
"gelu_tanh"
));
softplus_op
->
SetAttr
(
"fuse_activation"
,
std
::
string
(
"gelu_tanh"
));
else
else
softplus_op
->
SetAttr
(
"fuse_activation
_type"
,
fuse_activation
_type
);
softplus_op
->
SetAttr
(
"fuse_activation
"
,
act
_type
);
softplus_op
->
SetAttr
(
"use_mkldnn"
,
true
);
softplus_op
->
SetAttr
(
"use_mkldnn"
,
true
);
...
@@ -121,7 +97,7 @@ void SoftplusActivationOneDNNPass::FuseSoftplusActivation(
...
@@ -121,7 +97,7 @@ void SoftplusActivationOneDNNPass::FuseSoftplusActivation(
if
(
!
Has
(
"disable_logs"
)
||
!
Get
<
bool
>
(
"disable_logs"
))
if
(
!
Has
(
"disable_logs"
)
||
!
Get
<
bool
>
(
"disable_logs"
))
PrettyLogDetail
(
"--- fused %d softplus with %s activation"
,
PrettyLogDetail
(
"--- fused %d softplus with %s activation"
,
found_softplus_activation_count
,
found_softplus_activation_count
,
fuse_activation
_type
);
act
_type
);
}
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
...
@@ -133,13 +109,16 @@ REGISTER_PASS_CAPABILITY(softplus_activation_mkldnn_fuse_pass)
...
@@ -133,13 +109,16 @@ REGISTER_PASS_CAPABILITY(softplus_activation_mkldnn_fuse_pass)
.
AddCombination
(
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
LE
(
"softplus"
,
1
)
.
LE
(
"softplus"
,
1
)
.
EQ
(
"relu"
,
0
)
.
EQ
(
"tanh"
,
0
)
.
LE
(
"leaky_relu"
,
1
)
.
EQ
(
"swish"
,
0
)
.
EQ
(
"hard_swish"
,
0
)
.
EQ
(
"sqrt"
,
0
)
.
EQ
(
"abs"
,
0
)
.
EQ
(
"abs"
,
0
)
.
LE
(
"relu6"
,
1
)
.
LE
(
"clip"
,
1
)
.
LE
(
"clip"
,
1
)
.
EQ
(
"gelu"
,
0
));
.
EQ
(
"gelu"
,
0
)
.
EQ
(
"hard_sigmoid"
,
0
)
.
LE
(
"hard_swish"
,
0
)
.
LE
(
"leaky_relu"
,
1
)
.
LE
(
"mish"
,
1
)
.
EQ
(
"relu"
,
0
)
.
EQ
(
"relu6"
,
0
)
.
EQ
(
"sigmoid"
,
0
)
.
EQ
(
"sqrt"
,
0
)
.
EQ
(
"swish"
,
0
)
.
EQ
(
"tanh"
,
0
));
paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h
浏览文件 @
826e2781
...
@@ -34,10 +34,8 @@ class SoftplusActivationOneDNNPass : public FusePassBase {
...
@@ -34,10 +34,8 @@ class SoftplusActivationOneDNNPass : public FusePassBase {
protected:
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
void
FuseSoftplusActivation
(
void
FuseSoftplusActivation
(
ir
::
Graph
*
graph
,
ir
::
Graph
*
graph
,
const
std
::
string
&
act_type
)
const
;
const
std
::
string
&
fuse_activation_type
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
attr_map
)
const
;
};
};
}
// namespace ir
}
// namespace ir
...
...
paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h
浏览文件 @
826e2781
...
@@ -50,22 +50,7 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
...
@@ -50,22 +50,7 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
private:
private:
dnnl
::
post_ops
get_post_ops
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
dnnl
::
post_ops
get_post_ops
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
dnnl
::
post_ops
post_operations
;
dnnl
::
post_ops
post_operations
;
if
(
ctx
.
HasAttr
(
"activation_type"
))
{
platform
::
AppendActivation
(
ctx
,
post_operations
);
const
float
scale
=
ctx
.
HasAttr
(
"activation_scale"
)
?
ctx
.
Attr
<
float
>
(
"activation_scale"
)
:
1.0
f
;
const
float
alpha
=
ctx
.
HasAttr
(
"activation_alpha"
)
?
ctx
.
Attr
<
float
>
(
"activation_alpha"
)
:
0.0
f
;
const
float
beta
=
ctx
.
HasAttr
(
"activation_beta"
)
?
ctx
.
Attr
<
float
>
(
"activation_beta"
)
:
0.0
f
;
const
auto
activation_algorithm
=
platform
::
AcquireActivationAlgorithm
(
ctx
.
Attr
<
std
::
string
>
(
"activation_type"
));
post_operations
.
append_eltwise
(
scale
,
activation_algorithm
,
alpha
,
beta
);
}
return
post_operations
;
return
post_operations
;
}
}
...
...
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
浏览文件 @
826e2781
...
@@ -553,10 +553,6 @@ class ConvMKLDNNHandlerT
...
@@ -553,10 +553,6 @@ class ConvMKLDNNHandlerT
dnnl
::
primitive_attr
conv_attr
;
dnnl
::
primitive_attr
conv_attr
;
dnnl
::
post_ops
post_operations
;
dnnl
::
post_ops
post_operations
;
const
std
::
string
fuse_activation
=
ctx
.
Attr
<
std
::
string
>
(
"fuse_activation"
);
const
float
fuse_alpha
=
ctx
.
Attr
<
float
>
(
"fuse_alpha"
);
const
float
fuse_beta
=
ctx
.
Attr
<
float
>
(
"fuse_beta"
);
const
bool
fuse_residual_conn
=
ctx
.
Attr
<
bool
>
(
"fuse_residual_connection"
);
const
bool
fuse_residual_conn
=
ctx
.
Attr
<
bool
>
(
"fuse_residual_connection"
);
float
sum_scale
=
1.0
f
;
float
sum_scale
=
1.0
f
;
...
@@ -587,19 +583,7 @@ class ConvMKLDNNHandlerT
...
@@ -587,19 +583,7 @@ class ConvMKLDNNHandlerT
post_operations
.
append_sum
(
sum_scale
);
post_operations
.
append_sum
(
sum_scale
);
}
}
if
(
fuse_activation
==
"hard_sigmoid"
)
{
platform
::
AppendActivation
(
ctx
,
post_operations
,
activation_scale
);
post_operations
.
append_eltwise
(
activation_scale
,
dnnl
::
algorithm
::
eltwise_linear
,
fuse_alpha
,
fuse_beta
);
post_operations
.
append_eltwise
(
activation_scale
,
dnnl
::
algorithm
::
eltwise_clip
,
0.0
f
,
1.0
f
);
}
else
if
(
fuse_activation
!=
""
)
{
const
auto
activation_algorithm
=
platform
::
AcquireActivationAlgorithm
(
fuse_activation
);
post_operations
.
append_eltwise
(
activation_scale
,
activation_algorithm
,
fuse_alpha
,
fuse_beta
);
}
conv_attr
.
set_post_ops
(
post_operations
);
conv_attr
.
set_post_ops
(
post_operations
);
return
conv_attr
;
return
conv_attr
;
...
...
paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h
浏览文件 @
826e2781
...
@@ -46,7 +46,7 @@ class SoftplusMKLDNNHandler
...
@@ -46,7 +46,7 @@ class SoftplusMKLDNNHandler
1.0
f
,
dnnl
::
algorithm
::
eltwise_linear
,
1.0
f
/
beta
,
0.0
f
);
1.0
f
,
dnnl
::
algorithm
::
eltwise_linear
,
1.0
f
/
beta
,
0.0
f
);
}
}
AppendFusedActivationIfExists
(
ctx
,
&
post_ops
);
platform
::
AppendActivation
(
ctx
,
post_ops
);
dnnl
::
primitive_attr
attrs
;
dnnl
::
primitive_attr
attrs
;
attrs
.
set_post_ops
(
post_ops
);
attrs
.
set_post_ops
(
post_ops
);
...
@@ -62,42 +62,8 @@ class SoftplusMKLDNNHandler
...
@@ -62,42 +62,8 @@ class SoftplusMKLDNNHandler
return
this
->
AcquireMemoryFromPrimitive
(
return
this
->
AcquireMemoryFromPrimitive
(
this
->
fwd_pd_
->
src1_desc
(),
platform
::
to_void_cast
<
float
>
(
beta
));
this
->
fwd_pd_
->
src1_desc
(),
platform
::
to_void_cast
<
float
>
(
beta
));
}
}
private:
void
AppendFusedActivationIfExists
(
const
framework
::
ExecutionContext
&
ctx
,
dnnl
::
post_ops
*
post_ops
)
{
const
auto
&
fused_activation_type
=
algo_map
.
find
(
ctx
.
Attr
<
std
::
string
>
(
"fuse_activation_type"
));
if
(
fused_activation_type
!=
algo_map
.
end
())
{
auto
scale_out
=
ctx
.
Attr
<
float
>
(
"fuse_activation_scale"
);
// for future int8 support
post_ops
->
append_eltwise
(
scale_out
,
fused_activation_type
->
second
,
ctx
.
Attr
<
float
>
(
"fuse_activation_alpha"
),
ctx
.
Attr
<
float
>
(
"fuse_activation_beta"
));
}
}
static
const
std
::
unordered_map
<
std
::
string
,
dnnl
::
algorithm
>
algo_map
;
};
};
template
<
typename
T
>
const
std
::
unordered_map
<
std
::
string
,
dnnl
::
algorithm
>
SoftplusMKLDNNHandler
<
T
>::
algo_map
=
{
{
"relu"
,
dnnl
::
algorithm
::
eltwise_relu
},
{
"tanh"
,
dnnl
::
algorithm
::
eltwise_tanh
},
{
"leaky_relu"
,
dnnl
::
algorithm
::
eltwise_relu
},
{
"swish"
,
dnnl
::
algorithm
::
eltwise_swish
},
{
"hardswish"
,
dnnl
::
algorithm
::
eltwise_hardswish
},
{
"sqrt"
,
dnnl
::
algorithm
::
eltwise_sqrt
},
{
"abs"
,
dnnl
::
algorithm
::
eltwise_abs
},
{
"clip"
,
dnnl
::
algorithm
::
eltwise_clip
},
{
"gelu"
,
dnnl
::
algorithm
::
eltwise_gelu_erf
},
{
"gelu_tanh"
,
dnnl
::
algorithm
::
eltwise_gelu_tanh
},
{
"relu6"
,
dnnl
::
algorithm
::
eltwise_bounded_relu
},
{
"sigmoid"
,
dnnl
::
algorithm
::
eltwise_logistic
}};
template
<
typename
T
>
template
<
typename
T
>
void
custom_softplus_eltwise_forward
(
const
framework
::
ExecutionContext
&
ctx
)
{
void
custom_softplus_eltwise_forward
(
const
framework
::
ExecutionContext
&
ctx
)
{
const
auto
&
dev_ctx
=
const
auto
&
dev_ctx
=
...
...
paddle/fluid/platform/mkldnn_reuse.h
浏览文件 @
826e2781
...
@@ -1013,32 +1013,93 @@ class ActivationMKLDNNHandler
...
@@ -1013,32 +1013,93 @@ class ActivationMKLDNNHandler
}
}
};
};
static
const
dnnl
::
algorithm
AcquireActivationAlgorithm
(
static
void
AppendActivation
(
const
framework
::
ExecutionContext
&
ctx
,
std
::
string
activation_name
)
{
dnnl
::
post_ops
&
post_ops
,
std
::
unordered_map
<
std
::
string
,
dnnl
::
algorithm
>
activation_map
=
{
float
activation_scale
=
1.0
f
)
{
{
"abs"
,
dnnl
::
algorithm
::
eltwise_abs
},
const
auto
invalid_attribute
=
{
"clip"
,
dnnl
::
algorithm
::
eltwise_clip
},
ctx
.
HasAttr
(
"fuse_activation"
)
{
"gelu"
,
dnnl
::
algorithm
::
eltwise_gelu_erf
},
?
ctx
.
Attr
<
std
::
string
>
(
"fuse_activation"
).
empty
()
{
"gelu_erf"
,
dnnl
::
algorithm
::
eltwise_gelu_erf
},
:
true
;
{
"gelu_tanh"
,
dnnl
::
algorithm
::
eltwise_gelu_tanh
},
if
(
invalid_attribute
)
return
;
{
"hard_swish"
,
dnnl
::
algorithm
::
eltwise_hardswish
},
{
"leaky_relu"
,
dnnl
::
algorithm
::
eltwise_relu
},
const
auto
fuse_activation
=
ctx
.
Attr
<
std
::
string
>
(
"fuse_activation"
);
{
"mish"
,
dnnl
::
algorithm
::
eltwise_mish
},
const
auto
fuse_alpha
=
{
"relu"
,
dnnl
::
algorithm
::
eltwise_relu
},
ctx
.
HasAttr
(
"fuse_alpha"
)
?
ctx
.
Attr
<
float
>
(
"fuse_alpha"
)
:
0.0
f
;
{
"relu6"
,
dnnl
::
algorithm
::
eltwise_bounded_relu
},
const
auto
fuse_beta
=
{
"sigmoid"
,
dnnl
::
algorithm
::
eltwise_logistic
},
ctx
.
HasAttr
(
"fuse_beta"
)
?
ctx
.
Attr
<
float
>
(
"fuse_beta"
)
:
0.0
f
;
{
"sqrt"
,
dnnl
::
algorithm
::
eltwise_sqrt
},
{
"swish"
,
dnnl
::
algorithm
::
eltwise_swish
},
if
(
fuse_activation
==
"hard_sigmoid"
)
{
{
"tanh"
,
dnnl
::
algorithm
::
eltwise_tanh
}};
post_ops
.
append_eltwise
(
activation_scale
,
dnnl
::
algorithm
::
eltwise_linear
,
const
auto
&
activation_type
=
activation_map
.
find
(
activation_name
);
fuse_alpha
,
fuse_beta
);
PADDLE_ENFORCE_NE
(
activation_type
,
post_ops
.
append_eltwise
(
activation_map
.
end
(),
activation_scale
,
dnnl
::
algorithm
::
eltwise_clip
,
0.0
f
,
1.0
f
);
platform
::
errors
::
InvalidArgument
(
}
else
{
"Activation '%s' not found in oneDNN algorithms mapper"
,
const
std
::
unordered_map
<
std
::
string
,
dnnl
::
algorithm
>
activation_map
=
{
activation_name
));
{
"abs"
,
dnnl
::
algorithm
::
eltwise_abs
},
return
activation_type
->
second
;
{
"clip"
,
dnnl
::
algorithm
::
eltwise_clip
},
{
"gelu"
,
dnnl
::
algorithm
::
eltwise_gelu_erf
},
{
"gelu_erf"
,
dnnl
::
algorithm
::
eltwise_gelu_erf
},
{
"gelu_tanh"
,
dnnl
::
algorithm
::
eltwise_gelu_tanh
},
{
"hard_swish"
,
dnnl
::
algorithm
::
eltwise_hardswish
},
{
"leaky_relu"
,
dnnl
::
algorithm
::
eltwise_relu
},
{
"mish"
,
dnnl
::
algorithm
::
eltwise_mish
},
{
"relu"
,
dnnl
::
algorithm
::
eltwise_relu
},
{
"relu6"
,
dnnl
::
algorithm
::
eltwise_bounded_relu
},
{
"sigmoid"
,
dnnl
::
algorithm
::
eltwise_logistic
},
{
"sqrt"
,
dnnl
::
algorithm
::
eltwise_sqrt
},
{
"swish"
,
dnnl
::
algorithm
::
eltwise_swish
},
{
"tanh"
,
dnnl
::
algorithm
::
eltwise_tanh
}};
const
auto
&
activation_type
=
activation_map
.
find
(
fuse_activation
);
PADDLE_ENFORCE_NE
(
activation_type
,
activation_map
.
end
(),
platform
::
errors
::
InvalidArgument
(
"Activation '%s' not found in oneDNN algorithms mapper"
,
fuse_activation
));
post_ops
.
append_eltwise
(
activation_scale
,
activation_type
->
second
,
fuse_alpha
,
fuse_beta
);
}
}
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetAttributeMap
(
std
::
string
act_type
)
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
attr_map
;
if
(
act_type
==
"swish"
)
attr_map
.
emplace
(
"beta"
,
"fuse_alpha"
);
else
if
(
act_type
==
"relu6"
)
attr_map
.
emplace
(
"threshold"
,
"fuse_alpha"
);
else
if
(
act_type
==
"hard_sigmoid"
)
{
attr_map
.
emplace
(
"slope"
,
"fuse_alpha"
);
attr_map
.
emplace
(
"offset"
,
"fuse_beta"
);
}
else
if
(
act_type
==
"clip"
)
{
attr_map
.
emplace
(
"min"
,
"fuse_alpha"
);
attr_map
.
emplace
(
"max"
,
"fuse_beta"
);
}
else
{
attr_map
.
emplace
(
"alpha"
,
"fuse_alpha"
);
attr_map
.
emplace
(
"beta"
,
"fuse_beta"
);
}
return
attr_map
;
}
static
std
::
vector
<
std
::
string
>
GetSupportedActivations
()
{
return
std
::
vector
<
std
::
string
>
{
"abs"
,
"clip"
,
"gelu"
,
"hard_sigmoid"
,
"hard_swish"
,
"leaky_relu"
,
"mish"
,
"relu"
,
"relu6"
,
"sigmoid"
,
"sqrt"
,
"swish"
,
"tanh"
};
}
}
class
ReorderMKLDNNHandler
{
class
ReorderMKLDNNHandler
{
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_softplus_activation_fuse_pass.py
浏览文件 @
826e2781
...
@@ -23,8 +23,8 @@ from paddle.fluid.core import PassVersionChecker
...
@@ -23,8 +23,8 @@ from paddle.fluid.core import PassVersionChecker
class
SoftplusActivationReluOneDNNFusePassTest
(
InferencePassTest
):
class
SoftplusActivationReluOneDNNFusePassTest
(
InferencePassTest
):
fuse_a
ctivation_a
lpha
=
None
fuse_alpha
=
None
fuse_
activation_
beta
=
None
fuse_beta
=
None
pass_name
=
'softplus_activation_mkldnn_fuse_pass'
pass_name
=
'softplus_activation_mkldnn_fuse_pass'
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -34,13 +34,13 @@ class SoftplusActivationReluOneDNNFusePassTest(InferencePassTest):
...
@@ -34,13 +34,13 @@ class SoftplusActivationReluOneDNNFusePassTest(InferencePassTest):
shape
=
[
-
1
,
3
,
100
,
100
],
shape
=
[
-
1
,
3
,
100
,
100
],
dtype
=
"float32"
)
dtype
=
"float32"
)
softplus_out
=
fluid
.
layers
.
softplus
(
data
)
softplus_out
=
fluid
.
layers
.
softplus
(
data
)
if
self
.
fuse_
activation_
beta
is
not
None
:
if
self
.
fuse_beta
is
not
None
:
activation_out
=
self
.
fuse_activation
(
activation_out
=
self
.
fuse_activation
(
softplus_out
,
softplus_out
,
self
.
fuse_activation
_alpha
,
self
.
fuse
_alpha
,
self
.
fuse_activation
_beta
)
self
.
fuse
_beta
)
elif
self
.
fuse_a
ctivation_a
lpha
is
not
None
:
elif
self
.
fuse_alpha
is
not
None
:
activation_out
=
self
.
fuse_activation
(
activation_out
=
self
.
fuse_activation
(
softplus_out
,
softplus_out
,
self
.
fuse_activation
_alpha
)
self
.
fuse
_alpha
)
else
:
else
:
activation_out
=
self
.
fuse_activation
(
softplus_out
)
activation_out
=
self
.
fuse_activation
(
softplus_out
)
...
@@ -73,7 +73,7 @@ class SoftplusActivationLeakyReluOneDNNFusePassTest(
...
@@ -73,7 +73,7 @@ class SoftplusActivationLeakyReluOneDNNFusePassTest(
def
set_params
(
self
):
def
set_params
(
self
):
self
.
fuse_activation
=
fluid
.
layers
.
leaky_relu
self
.
fuse_activation
=
fluid
.
layers
.
leaky_relu
self
.
fuse_a
ctivation_a
lpha
=
0.3
self
.
fuse_alpha
=
0.3
class
SoftplusActivationSwishOneDNNFusePassTest
(
class
SoftplusActivationSwishOneDNNFusePassTest
(
...
@@ -81,7 +81,7 @@ class SoftplusActivationSwishOneDNNFusePassTest(
...
@@ -81,7 +81,7 @@ class SoftplusActivationSwishOneDNNFusePassTest(
def
set_params
(
self
):
def
set_params
(
self
):
self
.
fuse_activation
=
fluid
.
layers
.
swish
self
.
fuse_activation
=
fluid
.
layers
.
swish
self
.
fuse_a
ctivation_a
lpha
=
3
self
.
fuse_alpha
=
3
class
SoftplusActivationHardSwishOneDNNFusePassTest
(
class
SoftplusActivationHardSwishOneDNNFusePassTest
(
...
@@ -110,8 +110,8 @@ class SoftplusActivationClipOneDNNFusePassTest(
...
@@ -110,8 +110,8 @@ class SoftplusActivationClipOneDNNFusePassTest(
def
set_params
(
self
):
def
set_params
(
self
):
self
.
fuse_activation
=
fluid
.
layers
.
clip
self
.
fuse_activation
=
fluid
.
layers
.
clip
self
.
fuse_a
ctivation_a
lpha
=
1.1
self
.
fuse_alpha
=
1.1
self
.
fuse_
activation_
beta
=
5.2
self
.
fuse_beta
=
5.2
class
SoftplusActivationGeluErfOneDNNFusePassTest
(
class
SoftplusActivationGeluErfOneDNNFusePassTest
(
...
@@ -126,7 +126,7 @@ class SoftplusActivationGeluTanhOneDNNFusePassTest(
...
@@ -126,7 +126,7 @@ class SoftplusActivationGeluTanhOneDNNFusePassTest(
def
set_params
(
self
):
def
set_params
(
self
):
self
.
fuse_activation
=
fluid
.
layers
.
gelu
self
.
fuse_activation
=
fluid
.
layers
.
gelu
self
.
fuse_a
ctivation_a
lpha
=
True
# simulated "Approximate" attr
self
.
fuse_alpha
=
True
# simulated "Approximate" attr
class
SoftplusActivationRelu6OneDNNFusePassTest
(
class
SoftplusActivationRelu6OneDNNFusePassTest
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录