Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
36031cb5
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
36031cb5
编写于
6月 10, 2018
作者:
M
mozga-intel
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
MKLDNN layout: Support for pool operator
上级
b7c683b8
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
128 addition
and
54 deletion
+128
-54
paddle/fluid/operators/pool_mkldnn_op.cc
paddle/fluid/operators/pool_mkldnn_op.cc
+128
-54
未找到文件。
paddle/fluid/operators/pool_mkldnn_op.cc
浏览文件 @
36031cb5
...
...
@@ -18,9 +18,14 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
using
mkldnn
::
memory
;
// Note: paddle has also "memory" namespace
using
mkldnn
::
pooling_forward
;
using
framework
::
DataLayout
;
using
mkldnn
::
memory
;
using
mkldnn
::
pooling_backward
;
using
mkldnn
::
pooling_forward
;
using
mkldnn
::
primitive
;
using
mkldnn
::
reorder
;
using
mkldnn
::
stream
;
using
platform
::
to_void_cast
;
// Generate keys for storing/retriving primitives for this operator
// TODO(jczaja): Make hashing function more optimial
...
...
@@ -55,8 +60,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const
Tensor
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
Tensor
*
output
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
// Get an unique name from "argument" name of "Out" variable
// This name will be used as key when saving info into device context
PADDLE_ENFORCE
(
input
->
layout
()
==
DataLayout
::
kMKLDNN
&&
input
->
format
()
!=
memory
::
format
::
format_undef
,
"Wrong layout/format set for Input tensor"
);
std
::
string
pooling_type
=
ctx
.
Attr
<
std
::
string
>
(
"pooling_type"
);
std
::
vector
<
int
>
ksize
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"ksize"
);
...
...
@@ -82,6 +88,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std
::
vector
<
int
>
src_tz
=
paddle
::
framework
::
vectorize2int
(
input
->
dims
());
std
::
vector
<
int
>
dst_tz
=
paddle
::
framework
::
vectorize2int
(
output
->
dims
());
auto
input_format
=
input
->
format
();
memory
::
format
output_format
{
memory
::
format
::
format_undef
};
const
std
::
string
key
=
gethash
(
src_tz
,
pooling_type
,
ksize
,
strides
,
paddings
,
ctx
.
op
().
Output
(
"Out"
));
const
std
::
string
key_pool_p
=
key
+
"@pool_p"
;
...
...
@@ -94,16 +103,17 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto
pool_p
=
std
::
static_pointer_cast
<
pooling_forward
>
(
dev_ctx
.
GetBlob
(
key_pool_p
));
if
(
pool_p
==
nullptr
)
{
// TODO(pzelazko-intel): support more formats
auto
src_md
=
platform
::
MKLDNNMemDesc
(
src_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
input_format
);
auto
src_md
=
platform
::
MKLDNNMemDesc
(
src_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
mkldnn
::
memory
::
format
::
nchw
);
auto
dst_md
=
platform
::
MKLDNNMemDesc
(
dst_tz
,
platform
::
MKLDNNGetDataType
<
T
>
()
,
mkldnn
::
memory
::
format
::
nchw
);
/* create memory descriptor for pooling without specified format
* ('any') which lets a primitive (pooling in this case) choose
* the memory format preferred for best performance
*/
auto
dst_md
=
platform
::
MKLDNNMemDesc
(
dst_tz
,
mkldnn
::
memory
::
f32
,
mkldnn
::
memory
::
format
::
any
);
std
::
shared_ptr
<
pooling_forward
::
primitive_desc
>
pool_pd
=
std
::
shared_ptr
<
mkldnn
::
pooling_forward
::
primitive_desc
>
pool_pd
=
CreatePrimitiveDesc
(
src_md
,
dst_md
,
strides
,
paddings
,
ksize
,
pooling_type
,
mkldnn_engine
);
...
...
@@ -116,20 +126,22 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// save pool_workspace_memory to be referred in backward path
dev_ctx
.
SetBlob
(
key_pool_workspace_memory
,
workspace_memory
);
auto
pool_src_memory_p
=
std
::
make_shared
<
memory
>
(
memory
::
primitive_desc
{
src_md
,
mkldnn_engine
},
static_cast
<
void
*>
(
const_cast
<
T
*>
(
input_data
)));
dev_ctx
.
SetBlob
(
key_pool_src_mem_p
,
pool_src_memory_p
);
auto
src_memory
=
std
::
make_shared
<
memory
>
(
pool_pd
->
src_primitive_desc
(),
to_void_cast
<
T
>
(
input_data
));
auto
dst_memory
=
std
::
make_shared
<
memory
>
(
pool_pd
->
dst_primitive_desc
(),
output_data
);
auto
pool_dst_memory_p
=
std
::
make_shared
<
memory
>
(
memory
::
primitive_desc
{
dst_md
,
mkldnn_engine
},
static_cast
<
void
*>
(
output_data
));
dev_ctx
.
SetBlob
(
key_pool_dst_mem_p
,
pool_dst_memory_p
);
dev_ctx
.
SetBlob
(
key_pool_src_mem_p
,
src_memory
);
dev_ctx
.
SetBlob
(
key_pool_dst_mem_p
,
dst_memory
);
pool_p
=
std
::
make_shared
<
pooling_forward
>
(
*
pool_pd
,
*
(
src_memory
.
get
()),
*
(
dst_memory
.
get
()),
*
workspace_memory
);
pool_p
=
std
::
make_shared
<
pooling_forward
>
(
*
pool_pd
,
*
(
pool_src_memory_p
.
get
()),
*
(
pool_dst_memory_p
.
get
()),
*
workspace_memory
);
dev_ctx
.
SetBlob
(
key_pool_p
,
pool_p
);
output_format
=
(
memory
::
format
)
dst_memory
->
get_primitive_desc
().
desc
().
data
.
format
;
}
else
{
// Primitives already exist
auto
pool_src_memory_p
=
...
...
@@ -140,14 +152,20 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std
::
static_pointer_cast
<
memory
>
(
dev_ctx
.
GetBlob
(
key_pool_dst_mem_p
));
PADDLE_ENFORCE
(
pool_dst_memory_p
!=
nullptr
,
"Fail to find pooling dst mem_p in device context"
);
pool_src_memory_p
->
set_data_handle
(
reinterpret_cast
<
void
*>
(
const_cast
<
T
*>
(
input_data
)));
pool_src_memory_p
->
set_data_handle
(
to_void_cast
<
T
>
(
input_data
));
pool_dst_memory_p
->
set_data_handle
(
output_data
);
output_format
=
(
memory
::
format
)
pool_dst_memory_p
->
get_primitive_desc
()
.
desc
()
.
data
.
format
;
}
// push primitive to stream and wait until it's executed
std
::
vector
<
mkldnn
::
primitive
>
pipeline
{
*
(
pool_p
.
get
())};
mkldnn
::
stream
(
mkldnn
::
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
output
->
set_layout
(
DataLayout
::
kMKLDNN
);
output
->
set_format
(
output_format
);
}
private:
...
...
@@ -194,6 +212,13 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const
Tensor
*
out_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
Tensor
*
in_x_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
PADDLE_ENFORCE
(
in_x
->
layout
()
==
DataLayout
::
kMKLDNN
&&
in_x
->
format
()
!=
memory
::
format
::
format_undef
,
"Wrong layout/format set for Input X tensor"
);
PADDLE_ENFORCE
(
out_grad
->
layout
()
==
DataLayout
::
kMKLDNN
&&
out_grad
->
format
()
!=
memory
::
format
::
format_undef
,
"Wrong layout/format set for Input output_grad tensor"
);
std
::
string
pooling_type
=
ctx
.
Attr
<
std
::
string
>
(
"pooling_type"
);
std
::
vector
<
int
>
ksize
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
strides
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
...
...
@@ -212,6 +237,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const
T
*
out_grad_data
=
out_grad
->
data
<
T
>
();
T
*
in_x_grad_data
=
in_x_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
memory
::
format
in_x_grad_format
{
memory
::
format
::
format_undef
};
std
::
vector
<
int
>
diff_src_tz
=
paddle
::
framework
::
vectorize2int
(
in_x_grad
->
dims
());
...
...
@@ -225,39 +251,48 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const
std
::
string
key_pool_bwd_p
=
key
+
"@pool_bwd_p"
;
const
std
::
string
key_pool_diff_src_mem_p
=
key
+
"@pool_diff_src_mem_p"
;
const
std
::
string
key_pool_diff_dst_mem_p
=
key
+
"@pool_diff_dst_mem_p"
;
const
std
::
string
key_pool_src_mem_p
=
key
+
"@pool_src_mem_p"
;
const
std
::
string
key_pool_dst_mem_p
=
key
+
"@pool_dst_mem_p"
;
const
std
::
string
key_pool_pd
=
key
+
"@pool_pd"
;
const
std
::
string
key_pool_workspace_memory
=
key
+
"@pool_workspace_memory"
;
auto
user_diff_dst_memory
=
memory
({{{
diff_dst_tz
},
memory
::
data_type
::
f32
,
out_grad
->
format
()},
mkldnn_engine
},
to_void_cast
<
T
>
(
out_grad_data
));
std
::
shared_ptr
<
memory
>
diff_src_memory
;
std
::
shared_ptr
<
memory
>
diff_dst_memory
;
auto
dst_memory
=
std
::
static_pointer_cast
<
memory
>
(
dev_ctx
.
GetBlob
(
key_pool_dst_mem_p
));
PADDLE_ENFORCE
(
dst_memory
!=
nullptr
,
"Fail to find dst_memory in device context"
);
primitive
reorder_diff_dst
;
bool
is_diff_dst_reordered
=
false
;
auto
pool_bwd_p
=
std
::
static_pointer_cast
<
pooling_backward
>
(
dev_ctx
.
GetBlob
(
key_pool_bwd_p
));
if
(
pool_bwd_p
==
nullptr
)
{
auto
diff_src_md
=
platform
::
MKLDNNMemDesc
(
diff_src_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
mkldnn
::
memory
::
format
::
nchw
);
auto
diff_dst_md
=
platform
::
MKLDNNMemDesc
(
diff_dst_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
mkldnn
::
memory
::
format
::
nchw
);
// Retrieve src_memory/dst_memory saved in forward pass
auto
src_memory
=
std
::
static_pointer_cast
<
memory
>
(
dev_ctx
.
GetBlob
(
key_pool_src_mem_p
));
PADDLE_ENFORCE
(
src_memory
!=
nullptr
,
"Fail to find src_memory in device context"
);
// Retrieve pool_pd/pool_workspace_memory from device context
auto
pool_pd
=
std
::
static_pointer_cast
<
mkldnn
::
pooling_forward
::
primitive_desc
>
(
dev_ctx
.
GetBlob
(
key_pool_pd
));
PADDLE_ENFORCE
(
pool_pd
!=
nullptr
,
"Fail to find pool_pd in device context"
);
auto
workspace_memory
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
auto
workspace_memory
=
std
::
static_pointer_cast
<
memory
>
(
dev_ctx
.
GetBlob
(
key_pool_workspace_memory
));
PADDLE_ENFORCE
(
workspace_memory
!=
nullptr
,
"Fail to find workspace_memory in device context"
);
auto
pool_diff_src_memory_p
=
std
::
make_shared
<
memory
>
(
memory
(
{
diff_src_md
,
mkldnn_engine
},
static_cast
<
void
*>
(
in_x_grad_data
)));
dev_ctx
.
SetBlob
(
key_pool_diff_src_mem_p
,
pool_diff_src_memory_p
);
auto
pool_diff_dst_memory_p
=
std
::
make_shared
<
memory
>
(
memory
({
diff_dst_md
,
mkldnn_engine
},
static_cast
<
void
*>
(
const_cast
<
T
*>
(
out_grad_data
))));
dev_ctx
.
SetBlob
(
key_pool_diff_dst_mem_p
,
pool_diff_dst_memory_p
);
// create memory descriptors for pooling
auto
diff_src_md
=
src_memory
.
get
()
->
get_primitive_desc
().
desc
();
auto
diff_dst_md
=
dst_memory
.
get
()
->
get_primitive_desc
().
desc
();
auto
pool_bwd_desc
=
mkldnn
::
pooling_backward
::
desc
(
pooling_type
==
"max"
?
mkldnn
::
algorithm
::
pooling_max
...
...
@@ -267,35 +302,74 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto
pool_bwd_pd
=
mkldnn
::
pooling_backward
::
primitive_desc
(
pool_bwd_desc
,
mkldnn_engine
,
*
pool_pd
);
// reorder between user_diff_dst and pool diff_dst if needed
diff_dst_memory
=
std
::
make_shared
<
memory
>
(
user_diff_dst_memory
);
if
(
memory
::
primitive_desc
(
dst_memory
->
get_primitive_desc
())
!=
user_diff_dst_memory
.
get_primitive_desc
())
{
diff_dst_memory
=
std
::
make_shared
<
memory
>
(
dst_memory
.
get
()
->
get_primitive_desc
());
reorder_diff_dst
=
reorder
(
user_diff_dst_memory
,
*
diff_dst_memory
);
is_diff_dst_reordered
=
true
;
}
diff_src_memory
=
std
::
make_shared
<
memory
>
(
pool_bwd_pd
.
diff_src_primitive_desc
(),
in_x_grad_data
);
dev_ctx
.
SetBlob
(
key_pool_diff_src_mem_p
,
diff_src_memory
);
dev_ctx
.
SetBlob
(
key_pool_diff_dst_mem_p
,
diff_dst_memory
);
pool_bwd_p
=
std
::
make_shared
<
pooling_backward
>
(
pool_bwd_pd
,
*
(
pool_diff_dst_memory_p
.
get
()),
*
workspace_memory
,
*
(
pool_diff_src_memory_p
));
pool_bwd_pd
,
*
(
diff_dst_memory
.
get
()),
*
workspace_memory
,
*
(
diff_src_memory
));
dev_ctx
.
SetBlob
(
key_pool_bwd_p
,
pool_bwd_p
);
}
else
{
// Primitives already exist
auto
pool_diff_src_memory_p
=
std
::
static_pointer_cast
<
memory
>
(
diff_src_memory
=
std
::
static_pointer_cast
<
memory
>
(
dev_ctx
.
GetBlob
(
key_pool_diff_src_mem_p
));
PADDLE_ENFORCE
(
pool_diff_src_memory_p
!=
nullptr
,
PADDLE_ENFORCE
(
diff_src_memory
!=
nullptr
,
"Fail to find pooling src mem_p in device context"
);
auto
pool_diff_dst_memory_p
=
std
::
static_pointer_cast
<
memory
>
(
diff_dst_memory
=
std
::
static_pointer_cast
<
memory
>
(
dev_ctx
.
GetBlob
(
key_pool_diff_dst_mem_p
));
PADDLE_ENFORCE
(
pool_diff_dst_memory_p
!=
nullptr
,
PADDLE_ENFORCE
(
diff_dst_memory
!=
nullptr
,
"Fail to find pooling dst mem_p in device context"
);
pool_diff_src_memory_p
->
set_data_handle
(
reinterpret_cast
<
void
*>
(
in_x_grad_data
));
pool_diff_dst_memory_p
->
set_data_handle
(
const_cast
<
T
*>
(
out_grad_data
));
diff_src_memory
->
set_data_handle
(
reinterpret_cast
<
void
*>
(
in_x_grad_data
));
diff_dst_memory
->
set_data_handle
(
const_cast
<
T
*>
(
out_grad_data
));
// reorder between user_diff_dst and pool diff_dst if needed
if
(
memory
::
primitive_desc
(
dst_memory
->
get_primitive_desc
())
!=
user_diff_dst_memory
.
get_primitive_desc
())
{
diff_dst_memory
=
std
::
make_shared
<
memory
>
(
dst_memory
.
get
()
->
get_primitive_desc
());
reorder_diff_dst
=
reorder
(
user_diff_dst_memory
,
*
diff_dst_memory
);
is_diff_dst_reordered
=
true
;
}
}
in_x_grad_format
=
(
memory
::
format
)
diff_src_memory
->
get_primitive_desc
()
.
desc
()
.
data
.
format
;
// push primitive to stream and wait until it's executed
std
::
vector
<
mkldnn
::
primitive
>
pipeline
{
*
(
pool_bwd_p
.
get
())};
std
::
vector
<
mkldnn
::
primitive
>
pipeline
;
if
(
is_diff_dst_reordered
)
{
pipeline
.
push_back
(
reorder_diff_dst
);
}
pipeline
.
push_back
(
*
(
pool_bwd_p
.
get
()));
mkldnn
::
stream
(
mkldnn
::
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
in_x_grad
->
set_layout
(
DataLayout
::
kMKLDNN
);
in_x_grad
->
set_format
(
in_x_grad_format
);
}
// Compute()
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_KERNEL
(
pool2d
,
MKLDNN
,
::
paddle
::
platform
::
CPUPlace
,
paddle
::
operator
s
::
PoolMKLDNNOpKernel
<
float
>
);
op
s
::
PoolMKLDNNOpKernel
<
float
>
);
REGISTER_OP_KERNEL
(
pool2d_grad
,
MKLDNN
,
::
paddle
::
platform
::
CPUPlace
,
paddle
::
operator
s
::
PoolMKLDNNGradOpKernel
<
float
>
);
op
s
::
PoolMKLDNNGradOpKernel
<
float
>
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录