Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
c6c65c65
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c6c65c65
编写于
4月 22, 2020
作者:
J
Jacek Czaja
提交者:
GitHub
4月 22, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[DNNL] Added elementwise_add mkl-dnn inplace (#23477)
上级
9ff558a4
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
259 addition
and
111 deletion
+259
-111
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+1
-1
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+10
-10
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+2
-1
paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.cc
paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.cc
+106
-40
paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass_tester.cc
...e/fluid/framework/ir/mkldnn/mkldnn_inplace_pass_tester.cc
+18
-9
paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc
...operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc
+19
-24
paddle/fluid/operators/mkldnn/inplace_op_tests.cmake
paddle/fluid/operators/mkldnn/inplace_op_tests.cmake
+1
-1
paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc
paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc
+6
-5
paddle/fluid/operators/mkldnn/test_mkldnn_op_inplace.cc
paddle/fluid/operators/mkldnn/test_mkldnn_op_inplace.cc
+58
-17
paddle/fluid/platform/mkldnn_helper.h
paddle/fluid/platform/mkldnn_helper.h
+5
-0
paddle/fluid/platform/mkldnn_reuse.h
paddle/fluid/platform/mkldnn_reuse.h
+33
-3
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
c6c65c65
...
@@ -86,7 +86,7 @@ endif()
...
@@ -86,7 +86,7 @@ endif()
if
(
WITH_MKLDNN
)
if
(
WITH_MKLDNN
)
pass_library
(
mkldnn_placement_pass base DEPS placement_pass_base DIR mkldnn
)
pass_library
(
mkldnn_placement_pass base DEPS placement_pass_base DIR mkldnn
)
pass_library
(
mkldnn_inplace_pass inference DEPS mkldnn_placement_pass op_registry softmax_op softmax DIR mkldnn
)
pass_library
(
mkldnn_inplace_pass inference DEPS mkldnn_placement_pass op_registry
elementwise_add_op activation_op
softmax_op softmax DIR mkldnn
)
pass_library
(
depthwise_conv_mkldnn_pass base DIR mkldnn
)
pass_library
(
depthwise_conv_mkldnn_pass base DIR mkldnn
)
pass_library
(
conv_bias_mkldnn_fuse_pass inference DIR mkldnn
)
pass_library
(
conv_bias_mkldnn_fuse_pass inference DIR mkldnn
)
pass_library
(
conv_activation_mkldnn_fuse_pass inference DIR mkldnn
)
pass_library
(
conv_activation_mkldnn_fuse_pass inference DIR mkldnn
)
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
c6c65c65
...
@@ -1892,30 +1892,30 @@ PDNode *patterns::MultipleQuantize::operator()() {
...
@@ -1892,30 +1892,30 @@ PDNode *patterns::MultipleQuantize::operator()() {
}
}
PDNode
*
patterns
::
MKLDNNInPlace
::
operator
()()
{
PDNode
*
patterns
::
MKLDNNInPlace
::
operator
()()
{
// TODO(jczaja): Enable more mkl-dnn ops e.g. activation, elementwise_add,
// batch_norm....
auto
possible_inplace_op
=
auto
possible_inplace_op
=
pattern
->
NewNode
(
inplace_to_be_op_repr
())
->
assert_is_ops
({
"softmax"
});
pattern
->
NewNode
(
inplace_to_be_op_repr
())
->
assert_is_ops
({
"elementwise_add"
,
"softmax"
});
// TODO(jczaja): Enable more mkl-dnn ops e.g. activation, elementwise_add,
// TODO(jczaja): Enable more mkl-dnn ops e.g. activation, batch_norm....
// batch_norm....
auto
input
=
pattern
->
NewNode
(
inplace_to_be_op_in_repr
())
auto
input
=
pattern
->
NewNode
(
inplace_to_be_op_in_repr
())
->
assert_is_ops_input
({
"softmax"
})
->
assert_is_ops_input
({
"
elementwise_add"
,
"
softmax"
})
->
AsInput
();
->
AsInput
();
// TODO(jczaja): Enable more mkl-dnn ops e.g. activation, elementwise_add,
// TODO(jczaja): Enable more mkl-dnn ops e.g. activation, batch_norm....
// batch_norm....
auto
output
=
pattern
->
NewNode
(
inplace_to_be_op_out_repr
())
auto
output
=
pattern
->
NewNode
(
inplace_to_be_op_out_repr
())
->
assert_is_ops_output
({
"softmax"
})
->
assert_is_ops_output
({
"
elementwise_add"
,
"
softmax"
})
->
As
Intermediate
();
->
As
Output
();
auto
next_op
=
pattern
->
NewNode
(
next_op_repr
())
->
assert_is_op
();
auto
next_op
=
pattern
->
NewNode
(
next_op_repr
())
->
assert_is_op
();
auto
next_output
=
pattern
->
NewNode
(
next_op_out_repr
())
->
AsOutput
();
// Check if op is MKL-DNN enabled
// Check if op is MKL-DNN enabled
possible_inplace_op
->
assert_op_attr
(
"use_mkldnn"
,
true
);
possible_inplace_op
->
assert_op_attr
(
"use_mkldnn"
,
true
);
// linked structure
possible_inplace_op
->
LinksTo
({
output
});
possible_inplace_op
->
LinksTo
({
output
});
possible_inplace_op
->
LinksFrom
({
input
});
possible_inplace_op
->
LinksFrom
({
input
});
next_op
->
LinksFrom
({
output
});
next_op
->
LinksFrom
({
output
});
next_op
->
LinksTo
({
next_output
});
return
possible_inplace_op
;
return
possible_inplace_op
;
}
}
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
c6c65c65
...
@@ -1140,11 +1140,12 @@ struct MKLDNNInPlace : public PatternBase {
...
@@ -1140,11 +1140,12 @@ struct MKLDNNInPlace : public PatternBase {
:
PatternBase
(
pattern
,
name_scope
,
"mkldnn_inplace"
)
{}
:
PatternBase
(
pattern
,
name_scope
,
"mkldnn_inplace"
)
{}
PDNode
*
operator
()();
PDNode
*
operator
()();
// MKL-DNN's in-place ops: BatchNorm, Softmax,
Layer Norm
// MKL-DNN's in-place ops: BatchNorm, Softmax,
Elementwise_add
PATTERN_DECL_NODE
(
inplace_to_be_op
);
PATTERN_DECL_NODE
(
inplace_to_be_op
);
PATTERN_DECL_NODE
(
inplace_to_be_op_in
);
PATTERN_DECL_NODE
(
inplace_to_be_op_in
);
PATTERN_DECL_NODE
(
inplace_to_be_op_out
);
PATTERN_DECL_NODE
(
inplace_to_be_op_out
);
PATTERN_DECL_NODE
(
next_op
);
PATTERN_DECL_NODE
(
next_op
);
PATTERN_DECL_NODE
(
next_op_out
);
};
};
struct
TransposeFlattenConcat
:
public
PatternBase
{
struct
TransposeFlattenConcat
:
public
PatternBase
{
...
...
paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.cc
浏览文件 @
c6c65c65
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include <algorithm>
#include <algorithm>
#include <memory>
#include <memory>
#include <string>
#include <string>
#include <unordered_map>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
...
@@ -30,6 +31,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
...
@@ -30,6 +31,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL
(
graph
,
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Pointer to graph argument should not be NULL."
));
"Pointer to graph argument should not be NULL."
));
std
::
unordered_map
<
std
::
string
,
std
::
string
>
original_output_names
;
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
patterns
::
MKLDNNInPlace
mkldnn_inplace
{
gpd
.
mutable_pattern
(),
patterns
::
MKLDNNInPlace
mkldnn_inplace
{
gpd
.
mutable_pattern
(),
"mkldnn_inplace"
};
"mkldnn_inplace"
};
...
@@ -40,72 +42,136 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
...
@@ -40,72 +42,136 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
Graph
*
g
)
{
Graph
*
g
)
{
VLOG
(
3
)
<<
"Start to handle MKL-DNN In-Place pass"
;
VLOG
(
3
)
<<
"Start to handle MKL-DNN In-Place pass"
;
GET_IR_NODE_FROM_SUBGRAPH
(
inplace_to_be_op
,
inplace_to_be_op
,
GET_IR_NODE_FROM_SUBGRAPH
(
current_op
,
inplace_to_be_op
,
mkldnn_inplace
);
GET_IR_NODE_FROM_SUBGRAPH
(
current_op_in
,
inplace_to_be_op_in
,
mkldnn_inplace
);
mkldnn_inplace
);
GET_IR_NODE_FROM_SUBGRAPH
(
inplace_to_be_op_in
,
inplace_to_be_op_in
,
GET_IR_NODE_FROM_SUBGRAPH
(
current_op_out
,
inplace_to_be_op_out
,
mkldnn_inplace
);
GET_IR_NODE_FROM_SUBGRAPH
(
inplace_to_be_op_out
,
inplace_to_be_op_out
,
mkldnn_inplace
);
mkldnn_inplace
);
GET_IR_NODE_FROM_SUBGRAPH
(
next_op
,
next_op
,
mkldnn_inplace
);
GET_IR_NODE_FROM_SUBGRAPH
(
next_op
,
next_op
,
mkldnn_inplace
);
GET_IR_NODE_FROM_SUBGRAPH
(
next_op_out
,
next_op_out
,
mkldnn_inplace
);
if
((
inplace_to_be_op
->
Op
()
->
HasAttr
(
"use_mkldnn"
)
==
false
)
||
if
((
current_op
->
Op
()
->
HasAttr
(
"use_mkldnn"
)
==
false
)
||
(
boost
::
get
<
bool
>
(
inplace_to_be_op
->
Op
()
->
GetAttr
(
"use_mkldnn"
))
==
(
boost
::
get
<
bool
>
(
current_op
->
Op
()
->
GetAttr
(
"use_mkldnn"
))
==
false
))
{
false
))
{
VLOG
(
3
)
<<
"do not perform mkl-dnn inplace: use_mkldnn missing or set to "
VLOG
(
3
)
<<
"do not perform mkl-dnn inplace: use_mkldnn missing or set to "
"false"
;
"false"
;
return
;
return
;
}
}
auto
&
infer_inplace
=
OpInfoMap
::
Instance
()
auto
&
infer_inplace
=
.
Get
(
inplace_to_be_op
->
Op
()
->
Type
())
OpInfoMap
::
Instance
().
Get
(
current_op
->
Op
()
->
Type
()).
infer_inplace_
;
.
infer_inplace_
;
if
(
!
infer_inplace
)
{
if
(
!
infer_inplace
)
{
VLOG
(
3
)
<<
"do not perform mkl-dnn inplace: missing InplaceInferer"
;
VLOG
(
3
)
<<
"do not perform mkl-dnn inplace: missing InplaceInferer"
;
return
;
return
;
}
}
// TODO(jczaja): Enable more ops
VLOG
(
3
)
<<
"DNNL Inplace op("
<<
current_op
->
id
()
<<
") "
if
(
inplace_to_be_op
->
Op
()
->
Type
()
!=
"softmax"
)
{
<<
"Curr Node In: "
<<
current_op_in
->
Name
()
VLOG
(
3
)
<<
" Curr Node out: "
<<
current_op_out
->
Name
();
<<
"Curently works for softmax only. TODO(jczaja): support other ops"
;
VLOG
(
3
)
<<
"DNNL Inplace next op("
<<
next_op
->
id
()
<<
") "
<<
" next Node out: "
<<
next_op_out
->
Name
();
auto
inputs
=
current_op
->
Op
()
->
Inputs
();
auto
outputs
=
current_op
->
Op
()
->
Outputs
();
auto
in_to_outs
=
infer_inplace
(
false
);
// strictly no CUDA for MKL-DNN
VLOG
(
3
)
<<
"DNNL InplaceInferer op("
<<
current_op
->
id
()
<<
") "
<<
in_to_outs
.
begin
()
->
first
<<
": "
<<
inputs
[
in_to_outs
.
begin
()
->
first
][
0
]
<<
" "
<<
in_to_outs
.
begin
()
->
second
<<
": "
<<
outputs
[
in_to_outs
.
begin
()
->
second
][
0
];
// If InferInplace pattern does not contain input node then skip
auto
inplace_input_vec
=
inputs
[
in_to_outs
.
begin
()
->
first
];
if
(
std
::
find
(
inplace_input_vec
.
begin
(),
inplace_input_vec
.
end
(),
current_op_in
->
Name
())
==
inplace_input_vec
.
end
())
{
VLOG
(
3
)
<<
"DNNL in-place pass SKIP pattern "
;
return
;
return
;
}
}
// Iterate over all nodes that are ops
// Checking if this particular node (to be inplaced, overwritten)
// and check if in-place to be var is part of inputs
// is used anywhere else apart from inplaced op
// if positive then do not perform inplace
auto
input_consumers
=
current_op_in
->
outputs
;
for
(
const
Node
*
n
:
graph
->
Nodes
())
{
if
(
input_consumers
.
size
()
>
1
)
{
if
(
n
->
IsOp
())
{
VLOG
(
3
)
<<
"DNNL in-place pass FAIL: in-place var cannot "
// Avoid searchin in op that is to be inplace
"be an input to multiple operators"
;
if
((
n
->
id
()
!=
inplace_to_be_op
->
id
()))
{
return
;
auto
*
op
=
n
->
Op
();
}
auto
inputs
=
op
->
Inputs
();
auto
in_place_input
=
inplace_to_be_op_in
->
Name
();
// If this op was alrady inplaced in previous pass placements
for
(
auto
&
it
:
inputs
)
{
// then we need to update input of next op
for
(
auto
&
var_name
:
it
.
second
)
{
// but original name to be changed is gone, so we need to remember it
if
(
var_name
==
in_place_input
)
{
// on first time given op is to be inplaced
VLOG
(
3
)
<<
"MKL-DNN in-place pass: in-place var cannot be an "
if
(
current_op_in
->
Name
()
!=
current_op_out
->
Name
())
{
"input to more than one operator"
;
original_output_names
[
current_op
->
Name
()
+
current_op_in
->
Name
()]
=
return
;
current_op_out
->
Name
();
}
}
else
{
}
VLOG
(
3
)
<<
"DNNL Inplace: Current op already inplaced! "
;
}
// It may be that next op is reusing some of vars, we need to
// make sure that unwanted inplace is not created
// TODO(jczaja): Make UT for that one
for
(
auto
&
n
:
current_op_out
->
outputs
)
{
auto
&
n_op_infer_inplace
=
OpInfoMap
::
Instance
().
Get
(
n
->
Op
()
->
Type
()).
infer_inplace_
;
if
((
n_op_infer_inplace
==
nullptr
))
{
for
(
auto
&
m
:
n
->
outputs
)
{
if
(
m
->
Name
()
==
current_op_in
->
Name
())
{
VLOG
(
3
)
<<
"DNNL in-place pass FAIL: in-place var cannot "
"be an output to non-inplaced next op"
;
return
;
}
}
}
}
}
}
}
}
auto
original_name
=
inplace_to_be_op_out
->
Name
();
auto
original_name
=
inplace_to_be_op_out
->
RenameVar
(
inplace_to_be_op_in
->
Name
());
original_output_names
[
current_op
->
Name
()
+
current_op_in
->
Name
()];
current_op_out
->
RenameVar
(
current_op_in
->
Name
());
// Get mapping of input to output
// Get mapping of input to output
auto
in_to_outs
=
infer_inplace
(
false
);
// strictly no CUDA for MKL-DNN
// TODO(jczaja): Support more complex situations
auto
out_name
=
in_to_outs
.
begin
()
->
second
;
auto
out_name
=
in_to_outs
.
begin
()
->
second
;
inplace_to_be_op
->
Op
()
->
SetOutput
(
current_op
->
Op
()
->
SetOutput
(
out_name
,
std
::
vector
<
std
::
string
>
({
inplace_to_be_op_out
->
Name
()}));
out_name
,
std
::
vector
<
std
::
string
>
({
current_op_out
->
Name
()}));
next_op
->
Op
()
->
RenameInput
(
original_name
,
inplace_to_be_op_out
->
Name
());
// If next op in a line is doing inplace
// then we need to update its output as well
// Get inferer of next op
// If no inferer then we are done
auto
&
next_op_infer_inplace
=
OpInfoMap
::
Instance
().
Get
(
next_op
->
Op
()
->
Type
()).
infer_inplace_
;
if
(
next_op_infer_inplace
)
{
auto
in_to_outs
=
next_op_infer_inplace
(
false
);
auto
out_name
=
in_to_outs
.
begin
()
->
second
;
auto
*
op
=
next_op
->
Op
();
auto
inputs
=
op
->
Inputs
();
auto
outputs
=
op
->
Outputs
();
// Check if in-place happened
// for variable we changed (original name)
// TODO(jczaja): make recursive propagation of inplace
auto
next_op_inplace_inputs
=
inputs
[
in_to_outs
.
begin
()
->
first
];
if
((
next_op_inplace_inputs
==
outputs
[
in_to_outs
.
begin
()
->
second
])
&&
(
std
::
find
(
next_op_inplace_inputs
.
begin
(),
next_op_inplace_inputs
.
end
(),
original_name
)
!=
next_op_inplace_inputs
.
end
()))
{
VLOG
(
3
)
<<
"DNNL InPlace: Next Op is in-placed , updating its "
"input "
"and output var!"
;
next_op
->
Op
()
->
SetOutput
(
out_name
,
std
::
vector
<
std
::
string
>
({
current_op_out
->
Name
()}));
next_op_out
->
RenameVar
(
current_op_in
->
Name
());
// Get ops that next_op_out is linked to and update their input
auto
next_op_out_consumers
=
next_op_out
->
outputs
;
// Has to be ops
for
(
auto
&
c
:
next_op_out_consumers
)
{
c
->
Op
()
->
RenameInput
(
original_name
,
current_op_out
->
Name
());
}
}
}
next_op
->
Op
()
->
RenameInput
(
original_name
,
current_op_out
->
Name
());
found_inplace_count
++
;
found_inplace_count
++
;
VLOG
(
3
)
<<
"
MKL-DNN
InPlace applied!"
;
VLOG
(
3
)
<<
"
DNNL
InPlace applied!"
;
};
};
gpd
(
graph
,
handler
);
gpd
(
graph
,
handler
);
...
...
paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass_tester.cc
浏览文件 @
c6c65c65
...
@@ -21,6 +21,9 @@
...
@@ -21,6 +21,9 @@
USE_OP
(
softmax
);
USE_OP
(
softmax
);
USE_OP_DEVICE_KERNEL
(
softmax
,
MKLDNN
);
USE_OP_DEVICE_KERNEL
(
softmax
,
MKLDNN
);
USE_OP
(
elementwise_add
);
USE_OP_DEVICE_KERNEL
(
elementwise_add
,
MKLDNN
);
USE_OP
(
relu
);
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -62,8 +65,9 @@ class MKLDNNInplacePassTest {
...
@@ -62,8 +65,9 @@ class MKLDNNInplacePassTest {
bool
branched
)
{
bool
branched
)
{
ProgramDesc
prog
;
ProgramDesc
prog
;
for
(
auto
&
v
:
std
::
vector
<
std
::
string
>
(
for
(
auto
&
v
:
{
"a"
,
"weights"
,
"bias"
,
"f"
,
"g"
,
"h"
,
"i"
,
"j"
,
"k"
}))
{
std
::
vector
<
std
::
string
>
({
"a"
,
"weights"
,
"bias"
,
"f"
,
"g"
,
"h"
,
"i"
,
"j"
,
"k"
,
"l"
,
"m"
,
"z"
}))
{
auto
*
var
=
prog
.
MutableBlock
(
0
)
->
Var
(
v
);
auto
*
var
=
prog
.
MutableBlock
(
0
)
->
Var
(
v
);
var
->
SetType
(
proto
::
VarType
::
SELECTED_ROWS
);
var
->
SetType
(
proto
::
VarType
::
SELECTED_ROWS
);
if
(
v
==
"weights"
||
v
==
"bias"
)
{
if
(
v
==
"weights"
||
v
==
"bias"
)
{
...
@@ -83,9 +87,12 @@ class MKLDNNInplacePassTest {
...
@@ -83,9 +87,12 @@ class MKLDNNInplacePassTest {
SetOp
(
&
prog
,
"elementwise_add"
,
"elementwise_add1"
,
SetOp
(
&
prog
,
"elementwise_add"
,
"elementwise_add1"
,
std
::
vector
<
std
::
string
>
({
"h"
,
"i"
}),
std
::
vector
<
std
::
string
>
({
"j"
}),
std
::
vector
<
std
::
string
>
({
"h"
,
"i"
}),
std
::
vector
<
std
::
string
>
({
"j"
}),
mkldnn_enabled_op
.
compare
(
"elementwise_add"
)
==
0
);
mkldnn_enabled_op
.
compare
(
"elementwise_add"
)
==
0
);
SetOp
(
&
prog
,
"relu"
,
"relu2"
,
std
::
vector
<
std
::
string
>
({
"j"
}),
std
::
vector
<
std
::
string
>
({
"k"
}),
mkldnn_enabled_op
.
compare
(
"softmax"
)
==
0
);
if
(
branched
==
true
)
{
if
(
branched
==
true
)
{
SetOp
(
&
prog
,
"softmax"
,
"softmax2"
,
std
::
vector
<
std
::
string
>
({
"g"
}),
SetOp
(
&
prog
,
"softmax"
,
"softmax2"
,
std
::
vector
<
std
::
string
>
({
"g"
}),
std
::
vector
<
std
::
string
>
({
"
k
"
}),
std
::
vector
<
std
::
string
>
({
"
z
"
}),
mkldnn_enabled_op
.
compare
(
"softmax"
)
==
0
);
mkldnn_enabled_op
.
compare
(
"softmax"
)
==
0
);
}
}
...
@@ -105,12 +112,11 @@ class MKLDNNInplacePassTest {
...
@@ -105,12 +112,11 @@ class MKLDNNInplacePassTest {
unsigned
use_mkldnn_true_count
=
0
;
unsigned
use_mkldnn_true_count
=
0
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
input_names
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
input_names
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
output_names
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
output_names
;
input_names
[
"softmax"
]
=
"X"
;
input_names
[
"softmax"
]
=
"X"
;
output_names
[
"softmax"
]
=
"Out"
;
output_names
[
"softmax"
]
=
"Out"
;
input_names
[
"batch_norm"
]
=
"X"
;
input_names
[
"elementwise_add"
]
=
"X"
;
output_names
[
"batch_norm"
]
=
"Y"
;
output_names
[
"elementwise_add"
]
=
"Out"
;
input_names
[
"layer_norm"
]
=
"X"
;
output_names
[
"layer_norm"
]
=
"Y"
;
VLOG
(
3
)
<<
DebugString
(
graph
);
VLOG
(
3
)
<<
DebugString
(
graph
);
...
@@ -135,15 +141,18 @@ class MKLDNNInplacePassTest {
...
@@ -135,15 +141,18 @@ class MKLDNNInplacePassTest {
TEST
(
MKLDNNInplacePass
,
inplace_softmax
)
{
TEST
(
MKLDNNInplacePass
,
inplace_softmax
)
{
// softmax to be mkl-dnn enabled and made in-place
// softmax to be mkl-dnn enabled and made in-place
MKLDNNInplacePassTest
().
MainTest
(
"softmax"
,
false
,
1
);
MKLDNNInplacePassTest
().
MainTest
(
"softmax"
,
false
,
1
);
}
}
TEST
(
MKLDNNInplacePass
,
inplace_softmax_branched
)
{
TEST
(
MKLDNNInplacePass
,
inplace_softmax_branched
)
{
// softmax
to be mkl-dnn enabled and made
in-place
// softmax
's input is shared by two branches. so no
in-place
MKLDNNInplacePassTest
().
MainTest
(
"softmax"
,
true
,
0
);
MKLDNNInplacePassTest
().
MainTest
(
"softmax"
,
true
,
0
);
}
}
TEST
(
MKLDNNInplacePass
,
inplace_elementwise_add
)
{
// Two elementwise_add mkl-dnn enabled op instances to be made inplace
MKLDNNInplacePassTest
().
MainTest
(
"elementwise_add"
,
false
,
1
);
}
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
...
...
paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc
浏览文件 @
c6c65c65
...
@@ -56,39 +56,34 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
...
@@ -56,39 +56,34 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
y
->
format
(),
MKLDNNMemoryFormat
::
undef
,
y
->
format
(),
MKLDNNMemoryFormat
::
undef
,
platform
::
errors
::
InvalidArgument
(
"Wrong format set for Y tensor"
));
platform
::
errors
::
InvalidArgument
(
"Wrong format set for Y tensor"
));
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
y_data
=
y
->
data
<
T
>
();
auto
src_x_tz
=
framework
::
vectorize
<
int64_t
>
(
x
->
dims
());
auto
src_x_tz
=
framework
::
vectorize
<
int64_t
>
(
x
->
dims
());
auto
src_y_tz
=
framework
::
vectorize
<
int64_t
>
(
y
->
dims
());
auto
src_y_tz
=
framework
::
vectorize
<
int64_t
>
(
y
->
dims
());
auto
dst_tz
=
framework
::
vectorize
<
int64_t
>
(
z
->
dims
());
auto
dst_tz
=
framework
::
vectorize
<
int64_t
>
(
z
->
dims
());
std
::
vector
<
float
>
scales
=
{
1.0
f
,
1.0
f
};
// Currently MKL-DNN kernel supports only Z <- X + Y, shape(X) == shape(Y)
// TODO(jczaja): Binary primitive support broadcasting, so we can support
// this in kernel
platform
::
BinaryMKLDNNHandler
<
T
>
handler
(
dnnl
::
algorithm
::
binary_add
,
src_x_tz
,
x
->
format
(),
y
->
format
(),
dev_ctx
,
ctx
.
GetPlace
(),
ctx
.
OutputName
(
"Out"
));
const
std
::
string
key
=
auto
src_x_memory
=
handler
.
AcquireSrcMemory
(
x
);
platform
::
CreateKey
(
src_x_tz
,
ctx
.
OutputName
(
"Out"
)
);
auto
src_y_memory
=
handler
.
AcquireSecondSrcMemory
(
y
);
platform
::
SumMKLDNNHandler
handler
(
dev_ctx
,
mkldnn_engine
,
key
);
// For Inplace src and and dst are the same memory object
auto
dst_memory
=
x
->
IsSharedBufferWith
(
*
z
)
?
src_x_memory
:
handler
.
AcquireDstMemory
(
z
);
auto
src_x_memory
=
handler
.
AcquireSrcMemory
(
auto
binary_prim
=
handler
.
AcquireForwardPrimitive
();
{{
src_x_tz
},
platform
::
MKLDNNGetDataType
<
T
>
(),
x
->
format
()},
paddle
::
platform
::
to_void_cast
(
x_data
));
auto
src_y_memory
=
handler
.
AcquireSecondSrcMemory
(
{{
src_y_tz
},
platform
::
MKLDNNGetDataType
<
T
>
(),
y
->
format
()},
paddle
::
platform
::
to_void_cast
(
y_data
));
auto
dst_md
=
memory
::
desc
({
dst_tz
},
platform
::
MKLDNNGetDataType
<
T
>
(),
MKLDNNMemoryFormat
::
any
);
auto
sum_pd
=
handler
.
AcquireSumPrimitiveDescriptor
(
{
src_x_memory
,
src_y_memory
},
scales
,
dst_md
);
T
*
z_data
=
z
->
mutable_data
<
T
>
(
ctx
.
GetPlace
(),
sum_pd
->
dst_desc
().
get_size
());
auto
dst_memory
=
handler
.
AcquireDstMemoryFromPrimitive
(
z_data
);
auto
sum_prim
=
handler
.
AcquireSum
();
mkldnn
::
stream
astream
(
mkldnn_engine
);
mkldnn
::
stream
astream
(
mkldnn_engine
);
sum_prim
->
execute
(
astream
,
{{
MKLDNN_ARG_MULTIPLE_SRC
,
*
src_x_memory
},
{
MKLDNN_ARG_MULTIPLE_SRC
+
1
,
*
src_y_memory
},
std
::
unordered_map
<
int
,
dnnl
::
memory
>
args
=
{
{
MKLDNN_ARG_DST
,
*
dst_memory
}});
{
DNNL_ARG_SRC_0
,
*
src_x_memory
},
{
DNNL_ARG_SRC_1
,
*
src_y_memory
},
{
DNNL_ARG_DST
,
*
dst_memory
}};
binary_prim
->
execute
(
astream
,
args
);
astream
.
wait
();
astream
.
wait
();
z
->
set_layout
(
DataLayout
::
kMKLDNN
);
z
->
set_layout
(
DataLayout
::
kMKLDNN
);
...
...
paddle/fluid/operators/mkldnn/inplace_op_tests.cmake
浏览文件 @
c6c65c65
cc_test
(
test_mkldnn_op_inplace SRCS mkldnn/test_mkldnn_op_inplace.cc DEPS op_registry softmax_op softmax scope device_context enforce executor
)
cc_test
(
test_mkldnn_op_inplace SRCS mkldnn/test_mkldnn_op_inplace.cc DEPS op_registry
elementwise_add_op
softmax_op softmax scope device_context enforce executor
)
paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc
浏览文件 @
c6c65c65
...
@@ -45,7 +45,8 @@ class SoftmaxMKLDNNHandler
...
@@ -45,7 +45,8 @@ class SoftmaxMKLDNNHandler
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
softmax_forward
,
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
softmax_forward
,
mkldnn
::
softmax_backward
>
(
mkldnn
::
softmax_backward
>
(
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
platform
::
CreateKey
(
dims
,
uniq_name
))
{
// Softmax may be inplace then uniq_name is no longer unique
platform
::
CreateKey
(
dims
,
axis
,
uniq_name
))
{
auto
md
=
mkldnn
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
fmt
);
auto
md
=
mkldnn
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
fmt
);
this
->
AcquireForwardPrimitiveDescriptor
(
prop_kind
::
forward_scoring
,
md
,
this
->
AcquireForwardPrimitiveDescriptor
(
prop_kind
::
forward_scoring
,
md
,
...
@@ -60,7 +61,7 @@ class SoftmaxMKLDNNHandler
...
@@ -60,7 +61,7 @@ class SoftmaxMKLDNNHandler
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
softmax_forward
,
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
softmax_forward
,
mkldnn
::
softmax_backward
>
(
mkldnn
::
softmax_backward
>
(
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
platform
::
CreateKey
(
dims
,
uniq_name
))
{
platform
::
CreateKey
(
dims
,
axis
,
uniq_name
))
{
auto
data_softmax_md
=
auto
data_softmax_md
=
mkldnn
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
fmt
);
mkldnn
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
fmt
);
auto
diff_softmax_md
=
auto
diff_softmax_md
=
...
@@ -95,13 +96,13 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
...
@@ -95,13 +96,13 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
auto
softmax_src_memory_p
=
handler
.
AcquireSrcMemory
(
input
);
auto
softmax_src_memory_p
=
handler
.
AcquireSrcMemory
(
input
);
auto
softmax_p
=
handler
.
AcquireForwardPrimitive
();
auto
softmax_p
=
handler
.
AcquireForwardPrimitive
();
// For Inplace src and and dst are the same memory object
// For Inplace src and and dst are the same memory object
auto
softmax_dst_memory_p
=
input
->
Holder
()
==
output
->
Holder
(
)
auto
softmax_dst_memory_p
=
input
->
IsSharedBufferWith
(
*
output
)
?
softmax_src_memory_p
?
softmax_src_memory_p
:
handler
.
AcquireDstMemory
(
output
);
:
handler
.
AcquireDstMemory
(
output
);
mkldnn
::
stream
astream
(
dev_ctx
.
GetEngine
());
mkldnn
::
stream
astream
(
dev_ctx
.
GetEngine
());
softmax_p
->
execute
(
astream
,
{{
MKLDNN
_ARG_SRC
,
*
softmax_src_memory_p
},
softmax_p
->
execute
(
astream
,
{{
DNNL
_ARG_SRC
,
*
softmax_src_memory_p
},
{
MKLDNN
_ARG_DST
,
*
softmax_dst_memory_p
}});
{
DNNL
_ARG_DST
,
*
softmax_dst_memory_p
}});
astream
.
wait
();
astream
.
wait
();
const
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
const
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
...
...
paddle/fluid/operators/mkldnn/test_mkldnn_op_inplace.cc
浏览文件 @
c6c65c65
...
@@ -27,38 +27,68 @@
...
@@ -27,38 +27,68 @@
USE_OP
(
softmax
);
USE_OP
(
softmax
);
USE_OP_DEVICE_KERNEL
(
softmax
,
MKLDNN
);
USE_OP_DEVICE_KERNEL
(
softmax
,
MKLDNN
);
USE_OP
(
elementwise_add
);
USE_OP_DEVICE_KERNEL
(
elementwise_add
,
MKLDNN
);
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
struct
InputVars
{
std
::
string
name
;
framework
::
LoDTensor
*
tensor
;
};
template
<
typename
T
>
template
<
typename
T
>
bool
TestMain
(
const
platform
::
Place
&
place
,
const
framework
::
DDim
&
dims
)
{
bool
TestMain
(
const
platform
::
Place
&
place
,
const
std
::
string
&
op_type
,
const
framework
::
DDim
&
dims
,
const
int
num_inputs
)
{
framework
::
Scope
scope
;
framework
::
Scope
scope
;
auto
*
x
=
scope
.
Var
(
"x"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
y
=
scope
.
Var
(
"y"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
x
->
Resize
(
dims
);
std
::
vector
<
InputVars
>
input_names
=
{
y
->
Resize
(
dims
);
{
"x"
,
scope
.
Var
(
"x"
)
->
GetMutable
<
framework
::
LoDTensor
>
()},
{
"x1"
,
num_inputs
>
1
size_t
numel
=
static_cast
<
size_t
>
(
framework
::
product
(
dims
));
?
scope
.
Var
(
"x1"
)
->
GetMutable
<
framework
::
LoDTensor
>
()
:
nullptr
},
auto
x_ptr
=
x
->
mutable_data
<
T
>
(
place
);
{
"x2"
,
num_inputs
>
2
auto
y_ptr
=
y
->
mutable_data
<
T
>
(
place
);
?
scope
.
Var
(
"x2"
)
->
GetMutable
<
framework
::
LoDTensor
>
()
:
nullptr
},
{
"x3"
,
num_inputs
>
3
?
scope
.
Var
(
"x3"
)
->
GetMutable
<
framework
::
LoDTensor
>
()
:
nullptr
},
{
"x4"
,
num_inputs
>
4
?
scope
.
Var
(
"x4"
)
->
GetMutable
<
framework
::
LoDTensor
>
()
:
nullptr
}};
auto
*
y
=
scope
.
Var
(
"y"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
// Initialize input data
std
::
uniform_real_distribution
<
T
>
dist
(
static_cast
<
T
>
(
10.0
),
std
::
uniform_real_distribution
<
T
>
dist
(
static_cast
<
T
>
(
10.0
),
static_cast
<
T
>
(
20.0
));
static_cast
<
T
>
(
20.0
));
std
::
mt19937
engine
;
std
::
mt19937
engine
;
size_t
numel
=
static_cast
<
size_t
>
(
framework
::
product
(
dims
));
for
(
int
i
=
0
;
i
<
num_inputs
;
++
i
)
{
input_names
[
i
].
tensor
->
Resize
(
dims
);
auto
data_ptr
=
input_names
[
i
].
tensor
->
mutable_data
<
T
>
(
place
);
for
(
size_t
i
=
0
;
i
<
numel
;
++
i
)
{
data_ptr
[
i
]
=
dist
(
engine
);
}
}
// Initialize output
y
->
Resize
(
dims
);
auto
y_ptr
=
y
->
mutable_data
<
T
>
(
place
);
for
(
size_t
i
=
0
;
i
<
numel
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
numel
;
++
i
)
{
x_ptr
[
i
]
=
dist
(
engine
);
y_ptr
[
i
]
=
static_cast
<
T
>
(
0
);
y_ptr
[
i
]
=
static_cast
<
T
>
(
0
);
}
}
auto
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
// Out of place (reference) computation
// Out of place (reference) computation
auto
op_ref
=
framework
::
OpRegistry
::
CreateOp
(
auto
op_ref
=
num_inputs
>
1
?
framework
::
OpRegistry
::
CreateOp
(
"softmax"
,
{{
"X"
,
{
"x"
}}},
{{
"Out"
,
{
"y"
}}},
{{
"use_mkldnn"
,
{
true
}}});
op_type
,
{{
"X"
,
{
"x"
}},
{
"Y"
,
{
"x1"
}}},
{{
"Out"
,
{
"y"
}}},
{{
"use_mkldnn"
,
{
true
}}})
:
framework
::
OpRegistry
::
CreateOp
(
op_type
,
{{
"X"
,
{
"x"
}}},
{{
"Out"
,
{
"y"
}}},
{{
"use_mkldnn"
,
{
true
}}});
op_ref
->
Run
(
scope
,
place
);
op_ref
->
Run
(
scope
,
place
);
pool
.
Get
(
place
)
->
Wait
();
pool
.
Get
(
place
)
->
Wait
();
...
@@ -66,15 +96,20 @@ bool TestMain(const platform::Place &place, const framework::DDim &dims) {
...
@@ -66,15 +96,20 @@ bool TestMain(const platform::Place &place, const framework::DDim &dims) {
auto
&
ref_tensor
=
scope
.
FindVar
(
"y"
)
->
Get
<
framework
::
LoDTensor
>
();
auto
&
ref_tensor
=
scope
.
FindVar
(
"y"
)
->
Get
<
framework
::
LoDTensor
>
();
// In-place (to be tested) computation
// In-place (to be tested) computation
auto
op
=
framework
::
OpRegistry
::
CreateOp
(
auto
op
=
num_inputs
>
1
?
framework
::
OpRegistry
::
CreateOp
(
"softmax"
,
{{
"X"
,
{
"x"
}}},
{{
"Out"
,
{
"x"
}}},
{{
"use_mkldnn"
,
{
true
}}});
op_type
,
{{
"X"
,
{
"x"
}},
{
"Y"
,
{
"x1"
}}},
{{
"Out"
,
{
"x"
}}},
{{
"use_mkldnn"
,
{
true
}}})
:
framework
::
OpRegistry
::
CreateOp
(
op_type
,
{{
"X"
,
{
"x"
}}},
{{
"Out"
,
{
"x"
}}},
{{
"use_mkldnn"
,
{
true
}}});
op
->
Run
(
scope
,
place
);
op
->
Run
(
scope
,
place
);
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
)
->
Wait
();
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
)
->
Wait
();
// Get in-place result
// Get in-place result
auto
&
out_tensor
=
scope
.
FindVar
(
"x"
)
->
Get
<
framework
::
LoDTensor
>
();
auto
&
out_tensor
=
scope
.
FindVar
(
"x"
)
->
Get
<
framework
::
LoDTensor
>
();
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
&
out_tensor
,
x
,
&
out_tensor
,
input_names
[
0
].
tensor
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Input and output vars should share tensor for In-place test"
));
"Input and output vars should share tensor for In-place test"
));
...
@@ -88,7 +123,13 @@ bool TestMain(const platform::Place &place, const framework::DDim &dims) {
...
@@ -88,7 +123,13 @@ bool TestMain(const platform::Place &place, const framework::DDim &dims) {
TEST
(
test_softmax_inplace
,
cpu_place
)
{
TEST
(
test_softmax_inplace
,
cpu_place
)
{
framework
::
DDim
dims
({
32
,
64
});
framework
::
DDim
dims
({
32
,
64
});
platform
::
CPUPlace
p
;
platform
::
CPUPlace
p
;
ASSERT_TRUE
(
TestMain
<
float
>
(
p
,
dims
));
ASSERT_TRUE
(
TestMain
<
float
>
(
p
,
"softmax"
,
dims
,
1
));
}
TEST
(
test_elementwise_add_inplace
,
cpu_place
)
{
framework
::
DDim
dims
({
1
,
12
,
20
,
20
});
platform
::
CPUPlace
p
;
ASSERT_TRUE
(
TestMain
<
float
>
(
p
,
"elementwise_add"
,
dims
,
2
));
}
}
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/platform/mkldnn_helper.h
浏览文件 @
c6c65c65
...
@@ -101,6 +101,11 @@ inline void MatchShapeToLayout(framework::Tensor* tensor_in,
...
@@ -101,6 +101,11 @@ inline void MatchShapeToLayout(framework::Tensor* tensor_in,
}
}
}
}
struct
mkldnn_dummy_primitive
{
struct
primitive_desc
{};
struct
desc
{};
};
inline
mkldnn
::
memory
::
desc
MKLDNNMemDesc
(
const
std
::
vector
<
int64_t
>&
dims
,
inline
mkldnn
::
memory
::
desc
MKLDNNMemDesc
(
const
std
::
vector
<
int64_t
>&
dims
,
mkldnn
::
memory
::
data_type
data_type
,
mkldnn
::
memory
::
data_type
data_type
,
MKLDNNMemoryFormat
format
)
{
MKLDNNMemoryFormat
format
)
{
...
...
paddle/fluid/platform/mkldnn_reuse.h
浏览文件 @
c6c65c65
...
@@ -30,7 +30,8 @@ namespace platform {
...
@@ -30,7 +30,8 @@ namespace platform {
using
user_function
=
std
::
function
<
std
::
shared_ptr
<
float
>
(
const
float
*
)
>
;
using
user_function
=
std
::
function
<
std
::
shared_ptr
<
float
>
(
const
float
*
)
>
;
using
memory
=
mkldnn
::
memory
;
using
memory
=
mkldnn
::
memory
;
template
<
typename
T
,
typename
TForward
,
typename
TBackward
>
template
<
typename
T
,
typename
TForward
,
typename
TBackward
=
mkldnn_dummy_primitive
>
class
MKLDNNHandlerT
{
class
MKLDNNHandlerT
{
public:
public:
MKLDNNHandlerT
(
const
MKLDNNDeviceContext
&
dev_ctx
,
mkldnn
::
engine
engine
,
MKLDNNHandlerT
(
const
MKLDNNDeviceContext
&
dev_ctx
,
mkldnn
::
engine
engine
,
...
@@ -351,6 +352,35 @@ class MKLDNNHandler {
...
@@ -351,6 +352,35 @@ class MKLDNNHandler {
std
::
string
key_common_
;
std
::
string
key_common_
;
};
};
template
<
typename
T
>
class
BinaryMKLDNNHandler
:
public
platform
::
MKLDNNHandlerT
<
T
,
dnnl
::
binary
>
{
public:
BinaryMKLDNNHandler
(
const
dnnl
::
algorithm
algo
,
const
std
::
vector
<
int64_t
>&
dims
,
const
MKLDNNMemoryFormat
src0_fmt
,
const
MKLDNNMemoryFormat
src1_fmt
,
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
platform
::
Place
cpu_place
,
const
std
::
string
&
uniq_name
)
:
platform
::
MKLDNNHandlerT
<
T
,
dnnl
::
binary
>
(
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
platform
::
CreateKey
(
dims
,
uniq_name
))
{
// TODO(jczaja): Add function checking if data already exists
auto
src0_md
=
dnnl
::
memory
::
desc
(
dims
,
MKLDNNGetDataType
<
T
>
(),
src0_fmt
);
auto
src1_md
=
dnnl
::
memory
::
desc
(
dims
,
MKLDNNGetDataType
<
T
>
(),
src1_fmt
);
auto
dst_md
=
memory
::
desc
(
dims
,
MKLDNNGetDataType
<
T
>
(),
MKLDNNMemoryFormat
::
any
);
this
->
AcquireForwardPrimitiveDescriptor
(
algo
,
src0_md
,
src1_md
,
dst_md
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireSecondSrcMemory
(
const
framework
::
Tensor
*
input
)
{
const
T
*
input_data
=
input
->
data
<
T
>
();
return
this
->
AcquireMemoryFromPrimitive
(
this
->
fwd_pd_
->
src_desc
(),
to_void_cast
<
T
>
(
input_data
),
"@src1_mem_p"
);
}
};
class
SumMKLDNNHandler
:
public
MKLDNNHandler
{
class
SumMKLDNNHandler
:
public
MKLDNNHandler
{
public:
public:
SumMKLDNNHandler
(
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
SumMKLDNNHandler
(
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
...
@@ -419,7 +449,7 @@ class ActivationMKLDNNHandler
...
@@ -419,7 +449,7 @@ class ActivationMKLDNNHandler
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
eltwise_forward
,
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
eltwise_forward
,
mkldnn
::
eltwise_backward
>
(
mkldnn
::
eltwise_backward
>
(
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
platform
::
CreateKey
(
dims
,
unique_name
))
{
platform
::
CreateKey
(
dims
,
"a"
,
algorithm
,
unique_name
))
{
auto
md
=
mkldnn
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
fmt
);
auto
md
=
mkldnn
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
fmt
);
this
->
AcquireForwardPrimitiveDescriptor
(
mkldnn
::
prop_kind
::
forward_training
,
this
->
AcquireForwardPrimitiveDescriptor
(
mkldnn
::
prop_kind
::
forward_training
,
...
@@ -437,7 +467,7 @@ class ActivationMKLDNNHandler
...
@@ -437,7 +467,7 @@ class ActivationMKLDNNHandler
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
eltwise_forward
,
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
eltwise_forward
,
mkldnn
::
eltwise_backward
>
(
mkldnn
::
eltwise_backward
>
(
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
platform
::
CreateKey
(
dims
,
unique_name
))
{
platform
::
CreateKey
(
dims
,
"a"
,
algorithm
,
unique_name
))
{
auto
diff_dst_md
=
platform
::
MKLDNNMemDesc
(
auto
diff_dst_md
=
platform
::
MKLDNNMemDesc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
diff_fmt
);
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
diff_fmt
);
auto
src_md
=
auto
src_md
=
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录