Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
826e2781
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
826e2781
编写于
7月 11, 2022
作者:
S
Sławomir Siwek
提交者:
GitHub
7月 11, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Unify and generalize activation fuse passes (#44185)
* reduce redundancy * python code style * fix int8 ut
上级
526be01a
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
203 addition
and
450 deletion
+203
-450
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+14
-95
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+6
-78
paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc
...d/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc
+14
-47
paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h
...id/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h
+3
-5
paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.cc
paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.cc
+26
-53
paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.h
paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.h
+3
-5
paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc
paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc
+6
-9
paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.cc
...amework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.cc
+25
-46
paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h
...ramework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h
+2
-4
paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h
...luid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h
+1
-16
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
+1
-17
paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h
paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h
+1
-35
paddle/fluid/platform/mkldnn_reuse.h
paddle/fluid/platform/mkldnn_reuse.h
+87
-26
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_softplus_activation_fuse_pass.py
...ir/inference/test_mkldnn_softplus_activation_fuse_pass.py
+14
-14
未找到文件。
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
826e2781
...
@@ -931,65 +931,22 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input,
...
@@ -931,65 +931,22 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input,
return
bn_out_var
;
return
bn_out_var
;
}
}
PDNode
*
patterns
::
ConvActivation
::
operator
()(
PDNode
*
patterns
::
OperatorActivation
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
conv_input
,
const
std
::
string
&
operator_type
,
const
std
::
string
&
activation_type
)
{
std
::
string
conv_type
,
auto
*
preceding_op
=
std
::
string
activation_type
)
{
pattern
->
NewNode
(
preceding_op_repr
())
->
assert_is_op
(
operator_type
);
// Create Operators
auto
*
preceding_op_out
=
pattern
->
NewNode
(
preceding_op_out_repr
())
conv_input
->
assert_is_op_input
(
conv_type
,
"Input"
);
auto
*
conv_op
=
pattern
->
NewNode
(
conv_repr
())
->
assert_is_op
(
conv_type
);
auto
*
activation_op
=
pattern
->
NewNode
(
activation_repr
())
->
assert_is_op
(
activation_type
);
// Create variables
// Filter
auto
*
conv_weight_var
=
pattern
->
NewNode
(
conv_weight_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
conv_type
,
"Filter"
);
// intermediate variable, will be removed in the IR after fuse.
auto
*
conv_out_var
=
pattern
->
NewNode
(
conv_out_repr
())
->
AsIntermediate
()
->
AsIntermediate
()
->
assert_is_only_output_of_op
(
conv
_type
)
->
assert_is_only_output_of_op
(
operator
_type
)
->
assert_is_op_input
(
activation_type
);
->
assert_is_op_input
(
activation_type
);
// output
auto
*
activation_out_var
=
pattern
->
NewNode
(
activation_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
activation_type
);
conv_op
->
LinksFrom
({
conv_input
,
conv_weight_var
}).
LinksTo
({
conv_out_var
});
activation_op
->
LinksFrom
({
conv_out_var
}).
LinksTo
({
activation_out_var
});
return
activation_out_var
;
}
PDNode
*
patterns
::
ElementwiseActivation
::
operator
()(
paddle
::
framework
::
ir
::
PDNode
*
elementwise_a
,
const
std
::
string
&
elementwise_type
,
const
std
::
string
&
activation_type
)
{
// Create Operators
elementwise_a
->
assert_is_op_input
(
elementwise_type
,
"X"
);
auto
*
elementwise_op
=
pattern
->
NewNode
(
elementwise_repr
())
->
assert_is_op
(
elementwise_type
);
auto
*
activation_op
=
auto
*
activation_op
=
pattern
->
NewNode
(
activation_repr
())
->
assert_is_op
(
activation_type
);
pattern
->
NewNode
(
activation_repr
())
->
assert_is_op
(
activation_type
);
// Create variables
auto
*
activation_out
=
pattern
->
NewNode
(
activation_out_repr
())
auto
*
elementwise_b
=
pattern
->
NewNode
(
elementwise_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
elementwise_type
,
"Y"
);
// intermediate variable, will be removed in the IR after fuse.
auto
*
elementwise_out_var
=
pattern
->
NewNode
(
elementwise_out_repr
())
->
AsIntermediate
()
->
assert_is_only_output_of_op
(
elementwise_type
)
->
assert_is_op_input
(
activation_type
);
// output
auto
*
activation_out_var
=
pattern
->
NewNode
(
activation_out_repr
())
->
AsOutput
()
->
AsOutput
()
->
assert_is_op_output
(
activation_type
);
->
assert_is_op_output
(
activation_type
);
preceding_op
->
LinksTo
({
preceding_op_out
});
elementwise_op
->
LinksFrom
({
elementwise_a
,
elementwise_b
})
activation_op
->
LinksFrom
({
preceding_op_out
}).
LinksTo
({
activation_out
});
.
LinksTo
({
elementwise_out_var
});
return
activation_out
;
activation_op
->
LinksFrom
({
elementwise_out_var
}).
LinksTo
({
activation_out_var
});
return
activation_out_var
;
}
}
PDNode
*
patterns
::
SeqConvEltAddRelu
::
operator
()(
PDNode
*
patterns
::
SeqConvEltAddRelu
::
operator
()(
...
@@ -1121,44 +1078,6 @@ PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x,
...
@@ -1121,44 +1078,6 @@ PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x,
return
fc_out_var
;
return
fc_out_var
;
}
}
PDNode
*
patterns
::
FCActOneDNN
::
operator
()(
const
std
::
string
&
act_type
)
{
auto
*
fc
=
pattern
->
NewNode
(
fc_repr
())
->
assert_is_op
(
"fc"
);
auto
*
fc_out
=
pattern
->
NewNode
(
fc_out_repr
())
->
assert_is_op_output
(
"fc"
,
"Out"
)
->
assert_is_op_input
(
act_type
);
auto
*
act
=
pattern
->
NewNode
(
act_repr
())
->
assert_is_op
(
act_type
)
->
AsIntermediate
();
auto
*
act_out
=
pattern
->
NewNode
(
act_out_repr
())
->
assert_is_op_output
(
act_type
,
"Out"
)
->
AsOutput
();
fc
->
LinksTo
({
fc_out
});
act
->
LinksFrom
({
fc_out
}).
LinksTo
({
act_out
});
return
act_out
;
}
PDNode
*
patterns
::
SoftplusActivation
::
operator
()(
std
::
string
activation_type
)
{
// Create Operators
auto
*
softplus_op
=
pattern
->
NewNode
(
softplus_repr
())
->
assert_is_op
(
"softplus"
);
auto
*
activation_op
=
pattern
->
NewNode
(
activation_repr
())
->
assert_is_op
(
activation_type
);
// intermediate variable, will be removed in the IR after fuse.
auto
*
softplus_out
=
pattern
->
NewNode
(
softplus_out_repr
())
->
AsIntermediate
()
->
assert_is_only_output_of_op
(
"softplus"
)
->
assert_is_op_input
(
activation_type
);
// output
auto
*
activation_out
=
pattern
->
NewNode
(
activation_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
activation_type
);
softplus_op
->
LinksTo
({
softplus_out
});
activation_op
->
LinksFrom
({
softplus_out
}).
LinksTo
({
activation_out
});
return
activation_out
;
}
PDNode
*
patterns
::
Embedding
::
operator
()(
PDNode
*
x
)
{
PDNode
*
patterns
::
Embedding
::
operator
()(
PDNode
*
x
)
{
x
->
assert_is_op_input
(
"lookup_table"
,
"Ids"
);
x
->
assert_is_op_input
(
"lookup_table"
,
"Ids"
);
auto
*
lookup_table_op
=
auto
*
lookup_table_op
=
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
826e2781
...
@@ -524,49 +524,16 @@ struct ConvBN : public PatternBase {
...
@@ -524,49 +524,16 @@ struct ConvBN : public PatternBase {
PATTERN_DECL_NODE
(
bn_saved_variance
);
PATTERN_DECL_NODE
(
bn_saved_variance
);
};
};
// Conv with Activation
struct
OperatorActivation
:
public
PatternBase
{
// op: conv + activation
OperatorActivation
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
// named nodes:
:
PatternBase
(
pattern
,
name_scope
,
"operator_activation"
)
{}
// conv_input, conv_weight,
// conv_out, conv,
// activation_out, activation
struct
ConvActivation
:
public
PatternBase
{
ConvActivation
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"conv_activation"
)
{}
PDNode
*
operator
()(
PDNode
*
conv_input
,
std
::
string
conv_type
=
"conv2d"
,
std
::
string
activation_type
=
"relu"
);
// declare operator node's name
PDNode
*
operator
()(
const
std
::
string
&
operator_type
,
PATTERN_DECL_NODE
(
conv
);
PATTERN_DECL_NODE
(
activation
);
// declare variable node's name
PATTERN_DECL_NODE
(
conv_weight
);
PATTERN_DECL_NODE
(
conv_out
);
PATTERN_DECL_NODE
(
activation_out
);
};
// Elementwise with Activation
// op: elementwise + activation
// named nodes:
// elementwise_a, elementwise_b,
// elementwise_out, elementwise,
// activation_out, activation
struct
ElementwiseActivation
:
public
PatternBase
{
ElementwiseActivation
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"elementwise_add_activation"
)
{}
PDNode
*
operator
()(
PDNode
*
elementwise_a
,
const
std
::
string
&
elementwise_type
,
const
std
::
string
&
activation_type
);
const
std
::
string
&
activation_type
);
// declare operator node's name
PATTERN_DECL_NODE
(
preceding_op
);
PATTERN_DECL_NODE
(
elementwise
);
PATTERN_DECL_NODE
(
preceding_op_out
);
PATTERN_DECL_NODE
(
activation
);
PATTERN_DECL_NODE
(
activation
);
// declare variable node's name
PATTERN_DECL_NODE
(
elementwise_b
);
PATTERN_DECL_NODE
(
elementwise_out
);
PATTERN_DECL_NODE
(
activation_out
);
PATTERN_DECL_NODE
(
activation_out
);
};
};
...
@@ -639,45 +606,6 @@ struct FCMKLDNN : public PatternBase {
...
@@ -639,45 +606,6 @@ struct FCMKLDNN : public PatternBase {
PATTERN_DECL_NODE
(
output
);
PATTERN_DECL_NODE
(
output
);
};
};
//
// \brief Pattern looking for fc and a directly following activation
// operator.
//
// \note Currently only gelu and tanh are supported as an activation
// function.
// Formula: act(fc(x))
// Op: fc + act
struct
FCActOneDNN
:
public
PatternBase
{
FCActOneDNN
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"fc_act_onednn"
)
{}
PDNode
*
operator
()(
const
std
::
string
&
act_type
);
// declare operator node's name
PATTERN_DECL_NODE
(
fc
);
PATTERN_DECL_NODE
(
act
);
PATTERN_DECL_NODE
(
fc_out
);
PATTERN_DECL_NODE
(
act_out
);
};
// Fuse softplus with activation
// ops: softplus + activation
// nodes:
// softplus, softplus_out,
// activation, activation_out
struct
SoftplusActivation
:
public
PatternBase
{
SoftplusActivation
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"softplus_activation"
)
{}
PDNode
*
operator
()(
std
::
string
activation_type
);
// declare operator node's name
PATTERN_DECL_NODE
(
softplus
);
PATTERN_DECL_NODE
(
activation
);
PATTERN_DECL_NODE
(
softplus_out
);
PATTERN_DECL_NODE
(
activation_out
);
};
// Embedding
// Embedding
struct
Embedding
:
public
PatternBase
{
struct
Embedding
:
public
PatternBase
{
Embedding
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
Embedding
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
...
...
paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc
浏览文件 @
826e2781
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h"
#include "paddle/fluid/string/pretty_log.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -24,61 +25,27 @@ namespace ir {
...
@@ -24,61 +25,27 @@ namespace ir {
using
string
::
PrettyLogDetail
;
using
string
::
PrettyLogDetail
;
void
ConvActivationMkldnnFusePass
::
ApplyImpl
(
Graph
*
graph
)
const
{
void
ConvActivationMkldnnFusePass
::
ApplyImpl
(
Graph
*
graph
)
const
{
std
::
vector
<
std
::
string
>
act_types
=
{
"relu"
,
auto
act_types
=
paddle
::
platform
::
GetSupportedActivations
();
"mish"
,
"swish"
,
"sqrt"
,
"hard_swish"
,
"sigmoid"
,
"abs"
,
"gelu"
,
"relu6"
,
"clip"
,
"tanh"
,
"hard_sigmoid"
,
"leaky_relu"
};
std
::
vector
<
std
::
string
>
conv_types
=
{
"conv2d"
};
std
::
vector
<
std
::
string
>
conv_types
=
{
"conv2d"
};
for
(
const
auto
&
conv_type
:
conv_types
)
for
(
const
auto
&
conv_type
:
conv_types
)
for
(
auto
&
act_type
:
act_types
)
{
for
(
auto
&
act_type
:
act_types
)
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
attrs_map
;
FuseConvAct
(
graph
,
conv_type
,
act_type
);
if
(
act_type
==
"swish"
)
attrs_map
.
emplace
(
"beta"
,
"fuse_alpha"
);
else
if
(
act_type
==
"relu6"
)
attrs_map
.
emplace
(
"threshold"
,
"fuse_alpha"
);
else
if
(
act_type
==
"hard_sigmoid"
)
{
attrs_map
.
emplace
(
"slope"
,
"fuse_alpha"
);
attrs_map
.
emplace
(
"offset"
,
"fuse_beta"
);
}
else
if
(
act_type
==
"clip"
)
{
attrs_map
.
emplace
(
"min"
,
"fuse_alpha"
);
attrs_map
.
emplace
(
"max"
,
"fuse_beta"
);
}
else
{
attrs_map
.
emplace
(
"alpha"
,
"fuse_alpha"
);
attrs_map
.
emplace
(
"beta"
,
"fuse_beta"
);
}
FuseConvAct
(
graph
,
conv_type
,
act_type
,
attrs_map
);
}
}
}
}
void
ConvActivationMkldnnFusePass
::
FuseConvAct
(
void
ConvActivationMkldnnFusePass
::
FuseConvAct
(
Graph
*
graph
,
Graph
*
graph
,
const
std
::
string
&
conv_type
,
const
std
::
string
&
conv_type
,
std
::
string
&
act_type
,
std
::
string
&
act_type
)
const
{
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
attrs_map
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
InvalidArgument
(
"Graph cannot be nullptr."
));
graph
,
platform
::
errors
::
InvalidArgument
(
"Graph cannot be nullptr."
));
FusePassBase
::
Init
(
conv_type
+
"_"
+
act_type
+
"_mkldnn_fuse_pass"
,
graph
);
FusePassBase
::
Init
(
conv_type
+
"_"
+
act_type
+
"_mkldnn_fuse_pass"
,
graph
);
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
auto
*
conv_input
=
gpd
.
mutable_pattern
()
patterns
::
OperatorActivation
conv_act_pattern
(
gpd
.
mutable_pattern
(),
->
NewNode
(
"conv_activation_mkldnn_fuse/conv_input"
)
->
AsInput
()
->
assert_is_op_input
(
conv_type
,
"Input"
);
patterns
::
ConvActivation
conv_act_pattern
(
gpd
.
mutable_pattern
(),
"conv_activation_mkldnn_fuse"
);
"conv_activation_mkldnn_fuse"
);
conv_act_pattern
(
conv_
input
,
conv_
type
,
act_type
);
conv_act_pattern
(
conv_type
,
act_type
);
int
found_conv_activation_count
=
0
;
int
found_conv_activation_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
...
@@ -90,16 +57,16 @@ void ConvActivationMkldnnFusePass::FuseConvAct(
...
@@ -90,16 +57,16 @@ void ConvActivationMkldnnFusePass::FuseConvAct(
return
;
return
;
}
}
GET_IR_NODE_FROM_SUBGRAPH
(
conv_weight
,
conv_weight
,
conv_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv
,
preceding_op
,
conv_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_out
,
conv_out
,
conv_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_out
,
preceding_op_out
,
conv_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv
,
conv
,
conv_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
activation_out
,
activation_out
,
conv_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
activation
,
activation
,
conv_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
activation
,
activation
,
conv_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
activation_out
,
activation_out
,
conv_act_pattern
);
OpDesc
*
conv_op
=
conv
->
Op
();
OpDesc
*
conv_op
=
conv
->
Op
();
OpDesc
*
act_op
=
activation
->
Op
();
OpDesc
*
act_op
=
activation
->
Op
();
for
(
const
auto
&
attrs
:
attrs_map
)
{
auto
attr_map
=
paddle
::
platform
::
GetAttributeMap
(
act_type
);
for
(
const
auto
&
attrs
:
attr_map
)
{
if
(
act_op
->
HasAttr
(
attrs
.
first
))
{
if
(
act_op
->
HasAttr
(
attrs
.
first
))
{
conv_op
->
SetAttr
(
attrs
.
second
,
act_op
->
GetAttr
(
attrs
.
first
));
conv_op
->
SetAttr
(
attrs
.
second
,
act_op
->
GetAttr
(
attrs
.
first
));
}
}
...
...
paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h
浏览文件 @
826e2781
...
@@ -31,11 +31,9 @@ class ConvActivationMkldnnFusePass : public FusePassBase {
...
@@ -31,11 +31,9 @@ class ConvActivationMkldnnFusePass : public FusePassBase {
protected:
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
override
;
void
ApplyImpl
(
Graph
*
graph
)
const
override
;
void
FuseConvAct
(
void
FuseConvAct
(
Graph
*
graph
,
Graph
*
graph
,
const
std
::
string
&
conv_type
,
const
std
::
string
&
conv_type
,
std
::
string
&
act_type
,
std
::
string
&
act_type
)
const
;
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
attrs_map
)
const
;
};
};
}
// namespace ir
}
// namespace ir
...
...
paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.cc
浏览文件 @
826e2781
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h"
#include "paddle/fluid/string/pretty_log.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -26,71 +27,40 @@ namespace ir {
...
@@ -26,71 +27,40 @@ namespace ir {
using
string
::
PrettyLogDetail
;
using
string
::
PrettyLogDetail
;
void
ElementwiseActivationOneDNNPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
void
ElementwiseActivationOneDNNPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
std
::
vector
<
std
::
string
>
act_types
=
{
"relu"
,
auto
act_types
=
paddle
::
platform
::
GetSupportedActivations
();
"tanh"
,
"leaky_relu"
,
"swish"
,
"hard_swish"
,
"sqrt"
,
"abs"
,
"clip"
,
"gelu"
,
"relu6"
,
"sigmoid"
};
std
::
vector
<
std
::
string
>
elt_types
=
{
std
::
vector
<
std
::
string
>
elt_types
=
{
"elementwise_add"
,
"elementwise_sub"
,
"elementwise_mul"
};
"elementwise_add"
,
"elementwise_sub"
,
"elementwise_mul"
};
for
(
const
auto
&
elt_type
:
elt_types
)
for
(
const
auto
&
elt_type
:
elt_types
)
for
(
const
auto
&
act_type
:
act_types
)
{
for
(
const
auto
&
act_type
:
act_types
)
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
attr_map
;
FuseElementwiseAct
(
graph
,
elt_type
,
act_type
);
if
(
act_type
==
"swish"
)
attr_map
.
emplace
(
"beta"
,
"activation_alpha"
);
else
if
(
act_type
==
"relu6"
)
attr_map
.
emplace
(
"threshold"
,
"activation_alpha"
);
else
if
(
act_type
==
"clip"
)
{
attr_map
.
emplace
(
"min"
,
"activation_alpha"
);
attr_map
.
emplace
(
"max"
,
"activation_beta"
);
}
else
{
attr_map
.
emplace
(
"alpha"
,
"activation_alpha"
);
attr_map
.
emplace
(
"beta"
,
"activation_beta"
);
}
FuseElementwiseAct
(
graph
,
elt_type
,
act_type
,
attr_map
);
}
}
}
}
void
ElementwiseActivationOneDNNPass
::
FuseElementwiseAct
(
void
ElementwiseActivationOneDNNPass
::
FuseElementwiseAct
(
Graph
*
graph
,
Graph
*
graph
,
const
std
::
string
&
elt_type
,
const
std
::
string
&
elt_type
,
const
std
::
string
&
act_type
,
const
std
::
string
&
act_type
)
const
{
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
attr_map
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
InvalidArgument
(
"Graph cannot be nullptr."
));
graph
,
platform
::
errors
::
InvalidArgument
(
"Graph cannot be nullptr."
));
FusePassBase
::
Init
(
elt_type
+
"_"
+
act_type
+
"_mkldnn_fuse_pass"
,
graph
);
FusePassBase
::
Init
(
elt_type
+
"_"
+
act_type
+
"_mkldnn_fuse_pass"
,
graph
);
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
auto
*
elementwise_input
=
gpd
.
mutable_pattern
()
patterns
::
OperatorActivation
elementwise_act_pattern
(
gpd
.
mutable_pattern
(),
->
NewNode
(
elt_type
+
"_act/elementwise_input"
)
->
AsInput
()
->
assert_is_op_input
(
elt_type
,
"X"
);
patterns
::
ElementwiseActivation
elementwise_act_pattern
(
gpd
.
mutable_pattern
(),
elt_type
+
"_act"
);
elt_type
+
"_act"
);
elementwise_act_pattern
(
el
ementwise_input
,
el
t_type
,
act_type
);
elementwise_act_pattern
(
elt_type
,
act_type
);
int
found_elementwise_activation_count
=
0
;
int
found_elementwise_activation_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
VLOG
(
4
)
<<
"Fuse "
<<
elt_type
<<
" with activation op."
;
VLOG
(
4
)
<<
"Fuse "
<<
elt_type
<<
" with activation op."
;
// Elementwise output
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_out
,
elementwise_out
,
elementwise_act_pattern
);
// ACT output
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
activation_out
,
activation_out
,
elementwise_act_pattern
);
elementwise
,
preceding_op
,
elementwise_act_pattern
);
// ops
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise
,
elementwise
,
elementwise_act_pattern
);
elementwise
_out
,
preceding_op_out
,
elementwise_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
activation
,
activation
,
elementwise_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
activation
,
activation
,
elementwise_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
activation_out
,
activation_out
,
elementwise_act_pattern
);
auto
*
elementwise_op
=
elementwise
->
Op
();
auto
*
elementwise_op
=
elementwise
->
Op
();
...
@@ -106,6 +76,7 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct(
...
@@ -106,6 +76,7 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct(
}
}
auto
*
activation_op
=
activation
->
Op
();
auto
*
activation_op
=
activation
->
Op
();
auto
attr_map
=
paddle
::
platform
::
GetAttributeMap
(
act_type
);
for
(
const
auto
&
attr
:
attr_map
)
{
for
(
const
auto
&
attr
:
attr_map
)
{
if
(
activation_op
->
HasAttr
(
attr
.
first
))
{
if
(
activation_op
->
HasAttr
(
attr
.
first
))
{
elementwise_op
->
SetAttr
(
attr
.
second
,
elementwise_op
->
SetAttr
(
attr
.
second
,
...
@@ -115,9 +86,9 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct(
...
@@ -115,9 +86,9 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct(
if
(
act_type
==
"gelu"
&&
activation_op
->
HasAttr
(
"approximate"
)
&&
if
(
act_type
==
"gelu"
&&
activation_op
->
HasAttr
(
"approximate"
)
&&
BOOST_GET_CONST
(
bool
,
activation_op
->
GetAttr
(
"approximate"
)))
BOOST_GET_CONST
(
bool
,
activation_op
->
GetAttr
(
"approximate"
)))
elementwise_op
->
SetAttr
(
"
activation_type
"
,
std
::
string
(
"gelu_tanh"
));
elementwise_op
->
SetAttr
(
"
fuse_activation
"
,
std
::
string
(
"gelu_tanh"
));
else
else
elementwise_op
->
SetAttr
(
"
activation_type
"
,
act_type
);
elementwise_op
->
SetAttr
(
"
fuse_activation
"
,
act_type
);
elementwise_op
->
SetOutput
(
"Out"
,
{
activation_out
->
Name
()});
elementwise_op
->
SetOutput
(
"Out"
,
{
activation_out
->
Name
()});
...
@@ -146,14 +117,16 @@ REGISTER_PASS_CAPABILITY(elt_act_mkldnn_fuse_pass)
...
@@ -146,14 +117,16 @@ REGISTER_PASS_CAPABILITY(elt_act_mkldnn_fuse_pass)
.
LE
(
"elementwise_add"
,
1
)
.
LE
(
"elementwise_add"
,
1
)
.
LE
(
"elementwise_sub"
,
1
)
.
LE
(
"elementwise_sub"
,
1
)
.
LE
(
"elementwise_mul"
,
1
)
.
LE
(
"elementwise_mul"
,
1
)
.
LE
(
"relu"
,
0
)
.
EQ
(
"abs"
,
0
)
.
LE
(
"tanh"
,
0
)
.
LE
(
"leaky_relu"
,
1
)
.
LE
(
"swish"
,
0
)
.
LE
(
"hard_swish"
,
0
)
.
LE
(
"sqrt"
,
0
)
.
LE
(
"abs"
,
0
)
.
LE
(
"clip"
,
1
)
.
LE
(
"clip"
,
1
)
.
LE
(
"gelu"
,
0
)
.
EQ
(
"gelu"
,
0
)
.
LE
(
"relu6"
,
0
)
.
EQ
(
"hard_sigmoid"
,
0
)
.
LE
(
"sigmoid"
,
0
));
.
LE
(
"hard_swish"
,
0
)
.
LE
(
"leaky_relu"
,
1
)
.
LE
(
"mish"
,
1
)
.
EQ
(
"relu"
,
0
)
.
EQ
(
"relu6"
,
0
)
.
EQ
(
"sigmoid"
,
0
)
.
EQ
(
"sqrt"
,
0
)
.
EQ
(
"swish"
,
0
)
.
EQ
(
"tanh"
,
0
));
paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.h
浏览文件 @
826e2781
...
@@ -34,11 +34,9 @@ class ElementwiseActivationOneDNNPass : public FusePassBase {
...
@@ -34,11 +34,9 @@ class ElementwiseActivationOneDNNPass : public FusePassBase {
protected:
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
override
;
void
ApplyImpl
(
Graph
*
graph
)
const
override
;
void
FuseElementwiseAct
(
void
FuseElementwiseAct
(
Graph
*
graph
,
Graph
*
graph
,
const
std
::
string
&
elt_types
,
const
std
::
string
&
elt_types
,
const
std
::
string
&
act_types
,
const
std
::
string
&
act_types
)
const
;
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
attr_map
)
const
;
};
};
}
// namespace ir
}
// namespace ir
...
...
paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc
浏览文件 @
826e2781
...
@@ -39,20 +39,17 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
...
@@ -39,20 +39,17 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
FusePassBase
::
Init
(
"fc_act"
,
graph
);
FusePassBase
::
Init
(
"fc_act"
,
graph
);
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
patterns
::
FCActOneDNN
fc_act_pattern
(
gpd
.
mutable_pattern
(),
"fc_act"
);
patterns
::
OperatorActivation
fc_act_pattern
(
gpd
.
mutable_pattern
(),
"fc_act"
);
fc_act_pattern
(
act_type
);
fc_act_pattern
(
"fc"
,
act_type
);
int
found_fc_act_count
=
0
;
int
found_fc_act_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
VLOG
(
4
)
<<
"Fuse fc with activation op."
;
VLOG
(
4
)
<<
"Fuse fc with activation op."
;
// FC output
GET_IR_NODE_FROM_SUBGRAPH
(
fc
,
preceding_op
,
fc_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fc_out
,
fc_out
,
fc_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fc_out
,
preceding_op_out
,
fc_act_pattern
);
// ACT output
GET_IR_NODE_FROM_SUBGRAPH
(
act
,
activation
,
fc_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
act_out
,
act_out
,
fc_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
act_out
,
activation_out
,
fc_act_pattern
);
// ops
GET_IR_NODE_FROM_SUBGRAPH
(
fc
,
fc
,
fc_act_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
act
,
act
,
fc_act_pattern
);
auto
*
fc_op
=
fc
->
Op
();
auto
*
fc_op
=
fc
->
Op
();
auto
*
act_op
=
act
->
Op
();
auto
*
act_op
=
act
->
Op
();
...
...
paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.cc
浏览文件 @
826e2781
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h"
#include "paddle/fluid/string/pretty_log.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -26,59 +27,34 @@ namespace ir {
...
@@ -26,59 +27,34 @@ namespace ir {
using
string
::
PrettyLogDetail
;
using
string
::
PrettyLogDetail
;
void
SoftplusActivationOneDNNPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
void
SoftplusActivationOneDNNPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
std
::
vector
<
std
::
string
>
act_types
=
{
"relu"
,
auto
act_types
=
paddle
::
platform
::
GetSupportedActivations
();
"tanh"
,
"leaky_relu"
,
"swish"
,
"hardswish"
,
"sqrt"
,
"abs"
,
"clip"
,
"gelu"
,
"relu6"
,
"sigmoid"
};
for
(
const
auto
&
act_type
:
act_types
)
{
for
(
const
auto
&
act_type
:
act_types
)
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
attr_map
;
FuseSoftplusActivation
(
graph
,
act_type
);
if
(
act_type
==
"swish"
)
attr_map
.
emplace
(
"beta"
,
"fuse_activation_alpha"
);
else
if
(
act_type
==
"relu6"
)
attr_map
.
emplace
(
"threshold"
,
"fuse_activation_alpha"
);
else
if
(
act_type
==
"clip"
)
{
attr_map
.
emplace
(
"min"
,
"fuse_activation_alpha"
);
attr_map
.
emplace
(
"max"
,
"fuse_activation_beta"
);
}
else
{
attr_map
.
emplace
(
"alpha"
,
"fuse_activation_alpha"
);
attr_map
.
emplace
(
"beta"
,
"fuse_activation_beta"
);
}
FuseSoftplusActivation
(
graph
,
act_type
,
attr_map
);
}
}
}
}
void
SoftplusActivationOneDNNPass
::
FuseSoftplusActivation
(
void
SoftplusActivationOneDNNPass
::
FuseSoftplusActivation
(
Graph
*
graph
,
Graph
*
graph
,
const
std
::
string
&
act_type
)
const
{
const
std
::
string
&
fuse_activation_type
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
attr_map
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
InvalidArgument
(
"Graph cannot be nullptr."
));
graph
,
platform
::
errors
::
InvalidArgument
(
"Graph cannot be nullptr."
));
FusePassBase
::
Init
(
"softplus_activation"
,
graph
);
FusePassBase
::
Init
(
"softplus_activation"
,
graph
);
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
patterns
::
Softplus
Activation
softplus_activation_pattern
(
patterns
::
Operator
Activation
softplus_activation_pattern
(
gpd
.
mutable_pattern
(),
"softplus_activation"
);
gpd
.
mutable_pattern
(),
"softplus_activation"
);
softplus_activation_pattern
(
fuse_activation
_type
);
softplus_activation_pattern
(
"softplus"
,
act
_type
);
int
found_softplus_activation_count
=
0
;
int
found_softplus_activation_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
VLOG
(
4
)
<<
"Fuse softplus with activation op."
;
VLOG
(
4
)
<<
"Fuse softplus with activation op."
;
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
softplus_out
,
softplus
_out
,
softplus_activation_pattern
);
softplus_out
,
preceding_op
_out
,
softplus_activation_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
activation_out
,
activation_out
,
softplus_activation_pattern
);
activation_out
,
activation_out
,
softplus_activation_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
softplus
,
softplus
,
softplus_activation_pattern
);
softplus
,
preceding_op
,
softplus_activation_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
GET_IR_NODE_FROM_SUBGRAPH
(
activation
,
activation
,
softplus_activation_pattern
);
activation
,
activation
,
softplus_activation_pattern
);
...
@@ -94,18 +70,18 @@ void SoftplusActivationOneDNNPass::FuseSoftplusActivation(
...
@@ -94,18 +70,18 @@ void SoftplusActivationOneDNNPass::FuseSoftplusActivation(
}
}
auto
*
activation_op
=
activation
->
Op
();
auto
*
activation_op
=
activation
->
Op
();
auto
attr_map
=
paddle
::
platform
::
GetAttributeMap
(
act_type
);
for
(
const
auto
&
attr
:
attr_map
)
{
for
(
const
auto
&
attr
:
attr_map
)
{
if
(
activation_op
->
HasAttr
(
attr
.
first
))
{
if
(
activation_op
->
HasAttr
(
attr
.
first
))
{
softplus_op
->
SetAttr
(
attr
.
second
,
activation_op
->
GetAttr
(
attr
.
first
));
softplus_op
->
SetAttr
(
attr
.
second
,
activation_op
->
GetAttr
(
attr
.
first
));
}
}
}
}
if
(
fuse_activation_type
==
"gelu"
&&
if
(
act_type
==
"gelu"
&&
activation_op
->
HasAttr
(
"approximate"
)
&&
activation_op
->
HasAttr
(
"approximate"
)
&&
BOOST_GET_CONST
(
bool
,
activation_op
->
GetAttr
(
"approximate"
)))
BOOST_GET_CONST
(
bool
,
activation_op
->
GetAttr
(
"approximate"
)))
softplus_op
->
SetAttr
(
"fuse_activation
_type
"
,
std
::
string
(
"gelu_tanh"
));
softplus_op
->
SetAttr
(
"fuse_activation"
,
std
::
string
(
"gelu_tanh"
));
else
else
softplus_op
->
SetAttr
(
"fuse_activation
_type"
,
fuse_activation
_type
);
softplus_op
->
SetAttr
(
"fuse_activation
"
,
act
_type
);
softplus_op
->
SetAttr
(
"use_mkldnn"
,
true
);
softplus_op
->
SetAttr
(
"use_mkldnn"
,
true
);
...
@@ -121,7 +97,7 @@ void SoftplusActivationOneDNNPass::FuseSoftplusActivation(
...
@@ -121,7 +97,7 @@ void SoftplusActivationOneDNNPass::FuseSoftplusActivation(
if
(
!
Has
(
"disable_logs"
)
||
!
Get
<
bool
>
(
"disable_logs"
))
if
(
!
Has
(
"disable_logs"
)
||
!
Get
<
bool
>
(
"disable_logs"
))
PrettyLogDetail
(
"--- fused %d softplus with %s activation"
,
PrettyLogDetail
(
"--- fused %d softplus with %s activation"
,
found_softplus_activation_count
,
found_softplus_activation_count
,
fuse_activation
_type
);
act
_type
);
}
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
...
@@ -133,13 +109,16 @@ REGISTER_PASS_CAPABILITY(softplus_activation_mkldnn_fuse_pass)
...
@@ -133,13 +109,16 @@ REGISTER_PASS_CAPABILITY(softplus_activation_mkldnn_fuse_pass)
.
AddCombination
(
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
LE
(
"softplus"
,
1
)
.
LE
(
"softplus"
,
1
)
.
EQ
(
"relu"
,
0
)
.
EQ
(
"tanh"
,
0
)
.
LE
(
"leaky_relu"
,
1
)
.
EQ
(
"swish"
,
0
)
.
EQ
(
"hard_swish"
,
0
)
.
EQ
(
"sqrt"
,
0
)
.
EQ
(
"abs"
,
0
)
.
EQ
(
"abs"
,
0
)
.
LE
(
"relu6"
,
1
)
.
LE
(
"clip"
,
1
)
.
LE
(
"clip"
,
1
)
.
EQ
(
"gelu"
,
0
));
.
EQ
(
"gelu"
,
0
)
.
EQ
(
"hard_sigmoid"
,
0
)
.
LE
(
"hard_swish"
,
0
)
.
LE
(
"leaky_relu"
,
1
)
.
LE
(
"mish"
,
1
)
.
EQ
(
"relu"
,
0
)
.
EQ
(
"relu6"
,
0
)
.
EQ
(
"sigmoid"
,
0
)
.
EQ
(
"sqrt"
,
0
)
.
EQ
(
"swish"
,
0
)
.
EQ
(
"tanh"
,
0
));
paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h
浏览文件 @
826e2781
...
@@ -34,10 +34,8 @@ class SoftplusActivationOneDNNPass : public FusePassBase {
...
@@ -34,10 +34,8 @@ class SoftplusActivationOneDNNPass : public FusePassBase {
protected:
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
void
FuseSoftplusActivation
(
void
FuseSoftplusActivation
(
ir
::
Graph
*
graph
,
ir
::
Graph
*
graph
,
const
std
::
string
&
act_type
)
const
;
const
std
::
string
&
fuse_activation_type
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
attr_map
)
const
;
};
};
}
// namespace ir
}
// namespace ir
...
...
paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h
浏览文件 @
826e2781
...
@@ -50,22 +50,7 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
...
@@ -50,22 +50,7 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
private:
private:
dnnl
::
post_ops
get_post_ops
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
dnnl
::
post_ops
get_post_ops
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
dnnl
::
post_ops
post_operations
;
dnnl
::
post_ops
post_operations
;
if
(
ctx
.
HasAttr
(
"activation_type"
))
{
platform
::
AppendActivation
(
ctx
,
post_operations
);
const
float
scale
=
ctx
.
HasAttr
(
"activation_scale"
)
?
ctx
.
Attr
<
float
>
(
"activation_scale"
)
:
1.0
f
;
const
float
alpha
=
ctx
.
HasAttr
(
"activation_alpha"
)
?
ctx
.
Attr
<
float
>
(
"activation_alpha"
)
:
0.0
f
;
const
float
beta
=
ctx
.
HasAttr
(
"activation_beta"
)
?
ctx
.
Attr
<
float
>
(
"activation_beta"
)
:
0.0
f
;
const
auto
activation_algorithm
=
platform
::
AcquireActivationAlgorithm
(
ctx
.
Attr
<
std
::
string
>
(
"activation_type"
));
post_operations
.
append_eltwise
(
scale
,
activation_algorithm
,
alpha
,
beta
);
}
return
post_operations
;
return
post_operations
;
}
}
...
...
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
浏览文件 @
826e2781
...
@@ -553,10 +553,6 @@ class ConvMKLDNNHandlerT
...
@@ -553,10 +553,6 @@ class ConvMKLDNNHandlerT
dnnl
::
primitive_attr
conv_attr
;
dnnl
::
primitive_attr
conv_attr
;
dnnl
::
post_ops
post_operations
;
dnnl
::
post_ops
post_operations
;
const
std
::
string
fuse_activation
=
ctx
.
Attr
<
std
::
string
>
(
"fuse_activation"
);
const
float
fuse_alpha
=
ctx
.
Attr
<
float
>
(
"fuse_alpha"
);
const
float
fuse_beta
=
ctx
.
Attr
<
float
>
(
"fuse_beta"
);
const
bool
fuse_residual_conn
=
ctx
.
Attr
<
bool
>
(
"fuse_residual_connection"
);
const
bool
fuse_residual_conn
=
ctx
.
Attr
<
bool
>
(
"fuse_residual_connection"
);
float
sum_scale
=
1.0
f
;
float
sum_scale
=
1.0
f
;
...
@@ -587,19 +583,7 @@ class ConvMKLDNNHandlerT
...
@@ -587,19 +583,7 @@ class ConvMKLDNNHandlerT
post_operations
.
append_sum
(
sum_scale
);
post_operations
.
append_sum
(
sum_scale
);
}
}
if
(
fuse_activation
==
"hard_sigmoid"
)
{
platform
::
AppendActivation
(
ctx
,
post_operations
,
activation_scale
);
post_operations
.
append_eltwise
(
activation_scale
,
dnnl
::
algorithm
::
eltwise_linear
,
fuse_alpha
,
fuse_beta
);
post_operations
.
append_eltwise
(
activation_scale
,
dnnl
::
algorithm
::
eltwise_clip
,
0.0
f
,
1.0
f
);
}
else
if
(
fuse_activation
!=
""
)
{
const
auto
activation_algorithm
=
platform
::
AcquireActivationAlgorithm
(
fuse_activation
);
post_operations
.
append_eltwise
(
activation_scale
,
activation_algorithm
,
fuse_alpha
,
fuse_beta
);
}
conv_attr
.
set_post_ops
(
post_operations
);
conv_attr
.
set_post_ops
(
post_operations
);
return
conv_attr
;
return
conv_attr
;
...
...
paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h
浏览文件 @
826e2781
...
@@ -46,7 +46,7 @@ class SoftplusMKLDNNHandler
...
@@ -46,7 +46,7 @@ class SoftplusMKLDNNHandler
1.0
f
,
dnnl
::
algorithm
::
eltwise_linear
,
1.0
f
/
beta
,
0.0
f
);
1.0
f
,
dnnl
::
algorithm
::
eltwise_linear
,
1.0
f
/
beta
,
0.0
f
);
}
}
AppendFusedActivationIfExists
(
ctx
,
&
post_ops
);
platform
::
AppendActivation
(
ctx
,
post_ops
);
dnnl
::
primitive_attr
attrs
;
dnnl
::
primitive_attr
attrs
;
attrs
.
set_post_ops
(
post_ops
);
attrs
.
set_post_ops
(
post_ops
);
...
@@ -62,42 +62,8 @@ class SoftplusMKLDNNHandler
...
@@ -62,42 +62,8 @@ class SoftplusMKLDNNHandler
return
this
->
AcquireMemoryFromPrimitive
(
return
this
->
AcquireMemoryFromPrimitive
(
this
->
fwd_pd_
->
src1_desc
(),
platform
::
to_void_cast
<
float
>
(
beta
));
this
->
fwd_pd_
->
src1_desc
(),
platform
::
to_void_cast
<
float
>
(
beta
));
}
}
private:
void
AppendFusedActivationIfExists
(
const
framework
::
ExecutionContext
&
ctx
,
dnnl
::
post_ops
*
post_ops
)
{
const
auto
&
fused_activation_type
=
algo_map
.
find
(
ctx
.
Attr
<
std
::
string
>
(
"fuse_activation_type"
));
if
(
fused_activation_type
!=
algo_map
.
end
())
{
auto
scale_out
=
ctx
.
Attr
<
float
>
(
"fuse_activation_scale"
);
// for future int8 support
post_ops
->
append_eltwise
(
scale_out
,
fused_activation_type
->
second
,
ctx
.
Attr
<
float
>
(
"fuse_activation_alpha"
),
ctx
.
Attr
<
float
>
(
"fuse_activation_beta"
));
}
}
static
const
std
::
unordered_map
<
std
::
string
,
dnnl
::
algorithm
>
algo_map
;
};
};
template
<
typename
T
>
const
std
::
unordered_map
<
std
::
string
,
dnnl
::
algorithm
>
SoftplusMKLDNNHandler
<
T
>::
algo_map
=
{
{
"relu"
,
dnnl
::
algorithm
::
eltwise_relu
},
{
"tanh"
,
dnnl
::
algorithm
::
eltwise_tanh
},
{
"leaky_relu"
,
dnnl
::
algorithm
::
eltwise_relu
},
{
"swish"
,
dnnl
::
algorithm
::
eltwise_swish
},
{
"hardswish"
,
dnnl
::
algorithm
::
eltwise_hardswish
},
{
"sqrt"
,
dnnl
::
algorithm
::
eltwise_sqrt
},
{
"abs"
,
dnnl
::
algorithm
::
eltwise_abs
},
{
"clip"
,
dnnl
::
algorithm
::
eltwise_clip
},
{
"gelu"
,
dnnl
::
algorithm
::
eltwise_gelu_erf
},
{
"gelu_tanh"
,
dnnl
::
algorithm
::
eltwise_gelu_tanh
},
{
"relu6"
,
dnnl
::
algorithm
::
eltwise_bounded_relu
},
{
"sigmoid"
,
dnnl
::
algorithm
::
eltwise_logistic
}};
template
<
typename
T
>
template
<
typename
T
>
void
custom_softplus_eltwise_forward
(
const
framework
::
ExecutionContext
&
ctx
)
{
void
custom_softplus_eltwise_forward
(
const
framework
::
ExecutionContext
&
ctx
)
{
const
auto
&
dev_ctx
=
const
auto
&
dev_ctx
=
...
...
paddle/fluid/platform/mkldnn_reuse.h
浏览文件 @
826e2781
...
@@ -1013,9 +1013,30 @@ class ActivationMKLDNNHandler
...
@@ -1013,9 +1013,30 @@ class ActivationMKLDNNHandler
}
}
};
};
static
const
dnnl
::
algorithm
AcquireActivationAlgorithm
(
static
void
AppendActivation
(
const
framework
::
ExecutionContext
&
ctx
,
std
::
string
activation_name
)
{
dnnl
::
post_ops
&
post_ops
,
std
::
unordered_map
<
std
::
string
,
dnnl
::
algorithm
>
activation_map
=
{
float
activation_scale
=
1.0
f
)
{
const
auto
invalid_attribute
=
ctx
.
HasAttr
(
"fuse_activation"
)
?
ctx
.
Attr
<
std
::
string
>
(
"fuse_activation"
).
empty
()
:
true
;
if
(
invalid_attribute
)
return
;
const
auto
fuse_activation
=
ctx
.
Attr
<
std
::
string
>
(
"fuse_activation"
);
const
auto
fuse_alpha
=
ctx
.
HasAttr
(
"fuse_alpha"
)
?
ctx
.
Attr
<
float
>
(
"fuse_alpha"
)
:
0.0
f
;
const
auto
fuse_beta
=
ctx
.
HasAttr
(
"fuse_beta"
)
?
ctx
.
Attr
<
float
>
(
"fuse_beta"
)
:
0.0
f
;
if
(
fuse_activation
==
"hard_sigmoid"
)
{
post_ops
.
append_eltwise
(
activation_scale
,
dnnl
::
algorithm
::
eltwise_linear
,
fuse_alpha
,
fuse_beta
);
post_ops
.
append_eltwise
(
activation_scale
,
dnnl
::
algorithm
::
eltwise_clip
,
0.0
f
,
1.0
f
);
}
else
{
const
std
::
unordered_map
<
std
::
string
,
dnnl
::
algorithm
>
activation_map
=
{
{
"abs"
,
dnnl
::
algorithm
::
eltwise_abs
},
{
"abs"
,
dnnl
::
algorithm
::
eltwise_abs
},
{
"clip"
,
dnnl
::
algorithm
::
eltwise_clip
},
{
"clip"
,
dnnl
::
algorithm
::
eltwise_clip
},
{
"gelu"
,
dnnl
::
algorithm
::
eltwise_gelu_erf
},
{
"gelu"
,
dnnl
::
algorithm
::
eltwise_gelu_erf
},
...
@@ -1031,14 +1052,54 @@ static const dnnl::algorithm AcquireActivationAlgorithm(
...
@@ -1031,14 +1052,54 @@ static const dnnl::algorithm AcquireActivationAlgorithm(
{
"swish"
,
dnnl
::
algorithm
::
eltwise_swish
},
{
"swish"
,
dnnl
::
algorithm
::
eltwise_swish
},
{
"tanh"
,
dnnl
::
algorithm
::
eltwise_tanh
}};
{
"tanh"
,
dnnl
::
algorithm
::
eltwise_tanh
}};
const
auto
&
activation_type
=
activation_map
.
find
(
activation_name
);
const
auto
&
activation_type
=
activation_map
.
find
(
fuse_activation
);
PADDLE_ENFORCE_NE
(
activation_type
,
PADDLE_ENFORCE_NE
(
activation_type
,
activation_map
.
end
(),
activation_map
.
end
(),
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Activation '%s' not found in oneDNN algorithms mapper"
,
"Activation '%s' not found in oneDNN algorithms mapper"
,
activation_name
));
fuse_activation
));
return
activation_type
->
second
;
post_ops
.
append_eltwise
(
activation_scale
,
activation_type
->
second
,
fuse_alpha
,
fuse_beta
);
}
}
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
GetAttributeMap
(
std
::
string
act_type
)
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
attr_map
;
if
(
act_type
==
"swish"
)
attr_map
.
emplace
(
"beta"
,
"fuse_alpha"
);
else
if
(
act_type
==
"relu6"
)
attr_map
.
emplace
(
"threshold"
,
"fuse_alpha"
);
else
if
(
act_type
==
"hard_sigmoid"
)
{
attr_map
.
emplace
(
"slope"
,
"fuse_alpha"
);
attr_map
.
emplace
(
"offset"
,
"fuse_beta"
);
}
else
if
(
act_type
==
"clip"
)
{
attr_map
.
emplace
(
"min"
,
"fuse_alpha"
);
attr_map
.
emplace
(
"max"
,
"fuse_beta"
);
}
else
{
attr_map
.
emplace
(
"alpha"
,
"fuse_alpha"
);
attr_map
.
emplace
(
"beta"
,
"fuse_beta"
);
}
return
attr_map
;
}
static
std
::
vector
<
std
::
string
>
GetSupportedActivations
()
{
return
std
::
vector
<
std
::
string
>
{
"abs"
,
"clip"
,
"gelu"
,
"hard_sigmoid"
,
"hard_swish"
,
"leaky_relu"
,
"mish"
,
"relu"
,
"relu6"
,
"sigmoid"
,
"sqrt"
,
"swish"
,
"tanh"
};
}
}
class
ReorderMKLDNNHandler
{
class
ReorderMKLDNNHandler
{
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_softplus_activation_fuse_pass.py
浏览文件 @
826e2781
...
@@ -23,8 +23,8 @@ from paddle.fluid.core import PassVersionChecker
...
@@ -23,8 +23,8 @@ from paddle.fluid.core import PassVersionChecker
class
SoftplusActivationReluOneDNNFusePassTest
(
InferencePassTest
):
class
SoftplusActivationReluOneDNNFusePassTest
(
InferencePassTest
):
fuse_a
ctivation_a
lpha
=
None
fuse_alpha
=
None
fuse_
activation_
beta
=
None
fuse_beta
=
None
pass_name
=
'softplus_activation_mkldnn_fuse_pass'
pass_name
=
'softplus_activation_mkldnn_fuse_pass'
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -34,13 +34,13 @@ class SoftplusActivationReluOneDNNFusePassTest(InferencePassTest):
...
@@ -34,13 +34,13 @@ class SoftplusActivationReluOneDNNFusePassTest(InferencePassTest):
shape
=
[
-
1
,
3
,
100
,
100
],
shape
=
[
-
1
,
3
,
100
,
100
],
dtype
=
"float32"
)
dtype
=
"float32"
)
softplus_out
=
fluid
.
layers
.
softplus
(
data
)
softplus_out
=
fluid
.
layers
.
softplus
(
data
)
if
self
.
fuse_
activation_
beta
is
not
None
:
if
self
.
fuse_beta
is
not
None
:
activation_out
=
self
.
fuse_activation
(
activation_out
=
self
.
fuse_activation
(
softplus_out
,
softplus_out
,
self
.
fuse_activation
_alpha
,
self
.
fuse
_alpha
,
self
.
fuse_activation
_beta
)
self
.
fuse
_beta
)
elif
self
.
fuse_a
ctivation_a
lpha
is
not
None
:
elif
self
.
fuse_alpha
is
not
None
:
activation_out
=
self
.
fuse_activation
(
activation_out
=
self
.
fuse_activation
(
softplus_out
,
softplus_out
,
self
.
fuse_activation
_alpha
)
self
.
fuse
_alpha
)
else
:
else
:
activation_out
=
self
.
fuse_activation
(
softplus_out
)
activation_out
=
self
.
fuse_activation
(
softplus_out
)
...
@@ -73,7 +73,7 @@ class SoftplusActivationLeakyReluOneDNNFusePassTest(
...
@@ -73,7 +73,7 @@ class SoftplusActivationLeakyReluOneDNNFusePassTest(
def
set_params
(
self
):
def
set_params
(
self
):
self
.
fuse_activation
=
fluid
.
layers
.
leaky_relu
self
.
fuse_activation
=
fluid
.
layers
.
leaky_relu
self
.
fuse_a
ctivation_a
lpha
=
0.3
self
.
fuse_alpha
=
0.3
class
SoftplusActivationSwishOneDNNFusePassTest
(
class
SoftplusActivationSwishOneDNNFusePassTest
(
...
@@ -81,7 +81,7 @@ class SoftplusActivationSwishOneDNNFusePassTest(
...
@@ -81,7 +81,7 @@ class SoftplusActivationSwishOneDNNFusePassTest(
def
set_params
(
self
):
def
set_params
(
self
):
self
.
fuse_activation
=
fluid
.
layers
.
swish
self
.
fuse_activation
=
fluid
.
layers
.
swish
self
.
fuse_a
ctivation_a
lpha
=
3
self
.
fuse_alpha
=
3
class
SoftplusActivationHardSwishOneDNNFusePassTest
(
class
SoftplusActivationHardSwishOneDNNFusePassTest
(
...
@@ -110,8 +110,8 @@ class SoftplusActivationClipOneDNNFusePassTest(
...
@@ -110,8 +110,8 @@ class SoftplusActivationClipOneDNNFusePassTest(
def
set_params
(
self
):
def
set_params
(
self
):
self
.
fuse_activation
=
fluid
.
layers
.
clip
self
.
fuse_activation
=
fluid
.
layers
.
clip
self
.
fuse_a
ctivation_a
lpha
=
1.1
self
.
fuse_alpha
=
1.1
self
.
fuse_
activation_
beta
=
5.2
self
.
fuse_beta
=
5.2
class
SoftplusActivationGeluErfOneDNNFusePassTest
(
class
SoftplusActivationGeluErfOneDNNFusePassTest
(
...
@@ -126,7 +126,7 @@ class SoftplusActivationGeluTanhOneDNNFusePassTest(
...
@@ -126,7 +126,7 @@ class SoftplusActivationGeluTanhOneDNNFusePassTest(
def
set_params
(
self
):
def
set_params
(
self
):
self
.
fuse_activation
=
fluid
.
layers
.
gelu
self
.
fuse_activation
=
fluid
.
layers
.
gelu
self
.
fuse_a
ctivation_a
lpha
=
True
# simulated "Approximate" attr
self
.
fuse_alpha
=
True
# simulated "Approximate" attr
class
SoftplusActivationRelu6OneDNNFusePassTest
(
class
SoftplusActivationRelu6OneDNNFusePassTest
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录