Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
dd33d28d
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
dd33d28d
编写于
7月 06, 2021
作者:
W
Wangzheee
提交者:
GitHub
7月 06, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[pass_enhance] conv_elementwise_add_mkldnn_fuse_pass (#33931)
上级
ae74c404
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
178 addition
and
43 deletion
+178
-43
paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc
...d/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc
+5
-1
paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc
...mework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc
+76
-7
paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h
...amework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h
+7
-3
paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc
...ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc
+84
-32
paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc
paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc
+2
-0
paddle/fluid/operators/compat/elementwise_add.pbtxt
paddle/fluid/operators/compat/elementwise_add.pbtxt
+4
-0
未找到文件。
paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc
浏览文件 @
dd33d28d
...
@@ -51,7 +51,7 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const {
...
@@ -51,7 +51,7 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const {
VLOG
(
4
)
<<
"handle "
+
conv_type
()
+
"+"
+
activation_type
()
+
" fuse"
;
VLOG
(
4
)
<<
"handle "
+
conv_type
()
+
"+"
+
activation_type
()
+
" fuse"
;
if
(
!
IsCompat
(
subgraph
,
g
))
{
if
(
!
IsCompat
(
subgraph
,
g
))
{
LOG
(
WARNING
)
<<
"
P
ass op compat failed."
;
LOG
(
WARNING
)
<<
"
conv_activation_mkldnn_fuse_p
ass op compat failed."
;
return
;
return
;
}
}
GET_IR_NODE_FROM_SUBGRAPH
(
conv_weight
,
conv_weight
,
GET_IR_NODE_FROM_SUBGRAPH
(
conv_weight
,
conv_weight
,
...
@@ -114,6 +114,10 @@ ConvActivationFusePass::ConvActivationFusePass() {
...
@@ -114,6 +114,10 @@ ConvActivationFusePass::ConvActivationFusePass() {
.
IsOptional
()
.
IsOptional
()
.
IsTensor
()
.
IsTensor
()
.
End
()
.
End
()
.
AddInput
(
"ResidualData"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Output"
)
.
AddOutput
(
"Output"
)
.
IsTensor
()
.
IsTensor
()
.
End
()
.
End
()
...
...
paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc
浏览文件 @
dd33d28d
...
@@ -81,16 +81,72 @@ boost::optional<T> HasAttribute(const Node& op, const std::string& attr) {
...
@@ -81,16 +81,72 @@ boost::optional<T> HasAttribute(const Node& op, const std::string& attr) {
return
boost
::
none
;
return
boost
::
none
;
}
}
ResidualConnectionMKLDNNFusePass
::
ResidualConnectionMKLDNNFusePass
()
{
AddOpCompat
(
OpCompat
(
"conv2d"
))
.
AddInput
(
"Input"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Filter"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddInput
(
"ResidualData"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Output"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"strides"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"paddings"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"padding_algorithm"
)
.
IsOptional
()
.
IsStringIn
({
"EXPLICIT"
,
"SAME"
,
"VALID"
})
.
End
()
.
AddAttr
(
"groups"
)
.
IsNumGE
(
1
)
.
End
()
.
AddAttr
(
"dilations"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"data_format"
)
.
IsStringIn
({
"NCHW"
,
"NHWC"
,
"AnyLayout"
})
.
End
();
AddOpCompat
(
OpCompat
(
"elementwise_add"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
IsIntIn
({
-
1
,
0
})
.
End
();
}
ResidualConnectionMKLDNNFusePass
::
IdentityFuseHandle
::
IdentityFuseHandle
(
ResidualConnectionMKLDNNFusePass
::
IdentityFuseHandle
::
IdentityFuseHandle
(
const
ResidualConnectionMKLDNNFusePass
::
CanFuseFunc
&
can_fuse_func
,
const
ResidualConnectionMKLDNNFusePass
::
CanFuseFunc
&
can_fuse_func
,
const
ResidualConnectionMKLDNNFusePass
::
IdentityConvFunc
&
const
ResidualConnectionMKLDNNFusePass
::
IdentityConvFunc
&
get_node_from_conv_op
,
get_node_from_conv_op
,
const
ResidualConnectionMKLDNNFusePass
::
IdentityElementwiseAddFunc
&
const
ResidualConnectionMKLDNNFusePass
::
IdentityElementwiseAddFunc
&
get_node_from_elementwise_add_op
)
get_node_from_elementwise_add_op
,
const
ResidualConnectionMKLDNNFusePass
*
pass
)
:
fusion_stats
{
std
::
make_shared
<
int
>
(
0
)},
:
fusion_stats
{
std
::
make_shared
<
int
>
(
0
)},
can_fuse_func
{
can_fuse_func
},
can_fuse_func
{
can_fuse_func
},
get_node_from_conv_op
{
get_node_from_conv_op
},
get_node_from_conv_op
{
get_node_from_conv_op
},
get_node_from_elementwise_add_op
{
get_node_from_elementwise_add_op
}
{}
get_node_from_elementwise_add_op
{
get_node_from_elementwise_add_op
},
pass_
{
pass
}
{}
void
ResidualConnectionMKLDNNFusePass
::
IdentityFuseHandle
::
operator
()(
void
ResidualConnectionMKLDNNFusePass
::
IdentityFuseHandle
::
operator
()(
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
...
@@ -102,6 +158,11 @@ void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()(
...
@@ -102,6 +158,11 @@ void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()(
Node
*
elementwise_add_op
;
Node
*
elementwise_add_op
;
Node
*
elementwise_add_identity
;
Node
*
elementwise_add_identity
;
Node
*
elementwise_add_out
;
Node
*
elementwise_add_out
;
if
(
!
pass_
->
IsCompat
(
subgraph
,
graph
))
{
LOG
(
WARNING
)
<<
"conv_elementwise_add_mkldnn_fuse_pass in op compat failed."
;
return
;
}
std
::
tie
(
conv_op
,
conv_input
,
conv_filter
,
conv_output
)
=
std
::
tie
(
conv_op
,
conv_input
,
conv_filter
,
conv_output
)
=
get_node_from_conv_op
(
subgraph
);
get_node_from_conv_op
(
subgraph
);
...
@@ -133,12 +194,14 @@ ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::ProjectionFuseHandle(
...
@@ -133,12 +194,14 @@ ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::ProjectionFuseHandle(
const
ResidualConnectionMKLDNNFusePass
::
ProjectionConvFunc
&
const
ResidualConnectionMKLDNNFusePass
::
ProjectionConvFunc
&
get_node_from_conv_y_op
,
get_node_from_conv_y_op
,
const
ResidualConnectionMKLDNNFusePass
::
ProjectionElementwiseAddFunc
&
const
ResidualConnectionMKLDNNFusePass
::
ProjectionElementwiseAddFunc
&
get_node_from_elementwise_add_op
)
get_node_from_elementwise_add_op
,
const
ResidualConnectionMKLDNNFusePass
*
pass
)
:
fusion_stats
{
std
::
make_shared
<
int
>
(
0
)},
:
fusion_stats
{
std
::
make_shared
<
int
>
(
0
)},
can_fuse_func
{
can_fuse_func
},
can_fuse_func
{
can_fuse_func
},
get_node_from_conv_x_op
{
get_node_from_conv_x_op
},
get_node_from_conv_x_op
{
get_node_from_conv_x_op
},
get_node_from_conv_y_op
{
get_node_from_conv_y_op
},
get_node_from_conv_y_op
{
get_node_from_conv_y_op
},
get_node_from_elementwise_add_op
{
get_node_from_elementwise_add_op
}
{}
get_node_from_elementwise_add_op
{
get_node_from_elementwise_add_op
},
pass_
{
pass
}
{}
void
ResidualConnectionMKLDNNFusePass
::
ProjectionFuseHandle
::
operator
()(
void
ResidualConnectionMKLDNNFusePass
::
ProjectionFuseHandle
::
operator
()(
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
...
@@ -155,6 +218,12 @@ void ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::operator()(
...
@@ -155,6 +218,12 @@ void ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::operator()(
Node
*
elementwise_add_op
;
Node
*
elementwise_add_op
;
Node
*
elementwise_add_out
;
Node
*
elementwise_add_out
;
if
(
!
pass_
->
IsCompat
(
subgraph
,
graph
))
{
LOG
(
WARNING
)
<<
"conv_elementwise_add_mkldnn_fuse_pass in op compat failed."
;
return
;
}
std
::
tie
(
conv_x_op
,
conv_x_input
,
conv_x_filter
,
conv_x_output
)
=
std
::
tie
(
conv_x_op
,
conv_x_input
,
conv_x_filter
,
conv_x_output
)
=
get_node_from_conv_x_op
(
subgraph
);
get_node_from_conv_x_op
(
subgraph
);
std
::
tie
(
conv_y_op
,
conv_y_input
,
conv_y_filter
,
conv_y_output
)
=
std
::
tie
(
conv_y_op
,
conv_y_input
,
conv_y_filter
,
conv_y_output
)
=
...
@@ -247,7 +316,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
...
@@ -247,7 +316,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
[
this
,
&
conv_pattern
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
{
[
this
,
&
conv_pattern
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
{
return
GetNodesFromConv
(
conv_pattern
,
subgraph
);
return
GetNodesFromConv
(
conv_pattern
,
subgraph
);
},
},
get_node_from_elementwise_add
);
get_node_from_elementwise_add
,
this
);
}
}
GraphWithStats
ResidualConnectionMKLDNNFusePass
::
FuseConvAsY
(
GraphWithStats
ResidualConnectionMKLDNNFusePass
::
FuseConvAsY
(
...
@@ -284,7 +353,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
...
@@ -284,7 +353,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
[
this
,
&
conv_pattern
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
{
[
this
,
&
conv_pattern
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
{
return
GetNodesFromConv
(
conv_pattern
,
subgraph
);
return
GetNodesFromConv
(
conv_pattern
,
subgraph
);
},
},
get_node_from_elementwise_add
);
get_node_from_elementwise_add
,
this
);
}
}
GraphWithStats
ResidualConnectionMKLDNNFusePass
::
FuseProjectionConv
(
GraphWithStats
ResidualConnectionMKLDNNFusePass
::
FuseProjectionConv
(
...
@@ -325,7 +394,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
...
@@ -325,7 +394,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
&
conv_y_pattern
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
{
&
conv_y_pattern
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
)
{
return
GetNodesFromConv
(
conv_y_pattern
,
subgraph
);
return
GetNodesFromConv
(
conv_y_pattern
,
subgraph
);
},
},
get_node_from_elementwise_add
);
get_node_from_elementwise_add
,
this
);
}
}
void
ResidualConnectionMKLDNNFusePass
::
ApplyImpl
(
graph_ptr
graph
)
const
{
void
ResidualConnectionMKLDNNFusePass
::
ApplyImpl
(
graph_ptr
graph
)
const
{
...
...
paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h
浏览文件 @
dd33d28d
...
@@ -84,7 +84,6 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
...
@@ -84,7 +84,6 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
auto
can_fuse
=
[
this
](
Node
*
op1
,
Node
*
op2
)
->
bool
{
auto
can_fuse
=
[
this
](
Node
*
op1
,
Node
*
op2
)
->
bool
{
return
this
->
FindFuseOption
(
*
op1
,
*
op2
)
==
FUSE_MKLDNN
;
return
this
->
FindFuseOption
(
*
op1
,
*
op2
)
==
FUSE_MKLDNN
;
};
};
auto
fuse_handle
=
HandleType
{
can_fuse
,
std
::
forward
<
OpFuncs
>
(
op_funcs
)...};
auto
fuse_handle
=
HandleType
{
can_fuse
,
std
::
forward
<
OpFuncs
>
(
op_funcs
)...};
(
*
gpd
)(
graph
,
fuse_handle
);
(
*
gpd
)(
graph
,
fuse_handle
);
...
@@ -96,7 +95,8 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
...
@@ -96,7 +95,8 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
IdentityFuseHandle
(
IdentityFuseHandle
(
const
CanFuseFunc
&
can_fuse_func
,
const
CanFuseFunc
&
can_fuse_func
,
const
IdentityConvFunc
&
get_node_from_conv_op
,
const
IdentityConvFunc
&
get_node_from_conv_op
,
const
IdentityElementwiseAddFunc
&
get_node_from_elementwise_add_op
);
const
IdentityElementwiseAddFunc
&
get_node_from_elementwise_add_op
,
const
ResidualConnectionMKLDNNFusePass
*
pass
);
void
operator
()(
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
void
operator
()(
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
);
Graph
*
graph
);
...
@@ -107,6 +107,7 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
...
@@ -107,6 +107,7 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
CanFuseFunc
can_fuse_func
;
CanFuseFunc
can_fuse_func
;
IdentityConvFunc
get_node_from_conv_op
;
IdentityConvFunc
get_node_from_conv_op
;
IdentityElementwiseAddFunc
get_node_from_elementwise_add_op
;
IdentityElementwiseAddFunc
get_node_from_elementwise_add_op
;
const
ResidualConnectionMKLDNNFusePass
*
pass_
;
};
};
struct
ProjectionFuseHandle
{
struct
ProjectionFuseHandle
{
...
@@ -114,7 +115,8 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
...
@@ -114,7 +115,8 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
const
CanFuseFunc
&
can_fuse_func
,
const
CanFuseFunc
&
can_fuse_func
,
const
ProjectionConvFunc
&
get_node_from_conv_x_op
,
const
ProjectionConvFunc
&
get_node_from_conv_x_op
,
const
ProjectionConvFunc
&
get_node_from_conv_y_op
,
const
ProjectionConvFunc
&
get_node_from_conv_y_op
,
const
ProjectionElementwiseAddFunc
&
get_node_from_elementwise_add_op
);
const
ProjectionElementwiseAddFunc
&
get_node_from_elementwise_add_op
,
const
ResidualConnectionMKLDNNFusePass
*
pass
);
void
operator
()(
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
void
operator
()(
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
);
Graph
*
graph
);
...
@@ -126,9 +128,11 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
...
@@ -126,9 +128,11 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
ProjectionConvFunc
get_node_from_conv_x_op
;
ProjectionConvFunc
get_node_from_conv_x_op
;
ProjectionConvFunc
get_node_from_conv_y_op
;
ProjectionConvFunc
get_node_from_conv_y_op
;
ProjectionElementwiseAddFunc
get_node_from_elementwise_add_op
;
ProjectionElementwiseAddFunc
get_node_from_elementwise_add_op
;
const
ResidualConnectionMKLDNNFusePass
*
pass_
;
};
};
public:
public:
ResidualConnectionMKLDNNFusePass
();
virtual
~
ResidualConnectionMKLDNNFusePass
()
{}
virtual
~
ResidualConnectionMKLDNNFusePass
()
{}
protected:
protected:
...
...
paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc
浏览文件 @
dd33d28d
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/pass_test_util.h"
#include "paddle/fluid/framework/ir/pass_test_util.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -25,15 +26,66 @@ namespace ir {
...
@@ -25,15 +26,66 @@ namespace ir {
constexpr
int
nodes_removed
=
3
;
constexpr
int
nodes_removed
=
3
;
constexpr
int
nodes_added
=
1
;
constexpr
int
nodes_added
=
1
;
OpDesc
*
Create_Op_con2d
(
ProgramDesc
*
prog
,
const
std
::
string
&
op_type_name
,
const
std
::
vector
<
test
::
InOutVarNamePair
>&
inputs
,
const
std
::
vector
<
test
::
InOutVarNamePair
>&
outputs
,
const
bool
use_mkldnn
=
true
)
{
auto
*
op
=
prog
->
MutableBlock
(
0
)
->
AppendOp
();
const
std
::
vector
<
int
>
strides
({
1
,
1
});
const
std
::
vector
<
int
>
paddings
({
0
,
0
});
const
std
::
vector
<
int
>
dilations
({
1
,
1
});
op
->
SetType
(
op_type_name
);
op
->
SetAttr
(
"use_mkldnn"
,
use_mkldnn
);
op
->
SetAttr
(
"strides"
,
strides
);
op
->
SetAttr
(
"groups"
,
1
);
op
->
SetAttr
(
"paddings"
,
paddings
);
op
->
SetAttr
(
"padding_algorithm"
,
std
::
string
(
"EXPLICIT"
));
op
->
SetAttr
(
"dilations"
,
dilations
);
op
->
SetAttr
(
"data_format"
,
std
::
string
(
"NCHW"
));
for
(
const
auto
&
input
:
inputs
)
{
op
->
SetInput
(
input
.
first
,
{
input
.
second
});
}
for
(
const
auto
&
output
:
outputs
)
{
op
->
SetOutput
(
output
.
first
,
{
output
.
second
});
}
op
->
SetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
static_cast
<
int
>
(
OpRole
::
kForward
));
return
op
;
}
OpDesc
*
Create_Op_elemntwise_add
(
ProgramDesc
*
prog
,
const
std
::
string
&
op_type_name
,
const
std
::
vector
<
test
::
InOutVarNamePair
>&
inputs
,
const
std
::
vector
<
test
::
InOutVarNamePair
>&
outputs
,
bool
use_mkldnn
=
true
)
{
auto
*
op
=
prog
->
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
op_type_name
);
op
->
SetAttr
(
"use_mkldnn"
,
use_mkldnn
);
op
->
SetAttr
(
"axis"
,
-
1
);
for
(
const
auto
&
input
:
inputs
)
{
op
->
SetInput
(
input
.
first
,
{
input
.
second
});
}
for
(
const
auto
&
output
:
outputs
)
{
op
->
SetOutput
(
output
.
first
,
{
output
.
second
});
}
op
->
SetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
static_cast
<
int
>
(
OpRole
::
kForward
));
return
op
;
}
TEST
(
ConvElementwiseAddMKLDNNFusePass
,
ConvolutionAsYWithElementwiseAddRelu
)
{
TEST
(
ConvElementwiseAddMKLDNNFusePass
,
ConvolutionAsYWithElementwiseAddRelu
)
{
auto
prog
=
auto
prog
=
test
::
BuildProgramDesc
({
"a"
,
"b"
,
"c"
,
"d"
,
"e"
},
{
"bias"
,
"weights"
});
test
::
BuildProgramDesc
({
"a"
,
"b"
,
"c"
,
"d"
,
"e"
},
{
"bias"
,
"weights"
});
test
::
CreateOp
(
&
prog
,
"sigmoid"
,
{{
"X"
,
"a"
}},
{{
"Out"
,
"b"
}});
test
::
CreateOp
(
&
prog
,
"sigmoid"
,
{{
"X"
,
"a"
}},
{{
"Out"
,
"b"
}});
test
::
CreateOp
(
&
prog
,
"conv2d"
,
Create_Op_con2d
(
&
prog
,
"conv2d"
,
{{
"Input"
,
"b"
},
{
"Bias"
,
"bias"
},
{
"Filter"
,
"weights"
}},
{{
"Input"
,
"b"
},
{
"Bias"
,
"bias"
},
{
"Filter"
,
"weights"
}},
{{
"Output"
,
"c"
}});
{{
"Output"
,
"c"
}});
test
::
CreateOp
(
&
prog
,
"elementwise_add"
,
{{
"X"
,
"a"
},
{
"Y"
,
"c"
}},
Create_Op_elemntwise_add
(
&
prog
,
"elementwise_add"
,
{{
"X"
,
"a"
},
{
"Y"
,
"c"
}},
{{
"Out"
,
"d"
}});
{{
"Out"
,
"d"
}});
test
::
CreateOp
(
&
prog
,
"relu"
,
{{
"X"
,
"d"
}},
{{
"Out"
,
"e"
}});
test
::
CreateOp
(
&
prog
,
"relu"
,
{{
"X"
,
"d"
}},
{{
"Out"
,
"e"
}});
...
@@ -53,16 +105,16 @@ TEST(ConvElementwiseAddMKLDNNFusePass,
...
@@ -53,16 +105,16 @@ TEST(ConvElementwiseAddMKLDNNFusePass,
test
::
CreateOp
(
&
prog
,
"sigmoid"
,
{{
"X"
,
"a"
}},
{{
"Out"
,
"b"
}});
test
::
CreateOp
(
&
prog
,
"sigmoid"
,
{{
"X"
,
"a"
}},
{{
"Out"
,
"b"
}});
// right branch
// right branch
test
::
CreateOp
(
&
prog
,
"conv2d"
,
Create_Op_con2d
(
&
prog
,
"conv2d"
,
{{
"Input"
,
"b"
},
{
"Bias"
,
"bias"
},
{
"Filter"
,
"weights"
}},
{{
"Input"
,
"b"
},
{
"Bias"
,
"bias"
},
{
"Filter"
,
"weights"
}},
{{
"Output"
,
"c"
}});
{{
"Output"
,
"c"
}});
// left branch
// left branch
test
::
CreateOp
(
&
prog
,
"conv2d"
,
Create_Op_con2d
(
&
prog
,
"conv2d"
,
{{
"Input"
,
"a"
},
{
"Bias"
,
"bias2"
},
{
"Filter"
,
"weights2"
}},
{{
"Input"
,
"a"
},
{
"Bias"
,
"bias2"
},
{
"Filter"
,
"weights2"
}},
{{
"Output"
,
"f"
}});
{{
"Output"
,
"f"
}});
test
::
CreateOp
(
&
prog
,
"elementwise_add"
,
{{
"X"
,
"f"
},
{
"Y"
,
"c"
}},
Create_Op_elemntwise_add
(
&
prog
,
"elementwise_add"
,
{{
"X"
,
"f"
},
{
"Y"
,
"c"
}},
{{
"Out"
,
"d"
}});
{{
"Out"
,
"d"
}});
test
::
CreateOp
(
&
prog
,
"relu"
,
{{
"X"
,
"d"
}},
{{
"Out"
,
"e"
}});
test
::
CreateOp
(
&
prog
,
"relu"
,
{{
"X"
,
"d"
}},
{{
"Out"
,
"e"
}});
...
@@ -80,9 +132,9 @@ TEST(ConvElementwiseAddMKLDNNFusePass,
...
@@ -80,9 +132,9 @@ TEST(ConvElementwiseAddMKLDNNFusePass,
auto
prog
=
test
::
BuildProgramDesc
({
"a"
,
"b"
,
"c"
,
"d"
,
"e"
},
{
"weights"
});
auto
prog
=
test
::
BuildProgramDesc
({
"a"
,
"b"
,
"c"
,
"d"
,
"e"
},
{
"weights"
});
test
::
CreateOp
(
&
prog
,
"sigmoid"
,
{{
"X"
,
"a"
}},
{{
"Out"
,
"b"
}});
test
::
CreateOp
(
&
prog
,
"sigmoid"
,
{{
"X"
,
"a"
}},
{{
"Out"
,
"b"
}});
test
::
CreateOp
(
&
prog
,
"conv2d"
,
{{
"Input"
,
"b"
},
{
"Filter"
,
"weights"
}},
Create_Op_con2d
(
&
prog
,
"conv2d"
,
{{
"Input"
,
"b"
},
{
"Filter"
,
"weights"
}},
{{
"Output"
,
"c"
}});
{{
"Output"
,
"c"
}});
test
::
CreateOp
(
&
prog
,
"elementwise_add"
,
{{
"X"
,
"a"
},
{
"Y"
,
"c"
}},
Create_Op_elemntwise_add
(
&
prog
,
"elementwise_add"
,
{{
"X"
,
"a"
},
{
"Y"
,
"c"
}},
{{
"Out"
,
"d"
}});
{{
"Out"
,
"d"
}});
test
::
CreateOp
(
&
prog
,
"relu"
,
{{
"X"
,
"d"
}},
{{
"Out"
,
"e"
}});
test
::
CreateOp
(
&
prog
,
"relu"
,
{{
"X"
,
"d"
}},
{{
"Out"
,
"e"
}});
...
@@ -100,11 +152,11 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsXWithElementwiseAddRelu) {
...
@@ -100,11 +152,11 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsXWithElementwiseAddRelu) {
test
::
BuildProgramDesc
({
"a"
,
"b"
,
"c"
,
"d"
,
"e"
},
{
"bias"
,
"weights"
});
test
::
BuildProgramDesc
({
"a"
,
"b"
,
"c"
,
"d"
,
"e"
},
{
"bias"
,
"weights"
});
test
::
CreateOp
(
&
prog
,
"sigmoid"
,
{{
"X"
,
"a"
}},
{{
"Out"
,
"b"
}});
test
::
CreateOp
(
&
prog
,
"sigmoid"
,
{{
"X"
,
"a"
}},
{{
"Out"
,
"b"
}});
test
::
CreateOp
(
&
prog
,
"conv2d"
,
Create_Op_con2d
(
&
prog
,
"conv2d"
,
{{
"Input"
,
"b"
},
{
"Bias"
,
"bias"
},
{
"Filter"
,
"weights"
}},
{{
"Input"
,
"b"
},
{
"Bias"
,
"bias"
},
{
"Filter"
,
"weights"
}},
{{
"Output"
,
"c"
}});
{{
"Output"
,
"c"
}});
test
::
CreateOp
(
&
prog
,
"elementwise_add"
,
{{
"X"
,
"c"
},
{
"Y"
,
"a"
}},
Create_Op_elemntwise_add
(
&
prog
,
"elementwise_add"
,
{{
"X"
,
"c"
},
{
"Y"
,
"a"
}},
{{
"Out"
,
"d"
}});
{{
"Out"
,
"d"
}});
test
::
CreateOp
(
&
prog
,
"relu"
,
{{
"X"
,
"d"
}},
{{
"Out"
,
"e"
}});
test
::
CreateOp
(
&
prog
,
"relu"
,
{{
"X"
,
"d"
}},
{{
"Out"
,
"e"
}});
...
@@ -122,9 +174,9 @@ TEST(ConvElementwiseAddMKLDNNFusePass,
...
@@ -122,9 +174,9 @@ TEST(ConvElementwiseAddMKLDNNFusePass,
auto
prog
=
test
::
BuildProgramDesc
({
"a"
,
"b"
,
"c"
,
"d"
,
"e"
},
{
"weights"
});
auto
prog
=
test
::
BuildProgramDesc
({
"a"
,
"b"
,
"c"
,
"d"
,
"e"
},
{
"weights"
});
test
::
CreateOp
(
&
prog
,
"sigmoid"
,
{{
"X"
,
"a"
}},
{{
"Out"
,
"b"
}});
test
::
CreateOp
(
&
prog
,
"sigmoid"
,
{{
"X"
,
"a"
}},
{{
"Out"
,
"b"
}});
test
::
CreateOp
(
&
prog
,
"conv2d"
,
{{
"Input"
,
"b"
},
{
"Filter"
,
"weights"
}},
Create_Op_con2d
(
&
prog
,
"conv2d"
,
{{
"Input"
,
"b"
},
{
"Filter"
,
"weights"
}},
{{
"Output"
,
"c"
}});
{{
"Output"
,
"c"
}});
test
::
CreateOp
(
&
prog
,
"elementwise_add"
,
{{
"X"
,
"c"
},
{
"Y"
,
"a"
}},
Create_Op_elemntwise_add
(
&
prog
,
"elementwise_add"
,
{{
"X"
,
"c"
},
{
"Y"
,
"a"
}},
{{
"Out"
,
"d"
}});
{{
"Out"
,
"d"
}});
test
::
CreateOp
(
&
prog
,
"relu"
,
{{
"X"
,
"d"
}},
{{
"Out"
,
"e"
}});
test
::
CreateOp
(
&
prog
,
"relu"
,
{{
"X"
,
"d"
}},
{{
"Out"
,
"e"
}});
...
@@ -142,13 +194,13 @@ TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) {
...
@@ -142,13 +194,13 @@ TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) {
test
::
BuildProgramDesc
({
"a"
,
"b"
,
"c"
,
"d"
,
"e"
,
"f"
,
"g"
},
{
"weights"
});
test
::
BuildProgramDesc
({
"a"
,
"b"
,
"c"
,
"d"
,
"e"
,
"f"
,
"g"
},
{
"weights"
});
test
::
CreateOp
(
&
prog
,
"sigmoid"
,
{{
"X"
,
"a"
}},
{{
"Out"
,
"b"
}});
test
::
CreateOp
(
&
prog
,
"sigmoid"
,
{{
"X"
,
"a"
}},
{{
"Out"
,
"b"
}});
test
::
CreateOp
(
&
prog
,
"conv2d"
,
{{
"Input"
,
"b"
},
{
"Filter"
,
"weights"
}},
Create_Op_con2d
(
&
prog
,
"conv2d"
,
{{
"Input"
,
"b"
},
{
"Filter"
,
"weights"
}},
{{
"Output"
,
"c"
}});
{{
"Output"
,
"c"
}});
test
::
CreateOp
(
&
prog
,
"conv2d"
,
{{
"Input"
,
"d"
},
{
"Filter"
,
"weights"
}},
Create_Op_con2d
(
&
prog
,
"conv2d"
,
{{
"Input"
,
"d"
},
{
"Filter"
,
"weights"
}},
{{
"Output"
,
"e"
}});
{{
"Output"
,
"e"
}});
test
::
CreateOp
(
&
prog
,
"elementwise_add"
,
{{
"X"
,
"c"
},
{
"Y"
,
"e"
}},
Create_Op_elemntwise_add
(
&
prog
,
"elementwise_add"
,
{{
"X"
,
"c"
},
{
"Y"
,
"e"
}},
{{
"Out"
,
"f"
}});
{{
"Out"
,
"f"
}});
test
::
CreateOp
(
&
prog
,
"relu"
,
{{
"X"
,
"f"
}},
{{
"Out"
,
"g"
}});
test
::
CreateOp
(
&
prog
,
"relu"
,
{{
"X"
,
"f"
}},
{{
"Out"
,
"g"
}});
...
...
paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc
浏览文件 @
dd33d28d
...
@@ -67,6 +67,7 @@ CPUQuantizeSquashPass::CPUQuantizeSquashPass() {
...
@@ -67,6 +67,7 @@ CPUQuantizeSquashPass::CPUQuantizeSquashPass() {
.
AddAttr
(
"paddings"
)
.
AddAttr
(
"paddings"
)
.
End
()
.
End
()
.
AddAttr
(
"padding_algorithm"
)
.
AddAttr
(
"padding_algorithm"
)
.
IsOptional
()
.
IsStringIn
({
"EXPLICIT"
,
"SAME"
,
"VALID"
})
.
IsStringIn
({
"EXPLICIT"
,
"SAME"
,
"VALID"
})
.
End
()
.
End
()
.
AddAttr
(
"groups"
)
.
AddAttr
(
"groups"
)
...
@@ -75,6 +76,7 @@ CPUQuantizeSquashPass::CPUQuantizeSquashPass() {
...
@@ -75,6 +76,7 @@ CPUQuantizeSquashPass::CPUQuantizeSquashPass() {
.
AddAttr
(
"dilations"
)
.
AddAttr
(
"dilations"
)
.
End
()
.
End
()
.
AddAttr
(
"data_format"
)
.
AddAttr
(
"data_format"
)
.
IsOptional
()
.
IsStringIn
({
"NCHW"
,
"NHWC"
})
.
IsStringIn
({
"NCHW"
,
"NHWC"
})
.
End
();
.
End
();
}
}
...
...
paddle/fluid/operators/compat/elementwise_add.pbtxt
浏览文件 @
dd33d28d
...
@@ -15,6 +15,10 @@ def {
...
@@ -15,6 +15,10 @@ def {
}
}
}
}
extra {
extra {
attrs {
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
type: BOOLEAN
}
attrs {
attrs {
name: "out_threshold"
name: "out_threshold"
type: FLOAT
type: FLOAT
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录