Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
36031cb5
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
36031cb5
编写于
6月 10, 2018
作者:
M
mozga-intel
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
MKLDNN layout: Support for pool operator
上级
b7c683b8
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
128 addition
and
54 deletion
+128
-54
paddle/fluid/operators/pool_mkldnn_op.cc
paddle/fluid/operators/pool_mkldnn_op.cc
+128
-54
未找到文件。
paddle/fluid/operators/pool_mkldnn_op.cc
浏览文件 @
36031cb5
...
@@ -18,9 +18,14 @@ limitations under the License. */
...
@@ -18,9 +18,14 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
using
mkldnn
::
memory
;
// Note: paddle has also "memory" namespace
using
framework
::
DataLayout
;
using
mkldnn
::
pooling_forward
;
using
mkldnn
::
memory
;
using
mkldnn
::
pooling_backward
;
using
mkldnn
::
pooling_backward
;
using
mkldnn
::
pooling_forward
;
using
mkldnn
::
primitive
;
using
mkldnn
::
reorder
;
using
mkldnn
::
stream
;
using
platform
::
to_void_cast
;
// Generate keys for storing/retriving primitives for this operator
// Generate keys for storing/retriving primitives for this operator
// TODO(jczaja): Make hashing function more optimial
// TODO(jczaja): Make hashing function more optimial
...
@@ -55,8 +60,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -55,8 +60,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const
Tensor
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
Tensor
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
Tensor
*
output
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
Tensor
*
output
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
// Get an unique name from "argument" name of "Out" variable
PADDLE_ENFORCE
(
input
->
layout
()
==
DataLayout
::
kMKLDNN
&&
// This name will be used as key when saving info into device context
input
->
format
()
!=
memory
::
format
::
format_undef
,
"Wrong layout/format set for Input tensor"
);
std
::
string
pooling_type
=
ctx
.
Attr
<
std
::
string
>
(
"pooling_type"
);
std
::
string
pooling_type
=
ctx
.
Attr
<
std
::
string
>
(
"pooling_type"
);
std
::
vector
<
int
>
ksize
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
ksize
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"ksize"
);
...
@@ -82,6 +88,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -82,6 +88,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std
::
vector
<
int
>
src_tz
=
paddle
::
framework
::
vectorize2int
(
input
->
dims
());
std
::
vector
<
int
>
src_tz
=
paddle
::
framework
::
vectorize2int
(
input
->
dims
());
std
::
vector
<
int
>
dst_tz
=
paddle
::
framework
::
vectorize2int
(
output
->
dims
());
std
::
vector
<
int
>
dst_tz
=
paddle
::
framework
::
vectorize2int
(
output
->
dims
());
auto
input_format
=
input
->
format
();
memory
::
format
output_format
{
memory
::
format
::
format_undef
};
const
std
::
string
key
=
gethash
(
src_tz
,
pooling_type
,
ksize
,
strides
,
const
std
::
string
key
=
gethash
(
src_tz
,
pooling_type
,
ksize
,
strides
,
paddings
,
ctx
.
op
().
Output
(
"Out"
));
paddings
,
ctx
.
op
().
Output
(
"Out"
));
const
std
::
string
key_pool_p
=
key
+
"@pool_p"
;
const
std
::
string
key_pool_p
=
key
+
"@pool_p"
;
...
@@ -94,16 +103,17 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -94,16 +103,17 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto
pool_p
=
auto
pool_p
=
std
::
static_pointer_cast
<
pooling_forward
>
(
dev_ctx
.
GetBlob
(
key_pool_p
));
std
::
static_pointer_cast
<
pooling_forward
>
(
dev_ctx
.
GetBlob
(
key_pool_p
));
if
(
pool_p
==
nullptr
)
{
if
(
pool_p
==
nullptr
)
{
// TODO(pzelazko-intel): support more formats
auto
src_md
=
platform
::
MKLDNNMemDesc
(
src_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
input_format
);
auto
src_md
=
/* create memory descriptor for pooling without specified format
platform
::
MKLDNNMemDesc
(
src_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
* ('any') which lets a primitive (pooling in this case) choose
mkldnn
::
memory
::
format
::
nchw
);
* the memory format preferred for best performance
auto
dst_md
=
*/
platform
::
MKLDNNMemDesc
(
dst_tz
,
platform
::
MKLDNNGetDataType
<
T
>
()
,
auto
dst_md
=
platform
::
MKLDNNMemDesc
(
dst_tz
,
mkldnn
::
memory
::
f32
,
mkldnn
::
memory
::
format
::
nchw
);
mkldnn
::
memory
::
format
::
any
);
std
::
shared_ptr
<
pooling_forward
::
primitive_desc
>
pool_pd
=
std
::
shared_ptr
<
mkldnn
::
pooling_forward
::
primitive_desc
>
pool_pd
=
CreatePrimitiveDesc
(
src_md
,
dst_md
,
strides
,
paddings
,
ksize
,
CreatePrimitiveDesc
(
src_md
,
dst_md
,
strides
,
paddings
,
ksize
,
pooling_type
,
mkldnn_engine
);
pooling_type
,
mkldnn_engine
);
...
@@ -116,20 +126,22 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -116,20 +126,22 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// save pool_workspace_memory to be referred in backward path
// save pool_workspace_memory to be referred in backward path
dev_ctx
.
SetBlob
(
key_pool_workspace_memory
,
workspace_memory
);
dev_ctx
.
SetBlob
(
key_pool_workspace_memory
,
workspace_memory
);
auto
pool_src_memory_p
=
std
::
make_shared
<
memory
>
(
auto
src_memory
=
std
::
make_shared
<
memory
>
(
pool_pd
->
src_primitive_desc
(),
memory
::
primitive_desc
{
src_md
,
mkldnn_engine
},
to_void_cast
<
T
>
(
input_data
));
static_cast
<
void
*>
(
const_cast
<
T
*>
(
input_data
)));
auto
dst_memory
=
dev_ctx
.
SetBlob
(
key_pool_src_mem_p
,
pool_src_memory_p
);
std
::
make_shared
<
memory
>
(
pool_pd
->
dst_primitive_desc
(),
output_data
);
auto
pool_dst_memory_p
=
std
::
make_shared
<
memory
>
(
dev_ctx
.
SetBlob
(
key_pool_src_mem_p
,
src_memory
);
memory
::
primitive_desc
{
dst_md
,
mkldnn_engine
},
dev_ctx
.
SetBlob
(
key_pool_dst_mem_p
,
dst_memory
);
static_cast
<
void
*>
(
output_data
));
dev_ctx
.
SetBlob
(
key_pool_dst_mem_p
,
pool_dst_memory_p
);
pool_p
=
std
::
make_shared
<
pooling_forward
>
(
*
pool_pd
,
*
(
src_memory
.
get
()),
*
(
dst_memory
.
get
()),
*
workspace_memory
);
pool_p
=
std
::
make_shared
<
pooling_forward
>
(
*
pool_pd
,
*
(
pool_src_memory_p
.
get
()),
*
(
pool_dst_memory_p
.
get
()),
*
workspace_memory
);
dev_ctx
.
SetBlob
(
key_pool_p
,
pool_p
);
dev_ctx
.
SetBlob
(
key_pool_p
,
pool_p
);
output_format
=
(
memory
::
format
)
dst_memory
->
get_primitive_desc
().
desc
().
data
.
format
;
}
else
{
}
else
{
// Primitives already exist
// Primitives already exist
auto
pool_src_memory_p
=
auto
pool_src_memory_p
=
...
@@ -140,14 +152,20 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -140,14 +152,20 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std
::
static_pointer_cast
<
memory
>
(
dev_ctx
.
GetBlob
(
key_pool_dst_mem_p
));
std
::
static_pointer_cast
<
memory
>
(
dev_ctx
.
GetBlob
(
key_pool_dst_mem_p
));
PADDLE_ENFORCE
(
pool_dst_memory_p
!=
nullptr
,
PADDLE_ENFORCE
(
pool_dst_memory_p
!=
nullptr
,
"Fail to find pooling dst mem_p in device context"
);
"Fail to find pooling dst mem_p in device context"
);
pool_src_memory_p
->
set_data_handle
(
pool_src_memory_p
->
set_data_handle
(
to_void_cast
<
T
>
(
input_data
));
reinterpret_cast
<
void
*>
(
const_cast
<
T
*>
(
input_data
)));
pool_dst_memory_p
->
set_data_handle
(
output_data
);
pool_dst_memory_p
->
set_data_handle
(
output_data
);
output_format
=
(
memory
::
format
)
pool_dst_memory_p
->
get_primitive_desc
()
.
desc
()
.
data
.
format
;
}
}
// push primitive to stream and wait until it's executed
// push primitive to stream and wait until it's executed
std
::
vector
<
mkldnn
::
primitive
>
pipeline
{
*
(
pool_p
.
get
())};
std
::
vector
<
mkldnn
::
primitive
>
pipeline
{
*
(
pool_p
.
get
())};
mkldnn
::
stream
(
mkldnn
::
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
output
->
set_layout
(
DataLayout
::
kMKLDNN
);
output
->
set_format
(
output_format
);
}
}
private:
private:
...
@@ -194,6 +212,13 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -194,6 +212,13 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const
Tensor
*
out_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
const
Tensor
*
out_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
Tensor
*
in_x_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
Tensor
*
in_x_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
PADDLE_ENFORCE
(
in_x
->
layout
()
==
DataLayout
::
kMKLDNN
&&
in_x
->
format
()
!=
memory
::
format
::
format_undef
,
"Wrong layout/format set for Input X tensor"
);
PADDLE_ENFORCE
(
out_grad
->
layout
()
==
DataLayout
::
kMKLDNN
&&
out_grad
->
format
()
!=
memory
::
format
::
format_undef
,
"Wrong layout/format set for Input output_grad tensor"
);
std
::
string
pooling_type
=
ctx
.
Attr
<
std
::
string
>
(
"pooling_type"
);
std
::
string
pooling_type
=
ctx
.
Attr
<
std
::
string
>
(
"pooling_type"
);
std
::
vector
<
int
>
ksize
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
ksize
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
strides
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
strides
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
...
@@ -212,6 +237,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -212,6 +237,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const
T
*
out_grad_data
=
out_grad
->
data
<
T
>
();
const
T
*
out_grad_data
=
out_grad
->
data
<
T
>
();
T
*
in_x_grad_data
=
in_x_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
in_x_grad_data
=
in_x_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
memory
::
format
in_x_grad_format
{
memory
::
format
::
format_undef
};
std
::
vector
<
int
>
diff_src_tz
=
std
::
vector
<
int
>
diff_src_tz
=
paddle
::
framework
::
vectorize2int
(
in_x_grad
->
dims
());
paddle
::
framework
::
vectorize2int
(
in_x_grad
->
dims
());
...
@@ -225,39 +251,48 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -225,39 +251,48 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const
std
::
string
key_pool_bwd_p
=
key
+
"@pool_bwd_p"
;
const
std
::
string
key_pool_bwd_p
=
key
+
"@pool_bwd_p"
;
const
std
::
string
key_pool_diff_src_mem_p
=
key
+
"@pool_diff_src_mem_p"
;
const
std
::
string
key_pool_diff_src_mem_p
=
key
+
"@pool_diff_src_mem_p"
;
const
std
::
string
key_pool_diff_dst_mem_p
=
key
+
"@pool_diff_dst_mem_p"
;
const
std
::
string
key_pool_diff_dst_mem_p
=
key
+
"@pool_diff_dst_mem_p"
;
const
std
::
string
key_pool_src_mem_p
=
key
+
"@pool_src_mem_p"
;
const
std
::
string
key_pool_dst_mem_p
=
key
+
"@pool_dst_mem_p"
;
const
std
::
string
key_pool_pd
=
key
+
"@pool_pd"
;
const
std
::
string
key_pool_pd
=
key
+
"@pool_pd"
;
const
std
::
string
key_pool_workspace_memory
=
const
std
::
string
key_pool_workspace_memory
=
key
+
"@pool_workspace_memory"
;
key
+
"@pool_workspace_memory"
;
auto
user_diff_dst_memory
=
memory
({{{
diff_dst_tz
},
memory
::
data_type
::
f32
,
out_grad
->
format
()},
mkldnn_engine
},
to_void_cast
<
T
>
(
out_grad_data
));
std
::
shared_ptr
<
memory
>
diff_src_memory
;
std
::
shared_ptr
<
memory
>
diff_dst_memory
;
auto
dst_memory
=
std
::
static_pointer_cast
<
memory
>
(
dev_ctx
.
GetBlob
(
key_pool_dst_mem_p
));
PADDLE_ENFORCE
(
dst_memory
!=
nullptr
,
"Fail to find dst_memory in device context"
);
primitive
reorder_diff_dst
;
bool
is_diff_dst_reordered
=
false
;
auto
pool_bwd_p
=
std
::
static_pointer_cast
<
pooling_backward
>
(
auto
pool_bwd_p
=
std
::
static_pointer_cast
<
pooling_backward
>
(
dev_ctx
.
GetBlob
(
key_pool_bwd_p
));
dev_ctx
.
GetBlob
(
key_pool_bwd_p
));
if
(
pool_bwd_p
==
nullptr
)
{
if
(
pool_bwd_p
==
nullptr
)
{
auto
diff_src_md
=
// Retrieve src_memory/dst_memory saved in forward pass
platform
::
MKLDNNMemDesc
(
diff_src_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
auto
src_memory
=
mkldnn
::
memory
::
format
::
nchw
);
std
::
static_pointer_cast
<
memory
>
(
dev_ctx
.
GetBlob
(
key_pool_src_mem_p
));
auto
diff_dst_md
=
PADDLE_ENFORCE
(
src_memory
!=
nullptr
,
platform
::
MKLDNNMemDesc
(
diff_dst_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
"Fail to find src_memory in device context"
);
mkldnn
::
memory
::
format
::
nchw
);
// Retrieve pool_pd/pool_workspace_memory from device context
// Retrieve pool_pd/pool_workspace_memory from device context
auto
pool_pd
=
auto
pool_pd
=
std
::
static_pointer_cast
<
mkldnn
::
pooling_forward
::
primitive_desc
>
(
std
::
static_pointer_cast
<
mkldnn
::
pooling_forward
::
primitive_desc
>
(
dev_ctx
.
GetBlob
(
key_pool_pd
));
dev_ctx
.
GetBlob
(
key_pool_pd
));
PADDLE_ENFORCE
(
pool_pd
!=
nullptr
,
PADDLE_ENFORCE
(
pool_pd
!=
nullptr
,
"Fail to find pool_pd in device context"
);
"Fail to find pool_pd in device context"
);
auto
workspace_memory
=
std
::
static_pointer_cast
<
memory
>
(
auto
workspace_memory
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx
.
GetBlob
(
key_pool_workspace_memory
));
dev_ctx
.
GetBlob
(
key_pool_workspace_memory
));
PADDLE_ENFORCE
(
workspace_memory
!=
nullptr
,
PADDLE_ENFORCE
(
workspace_memory
!=
nullptr
,
"Fail to find workspace_memory in device context"
);
"Fail to find workspace_memory in device context"
);
auto
pool_diff_src_memory_p
=
std
::
make_shared
<
memory
>
(
memory
(
// create memory descriptors for pooling
{
diff_src_md
,
mkldnn_engine
},
static_cast
<
void
*>
(
in_x_grad_data
)));
auto
diff_src_md
=
src_memory
.
get
()
->
get_primitive_desc
().
desc
();
dev_ctx
.
SetBlob
(
key_pool_diff_src_mem_p
,
pool_diff_src_memory_p
);
auto
diff_dst_md
=
dst_memory
.
get
()
->
get_primitive_desc
().
desc
();
auto
pool_diff_dst_memory_p
=
std
::
make_shared
<
memory
>
(
memory
({
diff_dst_md
,
mkldnn_engine
},
static_cast
<
void
*>
(
const_cast
<
T
*>
(
out_grad_data
))));
dev_ctx
.
SetBlob
(
key_pool_diff_dst_mem_p
,
pool_diff_dst_memory_p
);
auto
pool_bwd_desc
=
mkldnn
::
pooling_backward
::
desc
(
auto
pool_bwd_desc
=
mkldnn
::
pooling_backward
::
desc
(
pooling_type
==
"max"
?
mkldnn
::
algorithm
::
pooling_max
pooling_type
==
"max"
?
mkldnn
::
algorithm
::
pooling_max
...
@@ -267,35 +302,74 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -267,35 +302,74 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto
pool_bwd_pd
=
mkldnn
::
pooling_backward
::
primitive_desc
(
auto
pool_bwd_pd
=
mkldnn
::
pooling_backward
::
primitive_desc
(
pool_bwd_desc
,
mkldnn_engine
,
*
pool_pd
);
pool_bwd_desc
,
mkldnn_engine
,
*
pool_pd
);
// reorder between user_diff_dst and pool diff_dst if needed
diff_dst_memory
=
std
::
make_shared
<
memory
>
(
user_diff_dst_memory
);
if
(
memory
::
primitive_desc
(
dst_memory
->
get_primitive_desc
())
!=
user_diff_dst_memory
.
get_primitive_desc
())
{
diff_dst_memory
=
std
::
make_shared
<
memory
>
(
dst_memory
.
get
()
->
get_primitive_desc
());
reorder_diff_dst
=
reorder
(
user_diff_dst_memory
,
*
diff_dst_memory
);
is_diff_dst_reordered
=
true
;
}
diff_src_memory
=
std
::
make_shared
<
memory
>
(
pool_bwd_pd
.
diff_src_primitive_desc
(),
in_x_grad_data
);
dev_ctx
.
SetBlob
(
key_pool_diff_src_mem_p
,
diff_src_memory
);
dev_ctx
.
SetBlob
(
key_pool_diff_dst_mem_p
,
diff_dst_memory
);
pool_bwd_p
=
std
::
make_shared
<
pooling_backward
>
(
pool_bwd_p
=
std
::
make_shared
<
pooling_backward
>
(
pool_bwd_pd
,
*
(
pool_diff_dst_memory_p
.
get
()),
*
workspace_memory
,
pool_bwd_pd
,
*
(
diff_dst_memory
.
get
()),
*
workspace_memory
,
*
(
pool_diff_src_memory_p
));
*
(
diff_src_memory
));
dev_ctx
.
SetBlob
(
key_pool_bwd_p
,
pool_bwd_p
);
dev_ctx
.
SetBlob
(
key_pool_bwd_p
,
pool_bwd_p
);
}
else
{
}
else
{
// Primitives already exist
// Primitives already exist
auto
pool_diff_src_memory_p
=
std
::
static_pointer_cast
<
memory
>
(
diff_src_memory
=
std
::
static_pointer_cast
<
memory
>
(
dev_ctx
.
GetBlob
(
key_pool_diff_src_mem_p
));
dev_ctx
.
GetBlob
(
key_pool_diff_src_mem_p
));
PADDLE_ENFORCE
(
pool_diff_src_memory_p
!=
nullptr
,
PADDLE_ENFORCE
(
diff_src_memory
!=
nullptr
,
"Fail to find pooling src mem_p in device context"
);
"Fail to find pooling src mem_p in device context"
);
auto
pool_diff_dst_memory_p
=
std
::
static_pointer_cast
<
memory
>
(
diff_dst_memory
=
std
::
static_pointer_cast
<
memory
>
(
dev_ctx
.
GetBlob
(
key_pool_diff_dst_mem_p
));
dev_ctx
.
GetBlob
(
key_pool_diff_dst_mem_p
));
PADDLE_ENFORCE
(
pool_diff_dst_memory_p
!=
nullptr
,
PADDLE_ENFORCE
(
diff_dst_memory
!=
nullptr
,
"Fail to find pooling dst mem_p in device context"
);
"Fail to find pooling dst mem_p in device context"
);
pool_diff_src_memory_p
->
set_data_handle
(
reinterpret_cast
<
void
*>
(
in_x_grad_data
));
diff_src_memory
->
set_data_handle
(
reinterpret_cast
<
void
*>
(
in_x_grad_data
));
pool_diff_dst_memory_p
->
set_data_handle
(
const_cast
<
T
*>
(
out_grad_data
));
diff_dst_memory
->
set_data_handle
(
const_cast
<
T
*>
(
out_grad_data
));
// reorder between user_diff_dst and pool diff_dst if needed
if
(
memory
::
primitive_desc
(
dst_memory
->
get_primitive_desc
())
!=
user_diff_dst_memory
.
get_primitive_desc
())
{
diff_dst_memory
=
std
::
make_shared
<
memory
>
(
dst_memory
.
get
()
->
get_primitive_desc
());
reorder_diff_dst
=
reorder
(
user_diff_dst_memory
,
*
diff_dst_memory
);
is_diff_dst_reordered
=
true
;
}
}
}
in_x_grad_format
=
(
memory
::
format
)
diff_src_memory
->
get_primitive_desc
()
.
desc
()
.
data
.
format
;
// push primitive to stream and wait until it's executed
// push primitive to stream and wait until it's executed
std
::
vector
<
mkldnn
::
primitive
>
pipeline
{
*
(
pool_bwd_p
.
get
())};
std
::
vector
<
mkldnn
::
primitive
>
pipeline
;
if
(
is_diff_dst_reordered
)
{
pipeline
.
push_back
(
reorder_diff_dst
);
}
pipeline
.
push_back
(
*
(
pool_bwd_p
.
get
()));
mkldnn
::
stream
(
mkldnn
::
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
mkldnn
::
stream
(
mkldnn
::
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
in_x_grad
->
set_layout
(
DataLayout
::
kMKLDNN
);
in_x_grad
->
set_format
(
in_x_grad_format
);
}
// Compute()
}
// Compute()
};
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_KERNEL
(
pool2d
,
MKLDNN
,
::
paddle
::
platform
::
CPUPlace
,
REGISTER_OP_KERNEL
(
pool2d
,
MKLDNN
,
::
paddle
::
platform
::
CPUPlace
,
paddle
::
operator
s
::
PoolMKLDNNOpKernel
<
float
>
);
op
s
::
PoolMKLDNNOpKernel
<
float
>
);
REGISTER_OP_KERNEL
(
pool2d_grad
,
MKLDNN
,
::
paddle
::
platform
::
CPUPlace
,
REGISTER_OP_KERNEL
(
pool2d_grad
,
MKLDNN
,
::
paddle
::
platform
::
CPUPlace
,
paddle
::
operator
s
::
PoolMKLDNNGradOpKernel
<
float
>
);
op
s
::
PoolMKLDNNGradOpKernel
<
float
>
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录