Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ba90e052
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2297
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
ba90e052
编写于
2月 27, 2019
作者:
T
Tao Luo
提交者:
GitHub
2月 27, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #15917 from jczaja/prv-tensor-mkldnn-ops
[MKL-DNN] Adjusting ops to Tensor modifications
上级
7d8f6398
c63f6b20
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
48 addition
and
86 deletion
+48
-86
paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc
...operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc
+6
-13
paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc
paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc
+7
-17
paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc
paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc
+11
-25
paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc
paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc
+1
-7
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
+2
-4
paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc
paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc
+1
-2
paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc
paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc
+8
-13
paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc
paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc
+8
-0
paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc
paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc
+4
-5
未找到文件。
paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc
浏览文件 @
ba90e052
...
...
@@ -77,8 +77,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
}
else
{
functor
.
RunMidWise
(
n
,
pre
,
post
);
}
z
->
set_layout
(
DataLayout
::
kMKLDNN
);
z
->
set_format
(
x
->
format
());
z
->
set_mkldnn_prim_desc
(
x
->
get_mkldnn_prim_desc
());
}
else
{
PADDLE_ENFORCE
(
x
->
layout
()
==
DataLayout
::
kMKLDNN
&&
x
->
format
()
!=
memory
::
format
::
format_undef
,
...
...
@@ -116,7 +115,8 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
auto
sum_pd
=
sum
::
primitive_desc
(
dst_md
,
scales
,
srcs_pd
);
// create mkldnn memory for dst
memory
dst_memory
=
memory
(
sum_pd
.
dst_primitive_desc
(),
z_data
);
auto
dst_mem_pd
=
sum_pd
.
dst_primitive_desc
();
memory
dst_memory
=
memory
(
dst_mem_pd
,
z_data
);
std
::
vector
<
primitive
::
at
>
inputs
;
inputs
.
push_back
(
srcs
[
0
]);
...
...
@@ -129,9 +129,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
pipeline
.
push_back
(
sum_prim
);
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
z
->
set_layout
(
DataLayout
::
kMKLDNN
);
z
->
set_format
(
(
memory
::
format
)
dst_memory
.
get_primitive_desc
().
desc
().
data
.
format
);
z
->
set_mkldnn_prim_desc
(
dst_mem_pd
);
}
}
};
...
...
@@ -152,24 +150,19 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
auto
*
out
=
dout
;
auto
*
x
=
dout
,
*
y
=
dout
;
auto
set_mkldnn_format
=
[](
Tensor
*
in
,
const
Tensor
*
out
)
{
in
->
set_layout
(
DataLayout
::
kMKLDNN
);
in
->
set_format
(
out
->
format
());
};
if
(
dx
!=
nullptr
&&
dy
!=
nullptr
&&
dx
->
dims
()
==
dy
->
dims
())
{
if
(
dx
->
dims
()
==
dy
->
dims
())
{
auto
blas
=
math
::
GetBlas
<
paddle
::
platform
::
CPUDeviceContext
,
T
>
(
ctx
);
if
(
dx
)
{
blas
.
VCOPY
(
dout
->
numel
(),
dout
->
data
<
T
>
(),
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
set_mkldnn_format
(
dx
,
dout
);
dx
->
set_mkldnn_prim_desc
(
dout
->
get_mkldnn_prim_desc
()
);
}
if
(
dy
)
{
blas
.
VCOPY
(
dout
->
numel
(),
dout
->
data
<
T
>
(),
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
set_mkldnn_format
(
dy
,
dout
);
dy
->
set_mkldnn_prim_desc
(
dout
->
get_mkldnn_prim_desc
()
);
}
}
}
else
{
...
...
paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc
浏览文件 @
ba90e052
...
...
@@ -96,8 +96,7 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
std
::
vector
<
int
>
src_tz
=
framework
::
vectorize2int
(
x
->
dims
());
auto
src_format
=
src_tz
.
size
()
==
2
?
mkldnn
::
memory
::
format
::
nc
:
x
->
format
();
auto
src_format
=
x
->
format
();
const
std
::
string
key
=
gethash
(
src_tz
,
algorithm
);
const
std
::
string
key_src_data
=
...
...
@@ -127,10 +126,8 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
if
(
p_fwd
==
nullptr
)
{
// create mkldnn memory for input X
auto
src_md
=
platform
::
MKLDNNMemDesc
(
src_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
src_format
);
auto
src_memory
=
std
::
shared_ptr
<
memory
>
(
new
memory
(
{
src_md
,
mkldnn_engine
}
,
to_void_cast
(
x_data
)));
new
memory
(
x
->
get_mkldnn_prim_desc
()
,
to_void_cast
(
x_data
)));
// save src_memory to be referred in backward path
dev_ctx
.
SetBlob
(
key_src_mem
,
src_memory
);
...
...
@@ -177,8 +174,7 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
pipeline
.
push_back
(
*
p_fwd
);
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
y
->
set_layout
(
DataLayout
::
kMKLDNN
);
y
->
set_format
(
GetMKLDNNFormat
(
*
dst_memory
));
y
->
set_mkldnn_prim_desc
(
dst_memory
->
get_primitive_desc
());
}
template
<
typename
T
>
...
...
@@ -196,9 +192,6 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
std
::
vector
<
int
>
diff_dst_tz
=
framework
::
vectorize2int
(
diff_y
->
dims
());
auto
diff_y_format
=
diff_dst_tz
.
size
()
==
2
?
mkldnn
::
memory
::
format
::
nc
:
diff_y
->
format
();
const
std
::
string
key
=
gethash
(
diff_dst_tz
,
algorithm
);
const
std
::
string
key_src_data
=
key
+
ctx
.
op
().
Input
(
"Out"
)
+
"@eltwise_fwd_src_data"
;
...
...
@@ -210,8 +203,8 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
key
+
std
::
to_string
(
*
p_src_layout
)
+
"@eltwise_fwd_src_mem"
;
const
std
::
string
key_fwd_pd
=
key
+
std
::
to_string
(
*
p_src_layout
)
+
"@eltwise_fwd_pd"
;
const
std
::
string
key_with_layouts
=
key
+
std
::
to_string
(
*
p_src_layout
)
+
"-"
+
std
::
to_string
(
diff_y_format
);
const
std
::
string
key_with_layouts
=
key
+
std
::
to_string
(
*
p_src_layout
)
+
"-"
+
std
::
to_string
(
diff_y
->
format
()
);
const
std
::
string
key_diff_src_mem
=
key_with_layouts
+
"@eltwise_diff_src_mem"
;
const
std
::
string
key_diff_dst_mem
=
...
...
@@ -234,10 +227,8 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
if
(
p_grad
==
nullptr
)
{
// create mkldnn memory for input diff_y
auto
diff_dst_md
=
platform
::
MKLDNNMemDesc
(
diff_dst_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
diff_y_format
);
auto
diff_dst_memory
=
std
::
shared_ptr
<
memory
>
(
new
memory
(
{
diff_dst_md
,
mkldnn_engine
}
,
to_void_cast
(
diff_y_data
)));
new
memory
(
diff_y
->
get_mkldnn_prim_desc
()
,
to_void_cast
(
diff_y_data
)));
dev_ctx
.
SetBlob
(
key_diff_dst_mem
,
diff_dst_memory
);
// retrieve eltwise primitive desc from device context
...
...
@@ -281,8 +272,7 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
pipeline
.
push_back
(
*
p_grad
);
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
diff_x
->
set_layout
(
DataLayout
::
kMKLDNN
);
diff_x
->
set_format
(
GetMKLDNNFormat
(
*
diff_src_memory
));
diff_x
->
set_mkldnn_prim_desc
(
diff_src_memory
->
get_primitive_desc
());
}
template
<
typename
T
,
mkldnn
::
algorithm
algorithm
>
...
...
paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc
浏览文件 @
ba90e052
...
...
@@ -206,17 +206,14 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if
(
fuse_with_relu
)
flags
|=
mkldnn
::
fuse_bn_relu
;
// create mkldnn memory from input x tensor
mkldnn
::
memory
::
format
input_format
=
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
x
->
format
());
// keys for backward pass
const
std
::
string
key
=
BatchNormMKLDNNHandler
::
GetHash
(
src_tz
,
epsilon
,
flags
,
global_stats
,
input_format
,
src_tz
,
epsilon
,
flags
,
global_stats
,
x
->
format
()
,
ctx
.
op
().
Output
(
"SavedMean"
));
const
std
::
string
key_batch_norm_fwd_pd
=
key
+
"@bn_fwd_pd"
;
auto
user_src_md
=
platform
::
MKLDNNMemDesc
(
{
src_tz
},
platform
::
MKLDNNGetDataType
<
T
>
(),
input_format
);
auto
user_src_md
=
x
->
get_mkldnn_prim_desc
().
desc
();
// create primitive descriptor for batch norm forward
using
bn_fwd_types
=
bn_type_traits
<
mkldnn
::
batch_normalization_forward
>
;
...
...
@@ -230,8 +227,8 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
BatchNormMKLDNNHandler
handler
(
batch_norm_fwd_pd
,
dev_ctx
,
mkldnn_engine
,
key
);
auto
src_memory
=
handler
.
AcquireSrcMemory
(
user_src_md
,
to_void_cast
(
x_data
));
auto
src_memory
=
handler
.
AcquireSrcMemory
(
x
->
get_mkldnn_prim_desc
(),
to_void_cast
(
x_data
));
// crate mkldnn memory for weights(scale/shift)
auto
scaleshift_memory
=
...
...
@@ -265,8 +262,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
variance_memory
,
false
);
}
y
->
set_layout
(
DataLayout
::
kMKLDNN
);
y
->
set_format
(
platform
::
GetMKLDNNFormat
(
*
dst_memory
));
y
->
set_mkldnn_prim_desc
(
dst_memory
->
get_primitive_desc
());
std
::
vector
<
mkldnn
::
primitive
>
pipeline
;
pipeline
.
push_back
(
*
batch_norm_p
);
...
...
@@ -336,9 +332,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
using
bn_bwd_types
=
bn_type_traits
<
mkldnn
::
batch_normalization_backward
>
;
mkldnn
::
memory
::
format
dst_format
=
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
diff_y
->
format
());
mkldnn
::
memory
::
format
input_format
=
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
x
->
format
());
...
...
@@ -346,14 +339,14 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// keys from forward pass
const
std
::
string
key
=
BatchNormMKLDNNHandler
::
GetHash
(
src_tz
,
epsilon
,
flags
,
false
,
input_format
,
src_tz
,
epsilon
,
flags
,
false
,
x
->
format
()
,
ctx
.
op
().
Input
(
"SavedMean"
));
const
std
::
string
key_batch_norm_fwd_pd
=
key
+
"@bn_fwd_pd"
;
// keys for primitives reuse
const
std
::
string
key_with_hash
=
key
+
BatchNormMKLDNNHandler
::
GetHash
(
src_tz
,
epsilon
,
flags
,
false
,
input_format
);
x
->
format
()
);
const
std
::
string
key_batch_norm_bwd_p
=
key_with_hash
+
"@batch_norm_bwd_p"
;
const
std
::
string
key_batch_norm_src_mem_p
=
...
...
@@ -373,9 +366,8 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
primitive
reorder_diff_dst
;
bool
is_diff_dst_reordered
=
false
;
auto
user_diff_dst_memory
=
memory
(
{{{
diff_dst_tz
},
memory
::
data_type
::
f32
,
dst_format
},
mkldnn_engine
},
to_void_cast
(
diff_y_data
));
auto
user_diff_dst_memory
=
memory
(
diff_y
->
get_mkldnn_prim_desc
(),
to_void_cast
(
diff_y_data
));
// MKLDNN requires a single piece of memory for scale and shift/bias data
const
size_t
scaleshift_size
=
2
*
ic
;
...
...
@@ -459,10 +451,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
dev_ctx
.
SetBlob
(
key_batch_norm_diff_dst_mem_p
,
diff_dst_memory
);
// set layout/format of output tensors
diff_x
->
set_layout
(
DataLayout
::
kMKLDNN
);
diff_x
->
set_format
((
memory
::
format
)
diff_src_memory
->
get_primitive_desc
()
.
desc
()
.
data
.
format
);
diff_x
->
set_mkldnn_prim_desc
(
diff_src_memory
->
get_primitive_desc
());
}
else
{
// primitives already exist
UpdateMemoryData
(
dev_ctx
,
key_batch_norm_src_mem_p
,
to_void_cast
(
x_data
));
...
...
@@ -487,10 +476,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
}
// set layout/format of output tensors
diff_x
->
set_layout
(
DataLayout
::
kMKLDNN
);
diff_x
->
set_format
((
memory
::
format
)
diff_src_memory
->
get_primitive_desc
()
.
desc
()
.
data
.
format
);
diff_x
->
set_mkldnn_prim_desc
(
diff_src_memory
->
get_primitive_desc
());
}
// execute optional reorder and batch_norm backward primitive
...
...
paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc
浏览文件 @
ba90e052
...
...
@@ -47,11 +47,6 @@ static memory::primitive_desc CreateMemPrimDesc(const Tensor& input,
return
mem_prim_desc
;
}
static
mkldnn
::
memory
::
format
GetDstMemFormat
(
const
concat
::
primitive_desc
&
concat_pd
)
{
return
(
memory
::
format
)
concat_pd
.
dst_primitive_desc
().
desc
().
data
.
format
;
}
static
platform
::
CPUPlace
GetCpuPlace
(
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
{
auto
place
=
ctx
.
GetPlace
();
...
...
@@ -139,8 +134,7 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto
concat
=
prim_creator
.
CreateConcatPrimitive
(
concat_pd
,
output
,
place
);
stream
(
stream
::
kind
::
eager
).
submit
({
concat
}).
wait
();
output
->
set_layout
(
DataLayout
::
kMKLDNN
);
output
->
set_format
(
GetDstMemFormat
(
concat_pd
));
output
->
set_mkldnn_prim_desc
(
concat_pd
.
dst_primitive_desc
());
}
};
}
// namespace operators
...
...
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
浏览文件 @
ba90e052
...
...
@@ -282,8 +282,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
pipeline
.
push_back
(
*
conv_p
);
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
auto
dst_mpd
=
dst_memory_p
->
get_primitive_desc
();
output
->
set_mkldnn_prim_desc
(
dst_mpd
);
output
->
set_mkldnn_prim_desc
(
dst_memory_p
->
get_primitive_desc
());
}
void
ComputeINT8
(
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
const
{
const
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
...
...
@@ -972,8 +971,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
pipeline
.
push_back
(
*
conv_bwd_data_p
);
input_grad
->
set_layout
(
DataLayout
::
kMKLDNN
);
input_grad
->
set_format
(
GetMKLDNNFormat
(
*
diff_src_memory_p
));
input_grad
->
set_mkldnn_prim_desc
(
diff_src_memory_p
->
get_primitive_desc
());
}
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
}
...
...
paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc
浏览文件 @
ba90e052
...
...
@@ -221,8 +221,7 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
pipeline
.
push_back
(
*
conv_p
);
mkldnn
::
stream
(
mkldnn
::
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
output
->
set_layout
(
DataLayout
::
kMKLDNN
);
output
->
set_format
(
platform
::
GetMKLDNNFormat
(
*
dst_memory_p
));
output
->
set_mkldnn_prim_desc
(
dst_memory_p
->
get_primitive_desc
());
}
private:
...
...
paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc
浏览文件 @
ba90e052
...
...
@@ -81,10 +81,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto
e_mid
=
framework
::
EigenTensor
<
T
,
4
>::
From
(
*
mid
);
e_mid
=
e_mid
.
constant
(
k
);
auto
dims
=
paddle
::
framework
::
vectorize2int
(
x
->
dims
());
auto
src_md
=
paddle
::
platform
::
MKLDNNMemDesc
(
dims
,
mkldnn
::
memory
::
data_type
::
f32
,
x
->
format
());
auto
src_md
=
x
->
get_mkldnn_prim_desc
().
desc
();
auto
forward_desc
=
mkldnn
::
lrn_forward
::
desc
{
mkldnn
::
prop_kind
::
forward
,
mkldnn
::
lrn_across_channels
,
...
...
@@ -94,7 +91,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
beta
,
k
};
auto
src_memory_pd
=
mkldnn
::
memory
::
primitive_desc
{
src_md
,
mkldnn_engine
}
;
auto
src_memory_pd
=
x
->
get_mkldnn_prim_desc
()
;
if
(
!
is_test
)
{
const
std
::
string
key
=
ctx
.
op
().
Output
(
"Out"
);
...
...
@@ -111,16 +108,15 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
src_memory
->
set_data_handle
(
static_cast
<
void
*>
(
const_cast
<
T
*>
(
input_data
)));
auto
dst_memory
=
mkldnn
::
memory
(
forward_pd
->
dst_primitive_desc
(),
static_cast
<
void
*>
(
output_data
));
auto
dst_memory_pd
=
forward_pd
->
dst_primitive_desc
();
auto
dst_memory
=
mkldnn
::
memory
(
dst_memory_pd
,
static_cast
<
void
*>
(
output_data
));
auto
workspace_memory
=
insert_to_context
<
mkldnn
::
memory
>
(
key_workspace_memory
,
dev_ctx
,
forward_pd
->
workspace_primitive_desc
());
run_primitive
(
*
forward_pd
,
*
src_memory
,
*
workspace_memory
,
dst_memory
);
out
->
set_layout
(
framework
::
DataLayout
::
kMKLDNN
);
out
->
set_format
(
platform
::
GetMKLDNNFormat
(
dst_memory
));
out
->
set_mkldnn_prim_desc
(
dst_memory_pd
);
}
else
{
auto
forward_pd
=
mkldnn
::
lrn_forward
::
primitive_desc
{
forward_desc
,
mkldnn_engine
};
...
...
@@ -128,13 +124,12 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
src_memory_pd
,
static_cast
<
void
*>
(
const_cast
<
T
*>
(
input_data
))};
auto
workspace_memory
=
mkldnn
::
memory
{
forward_pd
.
workspace_primitive_desc
()};
auto
dst_memory_pd
=
forward_pd
.
dst_primitive_desc
();
auto
dst_memory
=
mkldnn
::
memory
(
forward_pd
.
dst_primitive_desc
(),
static_cast
<
void
*>
(
output_data
));
run_primitive
(
forward_pd
,
src_memory
,
workspace_memory
,
dst_memory
);
out
->
set_layout
(
framework
::
DataLayout
::
kMKLDNN
);
out
->
set_format
(
platform
::
GetMKLDNNFormat
(
dst_memory
));
out
->
set_mkldnn_prim_desc
(
dst_memory_pd
);
}
}
};
...
...
paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc
浏览文件 @
ba90e052
...
...
@@ -158,6 +158,14 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
auto
softmax_p
=
handler
.
AcquireSoftmax
(
softmax_dst_memory_p
,
softmax_src_memory_p
);
// We cannot use softmax_dst_memory_p to get prim desc as
// it contains flattened dims (2D) while output tensor can
// have 2,3,4+ dims
auto
output_mem_pd
=
paddle
::
platform
::
create_prim_desc_from_dims
(
paddle
::
framework
::
vectorize2int
(
output
->
dims
()),
mkldnn
::
memory
::
format
::
blocked
);
output
->
set_mkldnn_prim_desc
(
output_mem_pd
);
std
::
vector
<
primitive
>
pipeline
{
*
(
static_cast
<
softmax_forward
::
primitive
*>
(
softmax_p
.
get
()))};
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
...
...
paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc
浏览文件 @
ba90e052
...
...
@@ -106,12 +106,12 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
memory
::
desc
(
dst_tz
,
memory
::
data_type
::
f32
,
memory
::
format
::
any
);
auto
sum_pd
=
sum
::
primitive_desc
(
dst_md
,
scales
,
srcs_mpd
);
auto
dst_mem_pd
=
sum_pd
.
dst_primitive_desc
();
std
::
shared_ptr
<
memory
>
dst_mem
;
if
(
in_place
)
{
dst_mem
.
reset
(
new
memory
(
sum_pd
.
dst_primitive_desc
()
));
dst_mem
.
reset
(
new
memory
(
dst_mem_pd
));
}
else
{
dst_mem
.
reset
(
new
memory
(
sum_pd
.
dst_primitive_desc
()
,
output_data
));
dst_mem
.
reset
(
new
memory
(
dst_mem_pd
,
output_data
));
}
std
::
vector
<
mkldnn
::
primitive
::
at
>
inputs
;
for
(
size_t
i
=
0
;
i
<
srcs_mem
.
size
();
++
i
)
{
...
...
@@ -136,8 +136,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if
(
in_place
)
pipeline
.
push_back
(
reorder_prim
);
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
output
->
set_layout
(
DataLayout
::
kMKLDNN
);
output
->
set_format
(
output_format
);
output
->
set_mkldnn_prim_desc
(
dst_mem_pd
);
}
else
{
// Fallback to naive version
// TODO(@mozga-intel) Add MKLDNN SelectedRows & LoDTensorArray support
SumKernel
<
CPUDeviceContext
,
T
>
reference_kernel
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录