Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
ba90e052
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ba90e052
编写于
2月 27, 2019
作者:
T
Tao Luo
提交者:
GitHub
2月 27, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #15917 from jczaja/prv-tensor-mkldnn-ops
[MKL-DNN] Adjusting ops to Tensor modifications
上级
7d8f6398
c63f6b20
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
48 addition
and
86 deletion
+48
-86
paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc
...operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc
+6
-13
paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc
paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc
+7
-17
paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc
paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc
+11
-25
paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc
paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc
+1
-7
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
+2
-4
paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc
paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc
+1
-2
paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc
paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc
+8
-13
paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc
paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc
+8
-0
paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc
paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc
+4
-5
未找到文件。
paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc
浏览文件 @
ba90e052
...
...
@@ -77,8 +77,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
}
else
{
functor
.
RunMidWise
(
n
,
pre
,
post
);
}
z
->
set_layout
(
DataLayout
::
kMKLDNN
);
z
->
set_format
(
x
->
format
());
z
->
set_mkldnn_prim_desc
(
x
->
get_mkldnn_prim_desc
());
}
else
{
PADDLE_ENFORCE
(
x
->
layout
()
==
DataLayout
::
kMKLDNN
&&
x
->
format
()
!=
memory
::
format
::
format_undef
,
...
...
@@ -116,7 +115,8 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
auto
sum_pd
=
sum
::
primitive_desc
(
dst_md
,
scales
,
srcs_pd
);
// create mkldnn memory for dst
memory
dst_memory
=
memory
(
sum_pd
.
dst_primitive_desc
(),
z_data
);
auto
dst_mem_pd
=
sum_pd
.
dst_primitive_desc
();
memory
dst_memory
=
memory
(
dst_mem_pd
,
z_data
);
std
::
vector
<
primitive
::
at
>
inputs
;
inputs
.
push_back
(
srcs
[
0
]);
...
...
@@ -129,9 +129,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
pipeline
.
push_back
(
sum_prim
);
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
z
->
set_layout
(
DataLayout
::
kMKLDNN
);
z
->
set_format
(
(
memory
::
format
)
dst_memory
.
get_primitive_desc
().
desc
().
data
.
format
);
z
->
set_mkldnn_prim_desc
(
dst_mem_pd
);
}
}
};
...
...
@@ -152,24 +150,19 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
auto
*
out
=
dout
;
auto
*
x
=
dout
,
*
y
=
dout
;
auto
set_mkldnn_format
=
[](
Tensor
*
in
,
const
Tensor
*
out
)
{
in
->
set_layout
(
DataLayout
::
kMKLDNN
);
in
->
set_format
(
out
->
format
());
};
if
(
dx
!=
nullptr
&&
dy
!=
nullptr
&&
dx
->
dims
()
==
dy
->
dims
())
{
if
(
dx
->
dims
()
==
dy
->
dims
())
{
auto
blas
=
math
::
GetBlas
<
paddle
::
platform
::
CPUDeviceContext
,
T
>
(
ctx
);
if
(
dx
)
{
blas
.
VCOPY
(
dout
->
numel
(),
dout
->
data
<
T
>
(),
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
set_mkldnn_format
(
dx
,
dout
);
dx
->
set_mkldnn_prim_desc
(
dout
->
get_mkldnn_prim_desc
()
);
}
if
(
dy
)
{
blas
.
VCOPY
(
dout
->
numel
(),
dout
->
data
<
T
>
(),
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
set_mkldnn_format
(
dy
,
dout
);
dy
->
set_mkldnn_prim_desc
(
dout
->
get_mkldnn_prim_desc
()
);
}
}
}
else
{
...
...
paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc
浏览文件 @
ba90e052
...
...
@@ -96,8 +96,7 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
std
::
vector
<
int
>
src_tz
=
framework
::
vectorize2int
(
x
->
dims
());
auto
src_format
=
src_tz
.
size
()
==
2
?
mkldnn
::
memory
::
format
::
nc
:
x
->
format
();
auto
src_format
=
x
->
format
();
const
std
::
string
key
=
gethash
(
src_tz
,
algorithm
);
const
std
::
string
key_src_data
=
...
...
@@ -127,10 +126,8 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
if
(
p_fwd
==
nullptr
)
{
// create mkldnn memory for input X
auto
src_md
=
platform
::
MKLDNNMemDesc
(
src_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
src_format
);
auto
src_memory
=
std
::
shared_ptr
<
memory
>
(
new
memory
(
{
src_md
,
mkldnn_engine
}
,
to_void_cast
(
x_data
)));
new
memory
(
x
->
get_mkldnn_prim_desc
()
,
to_void_cast
(
x_data
)));
// save src_memory to be referred in backward path
dev_ctx
.
SetBlob
(
key_src_mem
,
src_memory
);
...
...
@@ -177,8 +174,7 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
pipeline
.
push_back
(
*
p_fwd
);
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
y
->
set_layout
(
DataLayout
::
kMKLDNN
);
y
->
set_format
(
GetMKLDNNFormat
(
*
dst_memory
));
y
->
set_mkldnn_prim_desc
(
dst_memory
->
get_primitive_desc
());
}
template
<
typename
T
>
...
...
@@ -196,9 +192,6 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
std
::
vector
<
int
>
diff_dst_tz
=
framework
::
vectorize2int
(
diff_y
->
dims
());
auto
diff_y_format
=
diff_dst_tz
.
size
()
==
2
?
mkldnn
::
memory
::
format
::
nc
:
diff_y
->
format
();
const
std
::
string
key
=
gethash
(
diff_dst_tz
,
algorithm
);
const
std
::
string
key_src_data
=
key
+
ctx
.
op
().
Input
(
"Out"
)
+
"@eltwise_fwd_src_data"
;
...
...
@@ -210,8 +203,8 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
key
+
std
::
to_string
(
*
p_src_layout
)
+
"@eltwise_fwd_src_mem"
;
const
std
::
string
key_fwd_pd
=
key
+
std
::
to_string
(
*
p_src_layout
)
+
"@eltwise_fwd_pd"
;
const
std
::
string
key_with_layouts
=
key
+
std
::
to_string
(
*
p_src_layout
)
+
"-"
+
std
::
to_string
(
diff_y_format
);
const
std
::
string
key_with_layouts
=
key
+
std
::
to_string
(
*
p_src_layout
)
+
"-"
+
std
::
to_string
(
diff_y
->
format
()
);
const
std
::
string
key_diff_src_mem
=
key_with_layouts
+
"@eltwise_diff_src_mem"
;
const
std
::
string
key_diff_dst_mem
=
...
...
@@ -234,10 +227,8 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
if
(
p_grad
==
nullptr
)
{
// create mkldnn memory for input diff_y
auto
diff_dst_md
=
platform
::
MKLDNNMemDesc
(
diff_dst_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
diff_y_format
);
auto
diff_dst_memory
=
std
::
shared_ptr
<
memory
>
(
new
memory
(
{
diff_dst_md
,
mkldnn_engine
}
,
to_void_cast
(
diff_y_data
)));
new
memory
(
diff_y
->
get_mkldnn_prim_desc
()
,
to_void_cast
(
diff_y_data
)));
dev_ctx
.
SetBlob
(
key_diff_dst_mem
,
diff_dst_memory
);
// retrieve eltwise primitive desc from device context
...
...
@@ -281,8 +272,7 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
pipeline
.
push_back
(
*
p_grad
);
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
diff_x
->
set_layout
(
DataLayout
::
kMKLDNN
);
diff_x
->
set_format
(
GetMKLDNNFormat
(
*
diff_src_memory
));
diff_x
->
set_mkldnn_prim_desc
(
diff_src_memory
->
get_primitive_desc
());
}
template
<
typename
T
,
mkldnn
::
algorithm
algorithm
>
...
...
paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc
浏览文件 @
ba90e052
...
...
@@ -206,17 +206,14 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if
(
fuse_with_relu
)
flags
|=
mkldnn
::
fuse_bn_relu
;
// create mkldnn memory from input x tensor
mkldnn
::
memory
::
format
input_format
=
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
x
->
format
());
// keys for backward pass
const
std
::
string
key
=
BatchNormMKLDNNHandler
::
GetHash
(
src_tz
,
epsilon
,
flags
,
global_stats
,
input_format
,
src_tz
,
epsilon
,
flags
,
global_stats
,
x
->
format
()
,
ctx
.
op
().
Output
(
"SavedMean"
));
const
std
::
string
key_batch_norm_fwd_pd
=
key
+
"@bn_fwd_pd"
;
auto
user_src_md
=
platform
::
MKLDNNMemDesc
(
{
src_tz
},
platform
::
MKLDNNGetDataType
<
T
>
(),
input_format
);
auto
user_src_md
=
x
->
get_mkldnn_prim_desc
().
desc
();
// create primitive descriptor for batch norm forward
using
bn_fwd_types
=
bn_type_traits
<
mkldnn
::
batch_normalization_forward
>
;
...
...
@@ -230,8 +227,8 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
BatchNormMKLDNNHandler
handler
(
batch_norm_fwd_pd
,
dev_ctx
,
mkldnn_engine
,
key
);
auto
src_memory
=
handler
.
AcquireSrcMemory
(
user_src_md
,
to_void_cast
(
x_data
));
auto
src_memory
=
handler
.
AcquireSrcMemory
(
x
->
get_mkldnn_prim_desc
(),
to_void_cast
(
x_data
));
// crate mkldnn memory for weights(scale/shift)
auto
scaleshift_memory
=
...
...
@@ -265,8 +262,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
variance_memory
,
false
);
}
y
->
set_layout
(
DataLayout
::
kMKLDNN
);
y
->
set_format
(
platform
::
GetMKLDNNFormat
(
*
dst_memory
));
y
->
set_mkldnn_prim_desc
(
dst_memory
->
get_primitive_desc
());
std
::
vector
<
mkldnn
::
primitive
>
pipeline
;
pipeline
.
push_back
(
*
batch_norm_p
);
...
...
@@ -336,9 +332,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
using
bn_bwd_types
=
bn_type_traits
<
mkldnn
::
batch_normalization_backward
>
;
mkldnn
::
memory
::
format
dst_format
=
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
diff_y
->
format
());
mkldnn
::
memory
::
format
input_format
=
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
x
->
format
());
...
...
@@ -346,14 +339,14 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// keys from forward pass
const
std
::
string
key
=
BatchNormMKLDNNHandler
::
GetHash
(
src_tz
,
epsilon
,
flags
,
false
,
input_format
,
src_tz
,
epsilon
,
flags
,
false
,
x
->
format
()
,
ctx
.
op
().
Input
(
"SavedMean"
));
const
std
::
string
key_batch_norm_fwd_pd
=
key
+
"@bn_fwd_pd"
;
// keys for primitives reuse
const
std
::
string
key_with_hash
=
key
+
BatchNormMKLDNNHandler
::
GetHash
(
src_tz
,
epsilon
,
flags
,
false
,
input_format
);
x
->
format
()
);
const
std
::
string
key_batch_norm_bwd_p
=
key_with_hash
+
"@batch_norm_bwd_p"
;
const
std
::
string
key_batch_norm_src_mem_p
=
...
...
@@ -373,9 +366,8 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
primitive
reorder_diff_dst
;
bool
is_diff_dst_reordered
=
false
;
auto
user_diff_dst_memory
=
memory
(
{{{
diff_dst_tz
},
memory
::
data_type
::
f32
,
dst_format
},
mkldnn_engine
},
to_void_cast
(
diff_y_data
));
auto
user_diff_dst_memory
=
memory
(
diff_y
->
get_mkldnn_prim_desc
(),
to_void_cast
(
diff_y_data
));
// MKLDNN requires a single piece of memory for scale and shift/bias data
const
size_t
scaleshift_size
=
2
*
ic
;
...
...
@@ -459,10 +451,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
dev_ctx
.
SetBlob
(
key_batch_norm_diff_dst_mem_p
,
diff_dst_memory
);
// set layout/format of output tensors
diff_x
->
set_layout
(
DataLayout
::
kMKLDNN
);
diff_x
->
set_format
((
memory
::
format
)
diff_src_memory
->
get_primitive_desc
()
.
desc
()
.
data
.
format
);
diff_x
->
set_mkldnn_prim_desc
(
diff_src_memory
->
get_primitive_desc
());
}
else
{
// primitives already exist
UpdateMemoryData
(
dev_ctx
,
key_batch_norm_src_mem_p
,
to_void_cast
(
x_data
));
...
...
@@ -487,10 +476,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
}
// set layout/format of output tensors
diff_x
->
set_layout
(
DataLayout
::
kMKLDNN
);
diff_x
->
set_format
((
memory
::
format
)
diff_src_memory
->
get_primitive_desc
()
.
desc
()
.
data
.
format
);
diff_x
->
set_mkldnn_prim_desc
(
diff_src_memory
->
get_primitive_desc
());
}
// execute optional reorder and batch_norm backward primitive
...
...
paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc
浏览文件 @
ba90e052
...
...
@@ -47,11 +47,6 @@ static memory::primitive_desc CreateMemPrimDesc(const Tensor& input,
return
mem_prim_desc
;
}
static
mkldnn
::
memory
::
format
GetDstMemFormat
(
const
concat
::
primitive_desc
&
concat_pd
)
{
return
(
memory
::
format
)
concat_pd
.
dst_primitive_desc
().
desc
().
data
.
format
;
}
static
platform
::
CPUPlace
GetCpuPlace
(
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
{
auto
place
=
ctx
.
GetPlace
();
...
...
@@ -139,8 +134,7 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto
concat
=
prim_creator
.
CreateConcatPrimitive
(
concat_pd
,
output
,
place
);
stream
(
stream
::
kind
::
eager
).
submit
({
concat
}).
wait
();
output
->
set_layout
(
DataLayout
::
kMKLDNN
);
output
->
set_format
(
GetDstMemFormat
(
concat_pd
));
output
->
set_mkldnn_prim_desc
(
concat_pd
.
dst_primitive_desc
());
}
};
}
// namespace operators
...
...
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
浏览文件 @
ba90e052
...
...
@@ -282,8 +282,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
pipeline
.
push_back
(
*
conv_p
);
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
auto
dst_mpd
=
dst_memory_p
->
get_primitive_desc
();
output
->
set_mkldnn_prim_desc
(
dst_mpd
);
output
->
set_mkldnn_prim_desc
(
dst_memory_p
->
get_primitive_desc
());
}
void
ComputeINT8
(
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
const
{
const
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
...
...
@@ -972,8 +971,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
pipeline
.
push_back
(
*
conv_bwd_data_p
);
input_grad
->
set_layout
(
DataLayout
::
kMKLDNN
);
input_grad
->
set_format
(
GetMKLDNNFormat
(
*
diff_src_memory_p
));
input_grad
->
set_mkldnn_prim_desc
(
diff_src_memory_p
->
get_primitive_desc
());
}
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
}
...
...
paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc
浏览文件 @
ba90e052
...
...
@@ -221,8 +221,7 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
pipeline
.
push_back
(
*
conv_p
);
mkldnn
::
stream
(
mkldnn
::
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
output
->
set_layout
(
DataLayout
::
kMKLDNN
);
output
->
set_format
(
platform
::
GetMKLDNNFormat
(
*
dst_memory_p
));
output
->
set_mkldnn_prim_desc
(
dst_memory_p
->
get_primitive_desc
());
}
private:
...
...
paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc
浏览文件 @
ba90e052
...
...
@@ -81,10 +81,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto
e_mid
=
framework
::
EigenTensor
<
T
,
4
>::
From
(
*
mid
);
e_mid
=
e_mid
.
constant
(
k
);
auto
dims
=
paddle
::
framework
::
vectorize2int
(
x
->
dims
());
auto
src_md
=
paddle
::
platform
::
MKLDNNMemDesc
(
dims
,
mkldnn
::
memory
::
data_type
::
f32
,
x
->
format
());
auto
src_md
=
x
->
get_mkldnn_prim_desc
().
desc
();
auto
forward_desc
=
mkldnn
::
lrn_forward
::
desc
{
mkldnn
::
prop_kind
::
forward
,
mkldnn
::
lrn_across_channels
,
...
...
@@ -94,7 +91,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
beta
,
k
};
auto
src_memory_pd
=
mkldnn
::
memory
::
primitive_desc
{
src_md
,
mkldnn_engine
}
;
auto
src_memory_pd
=
x
->
get_mkldnn_prim_desc
()
;
if
(
!
is_test
)
{
const
std
::
string
key
=
ctx
.
op
().
Output
(
"Out"
);
...
...
@@ -111,16 +108,15 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
src_memory
->
set_data_handle
(
static_cast
<
void
*>
(
const_cast
<
T
*>
(
input_data
)));
auto
dst_memory
=
mkldnn
::
memory
(
forward_pd
->
dst_primitive_desc
(),
static_cast
<
void
*>
(
output_data
));
auto
dst_memory_pd
=
forward_pd
->
dst_primitive_desc
();
auto
dst_memory
=
mkldnn
::
memory
(
dst_memory_pd
,
static_cast
<
void
*>
(
output_data
));
auto
workspace_memory
=
insert_to_context
<
mkldnn
::
memory
>
(
key_workspace_memory
,
dev_ctx
,
forward_pd
->
workspace_primitive_desc
());
run_primitive
(
*
forward_pd
,
*
src_memory
,
*
workspace_memory
,
dst_memory
);
out
->
set_layout
(
framework
::
DataLayout
::
kMKLDNN
);
out
->
set_format
(
platform
::
GetMKLDNNFormat
(
dst_memory
));
out
->
set_mkldnn_prim_desc
(
dst_memory_pd
);
}
else
{
auto
forward_pd
=
mkldnn
::
lrn_forward
::
primitive_desc
{
forward_desc
,
mkldnn_engine
};
...
...
@@ -128,13 +124,12 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
src_memory_pd
,
static_cast
<
void
*>
(
const_cast
<
T
*>
(
input_data
))};
auto
workspace_memory
=
mkldnn
::
memory
{
forward_pd
.
workspace_primitive_desc
()};
auto
dst_memory_pd
=
forward_pd
.
dst_primitive_desc
();
auto
dst_memory
=
mkldnn
::
memory
(
forward_pd
.
dst_primitive_desc
(),
static_cast
<
void
*>
(
output_data
));
run_primitive
(
forward_pd
,
src_memory
,
workspace_memory
,
dst_memory
);
out
->
set_layout
(
framework
::
DataLayout
::
kMKLDNN
);
out
->
set_format
(
platform
::
GetMKLDNNFormat
(
dst_memory
));
out
->
set_mkldnn_prim_desc
(
dst_memory_pd
);
}
}
};
...
...
paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc
浏览文件 @
ba90e052
...
...
@@ -158,6 +158,14 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
auto
softmax_p
=
handler
.
AcquireSoftmax
(
softmax_dst_memory_p
,
softmax_src_memory_p
);
// We cannot use softmax_dst_memory_p to get prim desc as
// it contains flattened dims (2D) while output tensor can
// have 2,3,4+ dims
auto
output_mem_pd
=
paddle
::
platform
::
create_prim_desc_from_dims
(
paddle
::
framework
::
vectorize2int
(
output
->
dims
()),
mkldnn
::
memory
::
format
::
blocked
);
output
->
set_mkldnn_prim_desc
(
output_mem_pd
);
std
::
vector
<
primitive
>
pipeline
{
*
(
static_cast
<
softmax_forward
::
primitive
*>
(
softmax_p
.
get
()))};
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
...
...
paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc
浏览文件 @
ba90e052
...
...
@@ -106,12 +106,12 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
memory
::
desc
(
dst_tz
,
memory
::
data_type
::
f32
,
memory
::
format
::
any
);
auto
sum_pd
=
sum
::
primitive_desc
(
dst_md
,
scales
,
srcs_mpd
);
auto
dst_mem_pd
=
sum_pd
.
dst_primitive_desc
();
std
::
shared_ptr
<
memory
>
dst_mem
;
if
(
in_place
)
{
dst_mem
.
reset
(
new
memory
(
sum_pd
.
dst_primitive_desc
()
));
dst_mem
.
reset
(
new
memory
(
dst_mem_pd
));
}
else
{
dst_mem
.
reset
(
new
memory
(
sum_pd
.
dst_primitive_desc
()
,
output_data
));
dst_mem
.
reset
(
new
memory
(
dst_mem_pd
,
output_data
));
}
std
::
vector
<
mkldnn
::
primitive
::
at
>
inputs
;
for
(
size_t
i
=
0
;
i
<
srcs_mem
.
size
();
++
i
)
{
...
...
@@ -136,8 +136,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if
(
in_place
)
pipeline
.
push_back
(
reorder_prim
);
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
output
->
set_layout
(
DataLayout
::
kMKLDNN
);
output
->
set_format
(
output_format
);
output
->
set_mkldnn_prim_desc
(
dst_mem_pd
);
}
else
{
// Fallback to naive version
// TODO(@mozga-intel) Add MKLDNN SelectedRows & LoDTensorArray support
SumKernel
<
CPUDeviceContext
,
T
>
reference_kernel
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录