Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e5e0b726
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e5e0b726
编写于
4月 04, 2022
作者:
S
Sławomir Siwek
提交者:
GitHub
4月 04, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
conv + elementwise_add refactor (#41286)
* DRY * change nodes names * add const prefix * change asX to as_x in all files
上级
75a17cdb
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
113 addition
and
295 deletion
+113
-295
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+23
-0
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+16
-0
paddle/fluid/framework/ir/graph_traits.cc
paddle/fluid/framework/ir/graph_traits.cc
+48
-0
paddle/fluid/framework/ir/graph_traits.h
paddle/fluid/framework/ir/graph_traits.h
+3
-0
paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc
...mework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc
+19
-147
paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h
...amework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h
+3
-13
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_elementwise_add_fuse_pass.py
...r/inference/test_mkldnn_conv_elementwise_add_fuse_pass.py
+1
-135
未找到文件。
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
e5e0b726
...
...
@@ -2069,6 +2069,29 @@ PDNode *patterns::Elementwise::operator()(PDNode *x_var, PDNode *y_var,
return
out_var
;
}
PDNode
*
patterns
::
ResidualElementwise
::
operator
()(
PDNode
*
op_var
,
PDNode
*
residual_var
,
const
std
::
string
elementwise_type
,
bool
as_x
)
{
auto
elementwise_op
=
pattern
->
NewNode
(
elementwise_op_repr
())
->
assert_is_op
(
elementwise_type
);
if
(
as_x
)
{
op_var
->
AsInput
()
->
assert_is_op_input
(
elementwise_type
,
"X"
);
residual_var
->
AsInput
()
->
assert_is_op_input
(
elementwise_type
,
"Y"
);
}
else
{
op_var
->
AsInput
()
->
assert_is_op_input
(
elementwise_type
,
"Y"
);
residual_var
->
AsInput
()
->
assert_is_op_input
(
elementwise_type
,
"X"
);
}
auto
out_var
=
pattern
->
NewNode
(
elementwise_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
elementwise_type
,
"Out"
);
elementwise_op
->
LinksFrom
({
op_var
,
residual_var
});
elementwise_op
->
LinksTo
({
out_var
});
return
out_var
;
}
PDNode
*
patterns
::
Concat
::
operator
()()
{
auto
concat_op
=
pattern
->
NewNode
(
concat_op_repr
())
->
assert_is_op
(
"concat"
);
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
e5e0b726
...
...
@@ -1032,6 +1032,22 @@ struct Elementwise : public PatternBase {
PATTERN_DECL_NODE
(
elementwise_out
);
};
// Residual Elementwise ops
// This pattern allows operator output to be X or Y
// and residual data Y or X, based on as_x flag
struct
ResidualElementwise
:
public
PatternBase
{
ResidualElementwise
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
bool
as_x
)
:
PatternBase
(
pattern
,
name_scope
,
"residual_elementwise"
)
{}
PDNode
*
operator
()(
PDNode
*
op_var
,
PDNode
*
residual_var
,
const
std
::
string
elementwise_type
,
bool
as_x
);
PATTERN_DECL_NODE
(
operator_output
);
PATTERN_DECL_NODE
(
residual_data
);
PATTERN_DECL_NODE
(
elementwise_op
);
PATTERN_DECL_NODE
(
elementwise_out
);
};
// Transpose op
// Forward pass for transpose.
// transpose_out is a result of the operator.
...
...
paddle/fluid/framework/ir/graph_traits.cc
浏览文件 @
e5e0b726
...
...
@@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <list>
#include <map>
#include "paddle/fluid/framework/ir/graph_traits.h"
namespace
paddle
{
...
...
@@ -23,6 +26,51 @@ namespace ir {
//
class
Node
;
bool
IsReachable
(
ir
::
Graph
*
graph
,
Node
*
from
,
Node
*
to
)
{
if
(
from
==
to
)
{
return
true
;
}
std
::
map
<
Node
*
,
bool
>
visited
;
for
(
auto
&
node
:
GraphTraits
::
DFS
(
*
graph
))
{
visited
[
&
node
]
=
false
;
}
visited
[
from
]
=
true
;
std
::
list
<
Node
*>
queue
;
queue
.
push_back
(
from
);
while
(
!
queue
.
empty
())
{
auto
cur
=
FindNode
(
graph
,
queue
.
front
());
queue
.
pop_front
();
if
(
!
cur
)
return
false
;
for
(
const
auto
&
n
:
cur
->
outputs
)
{
if
(
n
==
to
)
{
return
true
;
}
if
(
!
visited
[
n
])
{
visited
[
n
]
=
true
;
queue
.
push_back
(
n
);
}
}
}
return
false
;
}
Node
*
FindNode
(
ir
::
Graph
*
graph
,
const
Node
*
node
)
{
for
(
const
auto
&
n
:
graph
->
Nodes
())
{
if
(
n
==
node
)
{
return
n
;
}
}
return
nullptr
;
}
NodesDFSIterator
::
NodesDFSIterator
(
const
std
::
vector
<
Node
*>
&
source
)
{
for
(
auto
*
x
:
source
)
stack_
.
push
(
x
);
}
...
...
paddle/fluid/framework/ir/graph_traits.h
浏览文件 @
e5e0b726
...
...
@@ -29,6 +29,9 @@ namespace ir {
class
Graph
;
class
Node
;
bool
IsReachable
(
ir
::
Graph
*
graph
,
Node
*
from
,
Node
*
to
);
Node
*
FindNode
(
ir
::
Graph
*
graph
,
const
Node
*
node
);
template
<
typename
IteratorT
>
class
iterator_range
{
IteratorT
begin_
,
end_
;
...
...
paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc
浏览文件 @
e5e0b726
...
...
@@ -14,12 +14,6 @@
#include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h"
#include <functional>
#include <list>
#include <map>
#include <memory>
#include <tuple>
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/string/pretty_log.h"
...
...
@@ -28,60 +22,6 @@ namespace paddle {
namespace
framework
{
namespace
ir
{
bool
IsReachable
(
ir
::
Graph
*
graph
,
Node
*
from
,
Node
*
to
)
{
auto
find_node
=
[](
ir
::
Graph
*
graph
,
const
Node
*
node
)
->
Node
*
{
for
(
auto
n
:
graph
->
Nodes
())
{
if
(
n
==
node
)
{
return
n
;
}
}
return
nullptr
;
};
if
(
from
==
to
)
{
return
true
;
}
std
::
map
<
Node
*
,
bool
>
visited
;
for
(
auto
&
node
:
GraphTraits
::
DFS
(
*
graph
))
{
visited
[
&
node
]
=
false
;
}
visited
[
from
]
=
true
;
std
::
list
<
Node
*>
queue
;
queue
.
push_back
(
from
);
while
(
!
queue
.
empty
())
{
auto
cur
=
find_node
(
graph
,
queue
.
front
());
queue
.
pop_front
();
if
(
!
cur
)
return
false
;
for
(
auto
n
:
cur
->
outputs
)
{
if
(
n
==
to
)
{
return
true
;
}
if
(
!
visited
[
n
])
{
visited
[
n
]
=
true
;
queue
.
push_back
(
n
);
}
}
}
return
false
;
}
template
<
typename
T
>
paddle
::
optional
<
T
>
HasAttribute
(
const
Node
&
op
,
const
std
::
string
&
attr
)
{
if
(
op
.
Op
()
->
HasAttr
(
attr
))
return
BOOST_GET_CONST
(
T
,
op
.
Op
()
->
GetAttr
(
attr
));
else
return
paddle
::
none
;
}
ResidualConnectionMKLDNNFusePass
::
ResidualConnectionMKLDNNFusePass
()
{
AddOpCompat
(
OpCompat
(
"conv2d"
))
.
AddInput
(
"Input"
)
...
...
@@ -136,22 +76,22 @@ ResidualConnectionMKLDNNFusePass::ResidualConnectionMKLDNNFusePass() {
.
End
();
}
GraphWithStats
ResidualConnectionMKLDNNFusePass
::
FuseConv
AsX
(
const
std
::
string
&
name_scope
,
const
GraphWithStats
&
graph_with_stats
)
const
{
GraphWithStats
ResidualConnectionMKLDNNFusePass
::
FuseConv
(
const
std
::
string
&
name_scope
,
const
GraphWithStats
&
graph_with_stats
,
bool
as_x
)
const
{
GraphPatternDetector
gpd
;
auto
pattern
=
gpd
.
mutable_pattern
();
patterns
::
Conv
conv_pattern
{
pattern
,
name_scope
};
auto
conv_output
=
conv_pattern
();
patterns
::
Elementwise
elementwise_pattern
{
pattern
,
name_scope
};
patterns
::
ResidualElementwise
elementwise_pattern
{
pattern
,
name_scope
,
as_x
};
elementwise_pattern
(
conv_output
,
pattern
->
NewNode
(
elementwise_pattern
.
elementwise_y
_repr
()),
"elementwise_add"
);
conv_output
,
pattern
->
NewNode
(
elementwise_pattern
.
residual_data
_repr
()),
"elementwise_add"
,
as_x
);
conv_output
->
AsIntermediate
();
int
found_conv_
as_x_
count
=
0
;
int
found_conv_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
...
...
@@ -162,15 +102,13 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_op
,
elementwise_op
,
elementwise_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_identity
,
elementwise_y
,
GET_IR_NODE_FROM_SUBGRAPH
(
residual_data
,
residual_data
,
elementwise_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_out
,
elementwise_out
,
elementwise_pattern
);
if
(
FindFuseOption
(
*
conv_op
,
*
elementwise_op
)
!=
FUSE_MKLDNN
)
return
;
if
(
!
IsReachable
(
g
,
elementwise_identity
,
conv_output
))
return
;
if
(
!
IsReachable
(
g
,
residual_data
,
conv_output
))
return
;
if
(
HasFusedActivation
(
conv_op
))
return
;
if
(
!
IsCompat
(
subgraph
,
g
))
{
...
...
@@ -179,95 +117,29 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
return
;
}
conv_op
->
Op
()
->
SetInput
(
"ResidualData"
,
{
elementwise_identity
->
Name
()});
conv_op
->
Op
()
->
SetInput
(
"ResidualData"
,
{
residual_data
->
Name
()});
conv_op
->
Op
()
->
SetOutput
(
"Output"
,
{
elementwise_out
->
Name
()});
conv_op
->
Op
()
->
SetAttr
(
"fuse_residual_connection"
,
true
);
GraphSafeRemoveNodes
(
g
,
{
conv_output
,
elementwise_op
});
IR_NODE_LINK_TO
(
elementwise_identity
,
conv_op
);
IR_NODE_LINK_TO
(
residual_data
,
conv_op
);
IR_NODE_LINK_TO
(
conv_op
,
elementwise_out
);
found_conv_
as_x_
count
++
;
found_conv_count
++
;
};
gpd
(
graph_with_stats
.
first
,
handler
);
if
(
!
Has
(
"disable_logs"
)
||
!
Get
<
bool
>
(
"disable_logs"
))
{
std
::
stringstream
msg_ss
;
msg_ss
<<
"--- Fused "
<<
found_conv_as_x_count
<<
" conv (as x) + elementwise_add patterns"
;
std
::
string
fusionMode
=
as_x
?
"x"
:
"y"
;
msg_ss
<<
"--- Fused "
<<
found_conv_count
<<
" conv (as "
<<
fusionMode
<<
") + elementwise_add patterns"
;
paddle
::
string
::
PrettyLogDetail
(
msg_ss
.
str
().
c_str
());
}
return
std
::
make_pair
(
graph_with_stats
.
first
,
found_conv_as_x_count
+
graph_with_stats
.
second
);
}
GraphWithStats
ResidualConnectionMKLDNNFusePass
::
FuseConvAsY
(
const
std
::
string
&
name_scope
,
const
GraphWithStats
&
graph_with_stats
)
const
{
GraphPatternDetector
gpd
;
auto
pattern
=
gpd
.
mutable_pattern
();
patterns
::
Conv
conv_pattern
{
pattern
,
name_scope
};
auto
conv_output
=
conv_pattern
();
patterns
::
Elementwise
elementwise_pattern
{
pattern
,
name_scope
};
elementwise_pattern
(
pattern
->
NewNode
(
elementwise_pattern
.
elementwise_x_repr
()),
conv_output
,
"elementwise_add"
);
conv_output
->
AsIntermediate
();
int
found_conv_as_y_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
GET_IR_NODE_FROM_SUBGRAPH
(
conv_op
,
conv_op
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_input
,
conv_input
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_filter
,
conv_filter
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
conv_output
,
conv_output
,
conv_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_op
,
elementwise_op
,
elementwise_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_x
,
elementwise_x
,
elementwise_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_out
,
elementwise_out
,
elementwise_pattern
);
if
(
FindFuseOption
(
*
conv_op
,
*
elementwise_op
)
!=
FUSE_MKLDNN
)
return
;
if
(
!
IsReachable
(
g
,
elementwise_x
,
conv_output
))
return
;
if
(
HasFusedActivation
(
conv_op
))
return
;
if
(
!
IsCompat
(
subgraph
,
g
))
{
LOG
(
WARNING
)
<<
"conv_elementwise_add_mkldnn_fuse_pass in op compat failed."
;
return
;
}
conv_op
->
Op
()
->
SetInput
(
"ResidualData"
,
{
elementwise_x
->
Name
()});
conv_op
->
Op
()
->
SetOutput
(
"Output"
,
{
elementwise_out
->
Name
()});
conv_op
->
Op
()
->
SetAttr
(
"fuse_residual_connection"
,
true
);
GraphSafeRemoveNodes
(
g
,
{
conv_output
,
elementwise_op
});
IR_NODE_LINK_TO
(
elementwise_x
,
conv_op
);
IR_NODE_LINK_TO
(
conv_op
,
elementwise_out
);
found_conv_as_y_count
++
;
};
gpd
(
graph_with_stats
.
first
,
handler
);
if
(
!
Has
(
"disable_logs"
)
||
!
Get
<
bool
>
(
"disable_logs"
))
{
std
::
stringstream
msg_ss
;
msg_ss
<<
"--- Fused "
<<
found_conv_as_y_count
<<
" conv (as y) + elementwise_add patterns"
;
paddle
::
string
::
PrettyLogDetail
(
msg_ss
.
str
().
c_str
());
}
return
std
::
make_pair
(
graph_with_stats
.
first
,
found_conv_as_y_count
+
graph_with_stats
.
second
);
found_conv_count
+
graph_with_stats
.
second
);
}
GraphWithStats
ResidualConnectionMKLDNNFusePass
::
FuseProjectionConv
(
...
...
@@ -308,7 +180,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
if
(
!
IsCompat
(
subgraph
,
g
))
{
LOG
(
WARNING
)
<<
"
conv_elementwise_add_mkldnn_fuse_pass in op compat
failed."
;
<<
"
op compat for conv_elementwise_add_mkldnn_fuse_pass
failed."
;
return
;
}
...
...
@@ -361,8 +233,8 @@ void ResidualConnectionMKLDNNFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase
::
Init
(
name_scope_
,
graph
);
auto
graph_with_stats
=
FuseProjectionConv
(
name_scope_
,
std
::
make_pair
(
graph
,
0
));
graph_with_stats
=
FuseConv
AsX
(
name_scope_
,
graph_with_stats
);
graph_with_stats
=
FuseConv
AsY
(
name_scope_
,
graph_with_stats
);
graph_with_stats
=
FuseConv
(
name_scope_
,
graph_with_stats
,
true
);
graph_with_stats
=
FuseConv
(
name_scope_
,
graph_with_stats
,
false
);
AddStatis
(
graph_with_stats
.
second
);
}
...
...
paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h
浏览文件 @
e5e0b726
...
...
@@ -14,30 +14,20 @@
#pragma once
#include <memory>
#include <string>
#include <tuple>
#include <utility>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include <boost/optional.hpp>
namespace
paddle
{
namespace
framework
{
namespace
ir
{
using
GraphWithStats
=
std
::
pair
<
ir
::
Graph
*
,
int
>
;
bool
IsReachable
(
ir
::
Graph
*
graph
,
Node
*
from
,
Node
*
to
);
class
ResidualConnectionMKLDNNFusePass
:
public
FusePassBase
{
private:
GraphWithStats
FuseConvAsX
(
const
std
::
string
&
name_scope
,
const
GraphWithStats
&
graph_with_stats
)
const
;
GraphWithStats
FuseConvAsY
(
const
std
::
string
&
name_scope
,
const
GraphWithStats
&
graph_with_stats
)
const
;
GraphWithStats
FuseConv
(
const
std
::
string
&
name_scope
,
const
GraphWithStats
&
graph_with_stats
,
bool
as_x
)
const
;
GraphWithStats
FuseProjectionConv
(
const
std
::
string
&
name_scope
,
const
GraphWithStats
&
graph_with_stats
)
const
;
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_elementwise_add_fuse_pass.py
浏览文件 @
e5e0b726
...
...
@@ -26,7 +26,7 @@ import hypothesis.strategies as st
# the two inputs of elementwise_add are tensor
class
TestConvElementwiseAddMkldnnFusePass
1
(
PassAutoScanTest
):
class
TestConvElementwiseAddMkldnnFusePass
(
PassAutoScanTest
):
def
is_program_valid
(
self
,
program_config
:
ProgramConfig
)
->
bool
:
attrs
=
[
program_config
.
ops
[
i
].
attrs
...
...
@@ -125,139 +125,5 @@ class TestConvElementwiseAddMkldnnFusePass1(PassAutoScanTest):
quant
=
False
,
passes
=
[
"conv_elementwise_add_mkldnn_fuse_pass"
])
'''
class TestConvElementwiseAddMkldnnFusePass(PassAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
attrs = [
program_config.ops[i].attrs
for i in range(len(program_config.ops))
]
if "elementwise_weight" in program_config.weights:
if program_config.weights["elementwise_weight"].shape[0] == program_config.inputs["input_data1"].shape[1]:
if attrs[2]['axis'] != 1:
return False
if program_config.weights["elementwise_weight"].shape[0] == program_config.inputs["input_data1"].shape[3]:
if attrs[2]['axis'] != -1:
return False
return True
def sample_program_config(self, draw):
data_format = draw(st.sampled_from(["NCHW", "NHWC"]))
dilations = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]]))
padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"]))
groups = draw(st.sampled_from([1, 2, 4]))
paddings = draw(st.sampled_from([[0, 3], [1, 1], [1, 2, 3, 4]]))
strides = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]]))
axis = draw(st.sampled_from([-1, 0, 1]))
batch_size = draw(st.integers(min_value=1, max_value=4))
def generate_input1():
if data_format == "NCHW":
return np.random.random(
[batch_size, 48, 64, 64]).astype(np.float32)
else:
return np.random.random(
[batch_size, 64, 64, 48]).astype(np.float32)
def generate_weight1():
return np.random.random(
[48, int(48 / groups), 3, 3]).astype(np.float32)
def compute_out_shape(padding_alg):
import paddle
import paddle.nn as nn
x_var = paddle.uniform(
(batch_size, 48, 64, 64), dtype='float32', min=-1., max=1.)
if padding_alg == "EXPLICIT":
conv = nn.Conv2D(48, 48, (3, 3), strides, paddings, dilations,
1)
else:
conv = nn.Conv2D(48, 48, (3, 3), strides, padding_alg,
dilations, 1)
y_var = conv(x_var)
return y_var.shape
def generate_weight2():
return np.random.random([48]).astype(np.float32)
if compute_out_shape(padding_algorithm) != (batch_size, 48, 64, 64):
axis = 1
relu_op = OpConfig(
type="relu",
inputs={"X": ["input_data1"]},
outputs={"Out": ["sigmoid_out"]},
attrs={})
conv2d_op = OpConfig(
type="conv2d",
inputs={"Input": ["sigmoid_out"],
"Filter": ["conv_weight"]},
outputs={"Output": ["conv_output"]},
attrs={
"data_format": data_format,
"dilations": dilations,
"padding_algorithm": padding_algorithm,
"groups": groups,
"paddings": paddings,
"strides": strides
})
if axis == 0:
elt_op = OpConfig(
type="elementwise_add",
inputs={"X": ["input_data1"],
"Y": ["conv_output"]},
outputs={"Out": ["elementwise_output"]},
attrs={'axis': axis})
else:
elt_op = OpConfig(
type="elementwise_add",
inputs={"X": ["conv_output"],
"Y": ["elementwise_weight"]},
outputs={"Out": ["elementwise_output"]},
attrs={'axis': axis})
model_net = [relu_op, conv2d_op, elt_op]
if axis == 0:
program_config = ProgramConfig(
ops=model_net,
weights={
"conv_weight":
TensorConfig(data_gen=partial(generate_weight1))
},
inputs={
"input_data1":
TensorConfig(data_gen=partial(generate_input1))
},
outputs=["elementwise_output"])
else:
program_config = ProgramConfig(
ops=model_net,
weights={
"conv_weight":
TensorConfig(data_gen=partial(generate_weight1)),
"elementwise_weight":
TensorConfig(data_gen=partial(generate_weight2))
},
inputs={
"input_data1":
TensorConfig(data_gen=partial(generate_input1))
},
outputs=["elementwise_output"])
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True)
yield config, ["relu", "conv2d"], (1e-5, 1e-5)
def test(self):
self.run_and_statis(
quant=False, passes=["conv_elementwise_add_mkldnn_fuse_pass"])
'''
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录