Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
0cc25a40
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0cc25a40
编写于
5月 16, 2018
作者:
K
Krzysztof Binias
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Realloc for forward
上级
a76d0dd4
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
64 addition
and
38 deletion
+64
-38
paddle/fluid/operators/activation_mkldnn_op.cc
paddle/fluid/operators/activation_mkldnn_op.cc
+64
-38
未找到文件。
paddle/fluid/operators/activation_mkldnn_op.cc
浏览文件 @
0cc25a40
...
@@ -23,6 +23,13 @@ using paddle::framework::Tensor;
...
@@ -23,6 +23,13 @@ using paddle::framework::Tensor;
using
paddle
::
platform
::
MKLDNNDeviceContext
;
using
paddle
::
platform
::
MKLDNNDeviceContext
;
namespace
{
namespace
{
std
::
string
gethash
(
const
mkldnn
::
memory
::
dims
&
operand_dims
,
const
mkldnn
::
algorithm
algorithm
)
{
return
std
::
string
(
std
::
to_string
(
operand_dims
[
0
])
+
"-"
+
std
::
to_string
(
operand_dims
[
1
])
+
"-"
+
std
::
to_string
(
algorithm
));
}
template
<
typename
T
,
typename
ExecContext
>
template
<
typename
T
,
typename
ExecContext
>
void
eltwise_forward
(
const
ExecContext
&
ctx
,
mkldnn
::
algorithm
algorithm
,
void
eltwise_forward
(
const
ExecContext
&
ctx
,
mkldnn
::
algorithm
algorithm
,
const
T
alpha
=
0
,
const
T
beta
=
0
)
{
const
T
alpha
=
0
,
const
T
beta
=
0
)
{
...
@@ -44,6 +51,16 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
...
@@ -44,6 +51,16 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
"Input dim must be with 2 or 4"
);
"Input dim must be with 2 or 4"
);
std
::
vector
<
int
>
src_tz
=
framework
::
vectorize2int
(
src
->
dims
());
std
::
vector
<
int
>
src_tz
=
framework
::
vectorize2int
(
src
->
dims
());
const
std
::
string
key
=
gethash
(
src_tz
,
algorithm
);
const
std
::
string
key_src_mem
=
key
+
"@eltwise_src_mem"
;
const
std
::
string
key_dst_mem
=
key
+
"@eltwise_dst_mem"
;
const
std
::
string
key_fwd
=
key
+
"@eltwise_fwd"
;
std
::
shared_ptr
<
void
>
p_src_mem
=
dev_ctx
.
GetBlob
(
key_src_mem
);
std
::
shared_ptr
<
void
>
p_dst_mem
=
dev_ctx
.
GetBlob
(
key_dst_mem
);
std
::
shared_ptr
<
void
>
p_fwd
=
dev_ctx
.
GetBlob
(
key_fwd
);
if
(
p_src_mem
==
nullptr
||
p_dst_mem
==
nullptr
||
p_fwd
==
nullptr
)
{
// create memory description
// create memory description
auto
data_md
=
src_tz
.
size
()
==
2
auto
data_md
=
src_tz
.
size
()
==
2
?
platform
::
MKLDNNMemDesc
(
src_tz
,
mkldnn
::
memory
::
f32
,
?
platform
::
MKLDNNMemDesc
(
src_tz
,
mkldnn
::
memory
::
f32
,
...
@@ -52,29 +69,35 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
...
@@ -52,29 +69,35 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
mkldnn
::
memory
::
format
::
nchw
);
mkldnn
::
memory
::
format
::
nchw
);
// create memory primitives
// create memory primitives
auto
src_memory
=
std
::
make_shared
<
mkldnn
::
memory
>
(
p_src_mem
=
std
::
make_shared
<
mkldnn
::
memory
>
(
mkldnn
::
memory
({
data_md
,
mkldnn_engine
},
mkldnn
::
memory
({
data_md
,
mkldnn_engine
},
static_cast
<
void
*>
(
const_cast
<
float
*>
(
src_data
))));
static_cast
<
void
*>
(
const_cast
<
float
*>
(
src_data
))));
// save source memory to device context to be referred in backward path
dev_ctx
.
SetBlob
(
key_src_mem
,
p_src_mem
);
dev_ctx
.
SetBlob
(
"InputX@eltwise_pd"
,
src_memory
);
auto
dst_memory
=
p_dst_mem
=
std
::
make_shared
<
mkldnn
::
memory
>
(
mkldnn
::
memory
({
data_md
,
mkldnn_engine
},
mkldnn
::
memory
({
data_md
,
mkldnn_engine
},
static_cast
<
void
*>
(
const_cast
<
float
*>
(
dst_data
)));
static_cast
<
void
*>
(
const_cast
<
float
*>
(
dst_data
))));
dev_ctx
.
SetBlob
(
key_dst_mem
,
p_dst_mem
);
auto
forwar
d_desc
=
mkldnn
::
eltwise_forward
::
desc
(
auto
fw
d_desc
=
mkldnn
::
eltwise_forward
::
desc
(
mkldnn
::
prop_kind
::
forward_training
,
algorithm
,
data_md
,
alpha
,
beta
);
mkldnn
::
prop_kind
::
forward_training
,
algorithm
,
data_md
,
alpha
,
beta
);
auto
p_fwd_pd
=
std
::
make_shared
<
mkldnn
::
eltwise_forward
::
primitive_desc
>
(
// save prim desc into global device context to be referred in backward path
fwd_desc
,
mkldnn_engine
);
const
std
::
string
key
=
ctx
.
op
().
Output
(
"Out"
);
p_fwd
=
std
::
make_shared
<
mkldnn
::
eltwise_forward
>
(
const
std
::
string
key_eltwise_pd
=
key
+
"@eltwise_pd"
;
*
(
p_fwd_pd
.
get
()),
*
(
static_cast
<
mkldnn
::
memory
*>
(
p_src_mem
.
get
())),
auto
forward_pd
=
std
::
make_shared
<
mkldnn
::
eltwise_forward
::
primitive_desc
>
(
*
(
static_cast
<
mkldnn
::
memory
*>
(
p_dst_mem
.
get
())));
forward_desc
,
mkldnn_engine
);
dev_ctx
.
SetBlob
(
key_fwd
,
p_fwd
);
dev_ctx
.
SetBlob
(
key_eltwise_pd
,
forward_pd
);
}
else
{
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
p_src_mem
)
->
set_data_handle
(
auto
eltwise
=
mkldnn
::
eltwise_forward
(
*
forward_pd
,
*
src_memory
,
dst_memory
);
reinterpret_cast
<
void
*>
(
const_cast
<
T
*>
(
src_data
)));
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
p_dst_mem
)
->
set_data_handle
(
reinterpret_cast
<
void
*>
(
const_cast
<
T
*>
(
dst_data
)));
}
// push primitive to stream and wait until it's executed
// push primitive to stream and wait until it's executed
std
::
vector
<
mkldnn
::
primitive
>
pipeline
=
{
eltwise
};
std
::
vector
<
mkldnn
::
primitive
>
pipeline
=
{
*
(
static_cast
<
mkldnn
::
eltwise_forward
::
primitive
*>
(
p_fwd
.
get
()))};
mkldnn
::
stream
(
mkldnn
::
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
mkldnn
::
stream
(
mkldnn
::
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
}
}
...
@@ -85,7 +108,7 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
...
@@ -85,7 +108,7 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
const
auto
&
mkldnn_engine
=
dev_ctx
.
GetEngine
();
const
auto
&
mkldnn_engine
=
dev_ctx
.
GetEngine
();
// get buffers
// get buffers
const
auto
*
x
=
ctx
.
template
Input
<
Tensor
>(
"Out"
);
const
auto
*
out
=
ctx
.
template
Input
<
Tensor
>(
"Out"
);
auto
*
dout
=
ctx
.
template
Input
<
Tensor
>(
framework
::
GradVarName
(
"Out"
));
auto
*
dout
=
ctx
.
template
Input
<
Tensor
>(
framework
::
GradVarName
(
"Out"
));
const
auto
*
diff_dst
=
dout
->
template
data
<
T
>();
const
auto
*
diff_dst
=
dout
->
template
data
<
T
>();
...
@@ -95,7 +118,12 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
...
@@ -95,7 +118,12 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
const
T
*
diff_src
=
dx
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
());
const
T
*
diff_src
=
dx
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
());
// get memory dim
// get memory dim
std
::
vector
<
int
>
src_tz
=
framework
::
vectorize2int
(
x
->
dims
());
std
::
vector
<
int
>
src_tz
=
framework
::
vectorize2int
(
out
->
dims
());
const
std
::
string
key
=
gethash
(
src_tz
,
algorithm
);
const
std
::
string
key_src_mem
=
key
+
"@eltwise_src_mem"
;
const
std
::
string
key_dst_mem
=
key
+
"@eltwise_dst_mem"
;
const
std
::
string
key_fwd
=
key
+
"@eltwise_fwd"
;
// create memory description
// create memory description
auto
data_md
=
src_tz
.
size
()
==
2
auto
data_md
=
src_tz
.
size
()
==
2
...
@@ -105,8 +133,8 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
...
@@ -105,8 +133,8 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
mkldnn
::
memory
::
format
::
nchw
);
mkldnn
::
memory
::
format
::
nchw
);
// retrieve source memory from device context
// retrieve source memory from device context
const
std
::
shared_ptr
<
void
>
src_mem
ory
=
dev_ctx
.
GetBlob
(
"InputX@eltwise_pd"
);
const
std
::
shared_ptr
<
void
>
src_mem
=
dev_ctx
.
GetBlob
(
key_src_mem
);
auto
*
p_src_mem
ory
=
static_cast
<
mkldnn
::
memory
*>
(
src_memory
.
get
());
auto
*
p_src_mem
=
static_cast
<
mkldnn
::
memory
*>
(
src_mem
.
get
());
// create memory primitives
// create memory primitives
auto
diff_src_memory
=
auto
diff_src_memory
=
...
@@ -120,9 +148,7 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
...
@@ -120,9 +148,7 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
mkldnn
::
eltwise_backward
::
desc
(
algorithm
,
data_md
,
data_md
,
alpha
,
beta
);
mkldnn
::
eltwise_backward
::
desc
(
algorithm
,
data_md
,
data_md
,
alpha
,
beta
);
// retrieve eltwise primitive desc from device context
// retrieve eltwise primitive desc from device context
const
std
::
string
key
=
ctx
.
op
().
Input
(
"Out"
);
const
std
::
shared_ptr
<
void
>
forward_pd
=
dev_ctx
.
GetBlob
(
key_fwd
);
const
std
::
string
key_eltwise_pd
=
key
+
"@eltwise_pd"
;
const
std
::
shared_ptr
<
void
>
forward_pd
=
dev_ctx
.
GetBlob
(
key_eltwise_pd
);
PADDLE_ENFORCE
(
forward_pd
!=
nullptr
,
PADDLE_ENFORCE
(
forward_pd
!=
nullptr
,
"Fail to find eltwise_pd in device context"
);
"Fail to find eltwise_pd in device context"
);
auto
*
p_forward_pd
=
auto
*
p_forward_pd
=
...
@@ -131,8 +157,8 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
...
@@ -131,8 +157,8 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
auto
eltwise_bwd_prim_desc
=
mkldnn
::
eltwise_backward
::
primitive_desc
(
auto
eltwise_bwd_prim_desc
=
mkldnn
::
eltwise_backward
::
primitive_desc
(
backward_desc
,
mkldnn_engine
,
*
p_forward_pd
);
backward_desc
,
mkldnn_engine
,
*
p_forward_pd
);
auto
eltwise_bwd
=
mkldnn
::
eltwise_backward
(
auto
eltwise_bwd
=
mkldnn
::
eltwise_backward
(
eltwise_bwd_prim_desc
,
*
p_src_mem
,
eltwise_bwd_prim_desc
,
*
p_src_memory
,
diff_dst_memory
,
diff_src_memory
);
diff_dst_memory
,
diff_src_memory
);
// push primitive to stream and wait until it's executed
// push primitive to stream and wait until it's executed
std
::
vector
<
mkldnn
::
primitive
>
pipeline
=
{
eltwise_bwd
};
std
::
vector
<
mkldnn
::
primitive
>
pipeline
=
{
eltwise_bwd
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录