Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
8c7d2e29
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8c7d2e29
编写于
5月 14, 2018
作者:
T
Tao Luo
提交者:
GitHub
5月 14, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #10576 from jczaja/prv-reuse-mkldnn-softmax-primitives
Reusing of softmax mkldnn primitives
上级
1c4a399d
7bf00c3a
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
54 addition
and
19 deletion
+54
-19
paddle/fluid/operators/softmax_mkldnn_op.cc
paddle/fluid/operators/softmax_mkldnn_op.cc
+54
-19
未找到文件。
paddle/fluid/operators/softmax_mkldnn_op.cc
浏览文件 @
8c7d2e29
...
...
@@ -53,25 +53,60 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
"Softmax input and output dimensions should match"
);
// Same memory descriptor to be used for input and output
memory
::
dims
softmax_tz
=
{
src_tz
[
0
],
src_tz
[
1
]};
// Currently only supports NC data format
// TODO(jczaja-intel): support more formats
auto
softmax_md
=
MKLDNNMemDesc
({
softmax_tz
},
memory
::
f32
,
memory
::
format
::
nc
);
// Normalization is made after innermost dimension eg. C out of NC
auto
softmax_desc
=
softmax_forward
::
desc
(
prop_kind
::
forward_scoring
,
softmax_md
,
1
/*dim: C*/
);
// create memory primitives
auto
softmax_src_memory
=
memory
({
softmax_md
,
mkldnn_engine
},
static_cast
<
void
*>
(
const_cast
<
T
*>
(
input_data
)));
auto
softmax_dst_memory
=
memory
({
softmax_md
,
mkldnn_engine
},
static_cast
<
void
*>
(
const_cast
<
T
*>
(
output_data
)));
auto
softmax_prim_desc
=
softmax_forward
::
primitive_desc
(
softmax_desc
,
mkldnn_engine
);
auto
softmax
=
softmax_forward
(
softmax_prim_desc
,
softmax_src_memory
,
softmax_dst_memory
);
std
::
vector
<
primitive
>
pipeline
{
softmax
};
// Generate keys for storing/retriving primitives for this operator
// TODO(jczaja): Each MKLDNN operator may have diffrent hashing function
auto
gethash
=
[](
memory
::
dims
&
operand_dims
)
{
return
std
::
string
(
std
::
to_string
(
operand_dims
[
0
])
+
"-"
+
std
::
to_string
(
operand_dims
[
1
]));
};
const
std
::
string
key
=
gethash
(
softmax_tz
);
const
std
::
string
key_softmax_p
=
key
+
"@softmax_p"
;
const
std
::
string
key_softmax_src_mem_p
=
key
+
"@softmax_src_mem_p"
;
const
std
::
string
key_softmax_dst_mem_p
=
key
+
"@softmax_dst_mem_p"
;
std
::
shared_ptr
<
void
>
softmax_p
=
dev_ctx
.
GetBlob
(
key_softmax_p
);
if
(
softmax_p
==
nullptr
)
{
// Currently only NC data format is supported
auto
softmax_md
=
MKLDNNMemDesc
({
softmax_tz
},
memory
::
f32
,
memory
::
format
::
nc
);
// Normalization is made after innermost dimension eg. C out of NC
auto
softmax_desc
=
softmax_forward
::
desc
(
prop_kind
::
forward_scoring
,
softmax_md
,
1
/*dim: C*/
);
// create memory primitives
auto
softmax_src_memory_p
=
std
::
make_shared
<
memory
>
(
memory
::
primitive_desc
{
softmax_md
,
mkldnn_engine
},
static_cast
<
void
*>
(
const_cast
<
T
*>
(
input_data
)));
dev_ctx
.
SetBlob
(
key_softmax_src_mem_p
,
softmax_src_memory_p
);
auto
softmax_dst_memory_p
=
std
::
make_shared
<
memory
>
(
memory
::
primitive_desc
{
softmax_md
,
mkldnn_engine
},
static_cast
<
void
*>
(
output_data
));
dev_ctx
.
SetBlob
(
key_softmax_dst_mem_p
,
softmax_dst_memory_p
);
auto
softmax_forward_pd
=
std
::
make_shared
<
softmax_forward
::
primitive_desc
>
(
softmax_desc
,
mkldnn_engine
);
softmax_p
=
std
::
make_shared
<
softmax_forward
>
(
*
(
softmax_forward_pd
.
get
()),
*
(
static_cast
<
memory
*>
(
softmax_src_memory_p
.
get
())),
*
(
static_cast
<
memory
*>
(
softmax_dst_memory_p
.
get
())));
dev_ctx
.
SetBlob
(
key_softmax_p
,
softmax_p
);
}
else
{
// Primitives already exist
auto
src_memory_p
=
std
::
static_pointer_cast
<
memory
>
(
dev_ctx
.
GetBlob
(
key_softmax_src_mem_p
));
PADDLE_ENFORCE
(
src_memory_p
!=
nullptr
,
"Fail to find softmax src mem_p in device context"
);
auto
dst_memory_p
=
std
::
static_pointer_cast
<
memory
>
(
dev_ctx
.
GetBlob
(
key_softmax_dst_mem_p
));
PADDLE_ENFORCE
(
dst_memory_p
!=
nullptr
,
"Fail to find softmax dst mem_p in device context"
);
src_memory_p
->
set_data_handle
(
reinterpret_cast
<
void
*>
(
const_cast
<
T
*>
(
input_data
)));
dst_memory_p
->
set_data_handle
(
output_data
);
}
std
::
vector
<
primitive
>
pipeline
{
*
(
static_cast
<
softmax_forward
::
primitive
*>
(
softmax_p
.
get
()))};
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
const
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录