Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
85b6bb58
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
85b6bb58
编写于
5月 21, 2018
作者:
T
Tao Luo
提交者:
GitHub
5月 21, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #10747 from jczaja/prv-mkldnn-pooling-reuse
Reuse of pooling mkldnn primitives
上级
e526dd58
5f133305
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
163 addition
and
76 deletion
+163
-76
paddle/fluid/operators/pool_mkldnn_op.cc
paddle/fluid/operators/pool_mkldnn_op.cc
+153
-76
paddle/fluid/platform/mkldnn_helper.h
paddle/fluid/platform/mkldnn_helper.h
+10
-0
未找到文件。
paddle/fluid/operators/pool_mkldnn_op.cc
浏览文件 @
85b6bb58
...
...
@@ -18,6 +18,26 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
using
mkldnn
::
memory
;
// Note: paddle has also "memory" namespace
using
mkldnn
::
pooling_forward
;
using
mkldnn
::
pooling_backward
;
// Generate keys for storing/retriving primitives for this operator
// TODO(jczaja): Make hashing function more optimial
static
std
::
string
gethash
(
memory
::
dims
&
input_dims
,
std
::
string
&
pooling_type
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
std
::
string
suffix
)
{
auto
dims2str
=
[](
memory
::
dims
&
operand_dims
)
{
std
::
string
dstr
=
""
;
for
(
size_t
i
=
0
;
i
<
operand_dims
.
size
();
++
i
)
{
dstr
+=
std
::
to_string
(
operand_dims
[
i
])
+
"-"
;
}
return
dstr
;
};
return
dims2str
(
input_dims
)
+
dims2str
(
ksize
)
+
dims2str
(
strides
)
+
dims2str
(
paddings
)
+
pooling_type
+
suffix
;
}
template
<
typename
T
>
class
PoolMKLDNNOpKernel
:
public
paddle
::
framework
::
OpKernel
<
T
>
{
public:
...
...
@@ -34,10 +54,6 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// Get an unique name from "argument" name of "Out" variable
// This name will be used as key when saving info into device context
const
std
::
string
key
=
ctx
.
op
().
Output
(
"Out"
);
const
std
::
string
key_pool_pd
=
key
+
"@pool_pd"
;
const
std
::
string
key_pool_workspace_memory
=
key
+
"@pool_workspace_memory"
;
std
::
string
pooling_type
=
ctx
.
Attr
<
std
::
string
>
(
"pooling_type"
);
std
::
vector
<
int
>
ksize
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"ksize"
);
...
...
@@ -63,13 +79,28 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std
::
vector
<
int
>
src_tz
=
paddle
::
framework
::
vectorize2int
(
input
->
dims
());
std
::
vector
<
int
>
dst_tz
=
paddle
::
framework
::
vectorize2int
(
output
->
dims
());
const
std
::
string
key
=
gethash
(
src_tz
,
pooling_type
,
ksize
,
strides
,
paddings
,
ctx
.
op
().
Output
(
"Out"
));
const
std
::
string
key_pool_p
=
key
+
"@pool_p"
;
const
std
::
string
key_pool_pd
=
key
+
"@pool_pd"
;
const
std
::
string
key_pool_src_mem_p
=
key
+
"@pool_src_mem_p"
;
const
std
::
string
key_pool_dst_mem_p
=
key
+
"@pool_dst_mem_p"
;
const
std
::
string
key_pool_workspace_memory
=
key
+
"@pool_workspace_memory"
;
auto
pool_p
=
std
::
static_pointer_cast
<
pooling_forward
>
(
dev_ctx
.
GetBlob
(
key_pool_p
));
if
(
pool_p
==
nullptr
)
{
// TODO(pzelazko-intel): support more formats
auto
src_md
=
platform
::
MKLDNNMemDesc
(
src_tz
,
mkldnn
::
memory
::
f32
,
auto
src_md
=
platform
::
MKLDNNMemDesc
(
src_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
mkldnn
::
memory
::
format
::
nchw
);
auto
dst_md
=
platform
::
MKLDNNMemDesc
(
dst_tz
,
mkldnn
::
memory
::
f32
,
auto
dst_md
=
platform
::
MKLDNNMemDesc
(
dst_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
mkldnn
::
memory
::
format
::
nchw
);
std
::
shared_ptr
<
mkldnn
::
pooling_forward
::
primitive_desc
>
pool_pd
=
std
::
shared_ptr
<
pooling_forward
::
primitive_desc
>
pool_pd
=
CreatePrimitiveDesc
(
src_md
,
dst_md
,
strides
,
paddings
,
ksize
,
pooling_type
,
mkldnn_engine
);
...
...
@@ -82,18 +113,37 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// save pool_workspace_memory to be referred in backward path
dev_ctx
.
SetBlob
(
key_pool_workspace_memory
,
workspace_memory
);
auto
src_memory
=
mkldnn
::
memory
(
{
src_md
,
mkldnn_engine
},
auto
pool_src_memory_p
=
std
::
make_shared
<
memory
>
(
memory
::
primitive_desc
{
src_md
,
mkldnn_engine
},
static_cast
<
void
*>
(
const_cast
<
T
*>
(
input_data
)));
auto
dst_memory
=
mkldnn
::
memory
({
dst_md
,
mkldnn_engine
},
static_cast
<
void
*>
(
const_cast
<
T
*>
(
output_data
)));
dev_ctx
.
SetBlob
(
key_pool_src_mem_p
,
pool_src_memory_p
);
auto
pool_prim
=
mkldnn
::
pooling_forward
(
*
pool_pd
,
src_memory
,
dst_memory
,
auto
pool_dst_memory_p
=
std
::
make_shared
<
memory
>
(
memory
::
primitive_desc
{
dst_md
,
mkldnn_engine
},
static_cast
<
void
*>
(
output_data
));
dev_ctx
.
SetBlob
(
key_pool_dst_mem_p
,
pool_dst_memory_p
);
pool_p
=
std
::
make_shared
<
pooling_forward
>
(
*
pool_pd
,
*
(
pool_src_memory_p
.
get
()),
*
(
pool_dst_memory_p
.
get
()),
*
workspace_memory
);
dev_ctx
.
SetBlob
(
key_pool_p
,
pool_p
);
}
else
{
// Primitives already exist
auto
pool_src_memory_p
=
std
::
static_pointer_cast
<
memory
>
(
dev_ctx
.
GetBlob
(
key_pool_src_mem_p
));
PADDLE_ENFORCE
(
pool_src_memory_p
!=
nullptr
,
"Fail to find pooling src mem_p in device context"
);
auto
pool_dst_memory_p
=
std
::
static_pointer_cast
<
memory
>
(
dev_ctx
.
GetBlob
(
key_pool_dst_mem_p
));
PADDLE_ENFORCE
(
pool_dst_memory_p
!=
nullptr
,
"Fail to find pooling dst mem_p in device context"
);
pool_src_memory_p
->
set_data_handle
(
reinterpret_cast
<
void
*>
(
const_cast
<
T
*>
(
input_data
)));
pool_dst_memory_p
->
set_data_handle
(
output_data
);
}
// push primitive to stream and wait until it's executed
std
::
vector
<
mkldnn
::
primitive
>
pipeline
{
pool_prim
};
std
::
vector
<
mkldnn
::
primitive
>
pipeline
{
*
(
pool_p
.
get
())
};
mkldnn
::
stream
(
mkldnn
::
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
}
...
...
@@ -120,8 +170,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mkldnn
::
memory
::
primitive_desc
workspace_md
=
pooling_type
==
"max"
?
pool_pd
->
workspace_primitive_desc
()
:
mkldnn
::
memory
::
primitive_desc
(
{{},
mkldnn
::
memory
::
f32
,
mkldnn
::
memory
::
format
::
nchw
},
:
mkldnn
::
memory
::
primitive_desc
({{},
platform
::
MKLDNNGetDataType
<
T
>
(),
mkldnn
::
memory
::
format
::
nchw
},
engine
);
auto
p_workspace_memory
=
new
mkldnn
::
memory
(
workspace_md
);
...
...
@@ -140,13 +191,6 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const
Tensor
*
out_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
Tensor
*
in_x_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
// Get an unique name from "argument" name of "Out" variable
// This name will be used as key when referring info from device context
const
std
::
string
key
=
ctx
.
op
().
Input
(
"Out"
);
const
std
::
string
key_pool_pd
=
key
+
"@pool_pd"
;
const
std
::
string
key_pool_workspace_memory
=
key
+
"@pool_workspace_memory"
;
std
::
string
pooling_type
=
ctx
.
Attr
<
std
::
string
>
(
"pooling_type"
);
std
::
vector
<
int
>
ksize
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
strides
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
...
...
@@ -171,11 +215,26 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
std
::
vector
<
int
>
diff_dst_tz
=
paddle
::
framework
::
vectorize2int
(
out_grad
->
dims
());
auto
diff_src_md
=
platform
::
MKLDNNMemDesc
(
diff_src_tz
,
mkldnn
::
memory
::
f32
,
// Get an unique name from "argument" name of "Out" variable
// This name will be used as key when referring info from device context
const
std
::
string
key
=
gethash
(
diff_src_tz
,
pooling_type
,
ksize
,
strides
,
paddings
,
ctx
.
op
().
Input
(
"Out"
));
const
std
::
string
key_pool_bwd_p
=
key
+
"@pool_bwd_p"
;
const
std
::
string
key_pool_diff_src_mem_p
=
key
+
"@pool_diff_src_mem_p"
;
const
std
::
string
key_pool_diff_dst_mem_p
=
key
+
"@pool_diff_dst_mem_p"
;
const
std
::
string
key_pool_pd
=
key
+
"@pool_pd"
;
const
std
::
string
key_pool_workspace_memory
=
key
+
"@pool_workspace_memory"
;
auto
pool_bwd_p
=
std
::
static_pointer_cast
<
pooling_backward
>
(
dev_ctx
.
GetBlob
(
key_pool_bwd_p
));
if
(
pool_bwd_p
==
nullptr
)
{
auto
diff_src_md
=
platform
::
MKLDNNMemDesc
(
diff_src_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
mkldnn
::
memory
::
format
::
nchw
);
auto
diff_dst_md
=
platform
::
MKLDNNMemDesc
(
diff_dst_tz
,
mkldnn
::
memory
::
f32
,
auto
diff_dst_md
=
platform
::
MKLDNNMemDesc
(
diff_dst_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
mkldnn
::
memory
::
format
::
nchw
);
// Retrieve pool_pd/pool_workspace_memory from device context
auto
pool_pd
=
std
::
static_pointer_cast
<
mkldnn
::
pooling_forward
::
primitive_desc
>
(
...
...
@@ -188,6 +247,15 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE
(
workspace_memory
!=
nullptr
,
"Fail to find workspace_memory in device context"
);
auto
pool_diff_src_memory_p
=
std
::
make_shared
<
memory
>
(
memory
(
{
diff_src_md
,
mkldnn_engine
},
static_cast
<
void
*>
(
in_x_grad_data
)));
dev_ctx
.
SetBlob
(
key_pool_diff_src_mem_p
,
pool_diff_src_memory_p
);
auto
pool_diff_dst_memory_p
=
std
::
make_shared
<
memory
>
(
memory
({
diff_dst_md
,
mkldnn_engine
},
static_cast
<
void
*>
(
const_cast
<
T
*>
(
out_grad_data
))));
dev_ctx
.
SetBlob
(
key_pool_diff_dst_mem_p
,
pool_diff_dst_memory_p
);
auto
pool_bwd_desc
=
mkldnn
::
pooling_backward
::
desc
(
pooling_type
==
"max"
?
mkldnn
::
algorithm
::
pooling_max
:
mkldnn
::
algorithm
::
pooling_avg
,
...
...
@@ -196,18 +264,27 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto
pool_bwd_pd
=
mkldnn
::
pooling_backward
::
primitive_desc
(
pool_bwd_desc
,
mkldnn_engine
,
*
pool_pd
);
auto
diff_src_memory
=
mkldnn
::
memory
({
diff_src_md
,
mkldnn_engine
},
static_cast
<
void
*>
(
const_cast
<
T
*>
(
in_x_grad_data
)));
auto
diff_dst_memory
=
mkldnn
::
memory
({
diff_dst_md
,
mkldnn_engine
},
static_cast
<
void
*>
(
const_cast
<
T
*>
(
out_grad_data
)));
auto
bwd_prim
=
mkldnn
::
pooling_backward
(
pool_bwd_pd
,
diff_dst_memory
,
*
workspace_memory
,
diff_src_memory
);
pool_bwd_p
=
std
::
make_shared
<
pooling_backward
>
(
pool_bwd_pd
,
*
(
pool_diff_dst_memory_p
.
get
()),
*
workspace_memory
,
*
(
pool_diff_src_memory_p
));
dev_ctx
.
SetBlob
(
key_pool_bwd_p
,
pool_bwd_p
);
}
else
{
// Primitives already exist
auto
pool_diff_src_memory_p
=
std
::
static_pointer_cast
<
memory
>
(
dev_ctx
.
GetBlob
(
key_pool_diff_src_mem_p
));
PADDLE_ENFORCE
(
pool_diff_src_memory_p
!=
nullptr
,
"Fail to find pooling src mem_p in device context"
);
auto
pool_diff_dst_memory_p
=
std
::
static_pointer_cast
<
memory
>
(
dev_ctx
.
GetBlob
(
key_pool_diff_dst_mem_p
));
PADDLE_ENFORCE
(
pool_diff_dst_memory_p
!=
nullptr
,
"Fail to find pooling dst mem_p in device context"
);
pool_diff_src_memory_p
->
set_data_handle
(
reinterpret_cast
<
void
*>
(
in_x_grad_data
));
pool_diff_dst_memory_p
->
set_data_handle
(
const_cast
<
T
*>
(
out_grad_data
));
}
// push primitive to stream and wait until it's executed
std
::
vector
<
mkldnn
::
primitive
>
pipeline
{
bwd_prim
};
std
::
vector
<
mkldnn
::
primitive
>
pipeline
{
*
(
pool_bwd_p
.
get
())
};
mkldnn
::
stream
(
mkldnn
::
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
}
// Compute()
};
...
...
paddle/fluid/platform/mkldnn_helper.h
浏览文件 @
85b6bb58
...
...
@@ -71,5 +71,15 @@ inline bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) {
return
use_mkldnn
&&
platform
::
is_cpu_place
(
ctx
.
GetPlace
());
}
template
<
typename
Type
>
mkldnn
::
memory
::
data_type
MKLDNNGetDataType
()
{
return
mkldnn
::
memory
::
data_undef
;
}
template
<
>
inline
mkldnn
::
memory
::
data_type
MKLDNNGetDataType
<
float
>
()
{
return
mkldnn
::
memory
::
f32
;
}
}
// namespace platform
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录