Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
0aa01929
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0aa01929
编写于
5月 17, 2018
作者:
K
Krzysztof Binias
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add backward
上级
0cc25a40
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
95 addition
and
63 deletion
+95
-63
paddle/fluid/operators/activation_mkldnn_op.cc
paddle/fluid/operators/activation_mkldnn_op.cc
+90
-63
paddle/fluid/platform/mkldnn_helper.h
paddle/fluid/platform/mkldnn_helper.h
+5
-0
未找到文件。
paddle/fluid/operators/activation_mkldnn_op.cc
浏览文件 @
0aa01929
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "mkldnn.hpp"
#include "mkldnn.hpp"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/mkldnn_activation_op.h"
#include "paddle/fluid/operators/mkldnn_activation_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -25,9 +26,14 @@ using paddle::platform::MKLDNNDeviceContext;
...
@@ -25,9 +26,14 @@ using paddle::platform::MKLDNNDeviceContext;
namespace
{
namespace
{
std
::
string
gethash
(
const
mkldnn
::
memory
::
dims
&
operand_dims
,
std
::
string
gethash
(
const
mkldnn
::
memory
::
dims
&
operand_dims
,
const
mkldnn
::
algorithm
algorithm
)
{
const
mkldnn
::
algorithm
algorithm
)
{
return
std
::
string
(
std
::
to_string
(
operand_dims
[
0
])
+
"-"
+
auto
dim2str
=
[](
const
mkldnn
::
memory
::
dims
&
operand_dims
)
{
std
::
to_string
(
operand_dims
[
1
])
+
"-"
+
std
::
string
dstr
=
""
;
std
::
to_string
(
algorithm
));
for
(
size_t
i
=
0
;
i
<
operand_dims
.
size
();
++
i
)
{
dstr
+=
std
::
to_string
(
operand_dims
[
i
])
+
"-"
;
}
return
dstr
;
};
return
dim2str
(
operand_dims
)
+
std
::
to_string
(
algorithm
);
}
}
template
<
typename
T
,
typename
ExecContext
>
template
<
typename
T
,
typename
ExecContext
>
...
@@ -44,7 +50,7 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
...
@@ -44,7 +50,7 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
const
auto
*
src_data
=
src
->
template
data
<
T
>();
const
auto
*
src_data
=
src
->
template
data
<
T
>();
auto
*
dst
=
ctx
.
template
Output
<
Tensor
>(
"Out"
);
auto
*
dst
=
ctx
.
template
Output
<
Tensor
>(
"Out"
);
const
T
*
dst_data
=
dst
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
());
T
*
dst_data
=
dst
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
());
// get memory dim
// get memory dim
PADDLE_ENFORCE
(
src
->
dims
().
size
()
==
2
||
src
->
dims
().
size
()
==
4
,
PADDLE_ENFORCE
(
src
->
dims
().
size
()
==
2
||
src
->
dims
().
size
()
==
4
,
...
@@ -52,15 +58,14 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
...
@@ -52,15 +58,14 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
std
::
vector
<
int
>
src_tz
=
framework
::
vectorize2int
(
src
->
dims
());
std
::
vector
<
int
>
src_tz
=
framework
::
vectorize2int
(
src
->
dims
());
const
std
::
string
key
=
gethash
(
src_tz
,
algorithm
);
const
std
::
string
key
=
gethash
(
src_tz
,
algorithm
);
const
std
::
string
key_src_mem
=
key
+
"@eltwise_src_mem"
;
const
std
::
string
key_src_mem
=
key
+
"@eltwise_
fwd_
src_mem"
;
const
std
::
string
key_dst_mem
=
key
+
"@eltwise_dst_mem"
;
const
std
::
string
key_dst_mem
=
key
+
"@eltwise_
fwd_
dst_mem"
;
const
std
::
string
key_fwd
=
key
+
"@eltwise_fwd"
;
const
std
::
string
key_fwd
=
key
+
"@eltwise_fwd"
;
std
::
shared_ptr
<
void
>
p_src_mem
=
dev_ctx
.
GetBlob
(
key_src_mem
);
auto
p_fwd
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_forward
>
(
std
::
shared_ptr
<
void
>
p_dst_mem
=
dev_ctx
.
GetBlob
(
key_dst_mem
);
dev_ctx
.
GetBlob
(
key_fwd
));
std
::
shared_ptr
<
void
>
p_fwd
=
dev_ctx
.
GetBlob
(
key_fwd
);
if
(
p_
src_mem
==
nullptr
||
p_dst_mem
==
nullptr
||
p_
fwd
==
nullptr
)
{
if
(
p_fwd
==
nullptr
)
{
// create memory description
// create memory description
auto
data_md
=
src_tz
.
size
()
==
2
auto
data_md
=
src_tz
.
size
()
==
2
?
platform
::
MKLDNNMemDesc
(
src_tz
,
mkldnn
::
memory
::
f32
,
?
platform
::
MKLDNNMemDesc
(
src_tz
,
mkldnn
::
memory
::
f32
,
...
@@ -69,35 +74,40 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
...
@@ -69,35 +74,40 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
mkldnn
::
memory
::
format
::
nchw
);
mkldnn
::
memory
::
format
::
nchw
);
// create memory primitives
// create memory primitives
p_src_mem
=
std
::
make_shared
<
mkldnn
::
memory
>
(
auto
p_src_mem
=
std
::
make_shared
<
mkldnn
::
memory
>
(
mkldnn
::
memory
(
mkldnn
::
memory
({
data_md
,
mkldnn_engine
},
{
data_md
,
mkldnn_engine
},
platform
::
to_void_cast
(
src_data
)));
static_cast
<
void
*>
(
const_cast
<
float
*>
(
src_data
))));
dev_ctx
.
SetBlob
(
key_src_mem
,
p_src_mem
);
dev_ctx
.
SetBlob
(
key_src_mem
,
p_src_mem
);
p_dst_mem
=
std
::
make_shared
<
mkldnn
::
memory
>
(
auto
p_dst_mem
=
std
::
make_shared
<
mkldnn
::
memory
>
(
mkldnn
::
memory
(
mkldnn
::
memory
({
data_md
,
mkldnn_engine
},
{
data_md
,
mkldnn_engine
},
platform
::
to_void_cast
(
dst_data
)));
static_cast
<
void
*>
(
const_cast
<
float
*>
(
dst_data
))));
dev_ctx
.
SetBlob
(
key_dst_mem
,
p_dst_mem
);
dev_ctx
.
SetBlob
(
key_dst_mem
,
p_dst_mem
);
auto
fwd_desc
=
mkldnn
::
eltwise_forward
::
desc
(
auto
fwd_desc
=
mkldnn
::
eltwise_forward
::
desc
(
mkldnn
::
prop_kind
::
forward_training
,
algorithm
,
data_md
,
alpha
,
beta
);
mkldnn
::
prop_kind
::
forward_training
,
algorithm
,
data_md
,
alpha
,
beta
);
auto
p_fwd_pd
=
std
::
make_shared
<
mkldnn
::
eltwise_forward
::
primitive_desc
>
(
auto
p_fwd_pd
=
std
::
make_shared
<
mkldnn
::
eltwise_forward
::
primitive_desc
>
(
fwd_desc
,
mkldnn_engine
);
fwd_desc
,
mkldnn_engine
);
const
std
::
string
key_fwd_pd
=
key
+
"eltwise_fwd_pd"
;
dev_ctx
.
SetBlob
(
key_fwd_pd
,
p_fwd_pd
);
p_fwd
=
std
::
make_shared
<
mkldnn
::
eltwise_forward
>
(
p_fwd
=
std
::
make_shared
<
mkldnn
::
eltwise_forward
>
(
*
(
p_fwd_pd
.
get
()),
*
(
static_cast
<
mkldnn
::
memory
*>
(
p_src_mem
.
get
())),
*
p_fwd_pd
,
*
(
p_src_mem
.
get
()),
*
(
p_dst_mem
.
get
()));
*
(
static_cast
<
mkldnn
::
memory
*>
(
p_dst_mem
.
get
())));
dev_ctx
.
SetBlob
(
key_fwd
,
p_fwd
);
dev_ctx
.
SetBlob
(
key_fwd
,
p_fwd
);
}
else
{
}
else
{
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
p_src_mem
)
->
set_data_handle
(
// primitives already exist
reinterpret_cast
<
void
*>
(
const_cast
<
T
*>
(
src_data
)));
auto
p_src_mem
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx
.
GetBlob
(
key_src_mem
));
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
p_dst_mem
)
->
set_data_handle
(
PADDLE_ENFORCE
(
p_src_mem
!=
nullptr
,
reinterpret_cast
<
void
*>
(
const_cast
<
T
*>
(
dst_data
)));
"Fail to find eltwise p_src_mem in device context."
);
auto
p_dst_mem
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx
.
GetBlob
(
key_dst_mem
));
PADDLE_ENFORCE
(
p_dst_mem
!=
nullptr
,
"Fail to find eltwise p_src_mem in device context."
);
p_src_mem
->
set_data_handle
(
platform
::
to_void_reinterpret_cast
(
src_data
));
p_dst_mem
->
set_data_handle
(
dst_data
);
}
}
// push primitive to stream and wait until it's executed
// push primitive to stream and wait until it's executed
std
::
vector
<
mkldnn
::
primitive
>
pipeline
=
{
std
::
vector
<
mkldnn
::
primitive
>
pipeline
=
{
*
(
p_fwd
.
get
())};
*
(
static_cast
<
mkldnn
::
eltwise_forward
::
primitive
*>
(
p_fwd
.
get
()))};
mkldnn
::
stream
(
mkldnn
::
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
mkldnn
::
stream
(
mkldnn
::
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
}
}
...
@@ -121,10 +131,15 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
...
@@ -121,10 +131,15 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
std
::
vector
<
int
>
src_tz
=
framework
::
vectorize2int
(
out
->
dims
());
std
::
vector
<
int
>
src_tz
=
framework
::
vectorize2int
(
out
->
dims
());
const
std
::
string
key
=
gethash
(
src_tz
,
algorithm
);
const
std
::
string
key
=
gethash
(
src_tz
,
algorithm
);
const
std
::
string
key_src_mem
=
key
+
"@eltwise_src_mem"
;
const
std
::
string
key_dst_mem
=
key
+
"@eltwise_dst_mem"
;
const
std
::
string
key_fwd
=
key
+
"@eltwise_fwd"
;
const
std
::
string
key_diff_src_mem
=
key
+
"@eltwise_diff_src_mem"
;
const
std
::
string
key_diff_dst_mem
=
key
+
"@eltwise_diff_dst_mem"
;
const
std
::
string
key_grad
=
key
+
"@eltwise_grad"
;
auto
p_grad
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_forward
::
primitive
>
(
dev_ctx
.
GetBlob
(
key_grad
));
if
(
p_grad
==
nullptr
)
{
// create memory description
// create memory description
auto
data_md
=
src_tz
.
size
()
==
2
auto
data_md
=
src_tz
.
size
()
==
2
?
platform
::
MKLDNNMemDesc
(
src_tz
,
mkldnn
::
memory
::
f32
,
?
platform
::
MKLDNNMemDesc
(
src_tz
,
mkldnn
::
memory
::
f32
,
...
@@ -132,36 +147,48 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
...
@@ -132,36 +147,48 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
:
platform
::
MKLDNNMemDesc
(
src_tz
,
mkldnn
::
memory
::
f32
,
:
platform
::
MKLDNNMemDesc
(
src_tz
,
mkldnn
::
memory
::
f32
,
mkldnn
::
memory
::
format
::
nchw
);
mkldnn
::
memory
::
format
::
nchw
);
// retrieve source memory from device context
const
std
::
shared_ptr
<
void
>
src_mem
=
dev_ctx
.
GetBlob
(
key_src_mem
);
auto
*
p_src_mem
=
static_cast
<
mkldnn
::
memory
*>
(
src_mem
.
get
());
// create memory primitives
// create memory primitives
auto
diff_src_memory
=
std
::
shared_ptr
<
void
>
p_diff_src_mem
=
mkldnn
::
memory
({
data_md
,
mkldnn_engine
},
std
::
make_shared
<
mkldnn
::
memory
>
(
mkldnn
::
memory
(
static_cast
<
void
*>
(
const_cast
<
float
*>
(
diff_src
)));
{
data_md
,
mkldnn_engine
},
platform
::
to_void_cast
(
diff_src
)));
auto
diff_dst_memory
=
dev_ctx
.
SetBlob
(
key_diff_src_mem
,
p_diff_src_mem
);
mkldnn
::
memory
({
data_md
,
mkldnn_engine
},
std
::
shared_ptr
<
void
>
p_diff_dst_mem
=
static_cast
<
void
*>
(
const_cast
<
float
*>
(
diff_dst
)));
std
::
make_shared
<
mkldnn
::
memory
>
(
mkldnn
::
memory
(
{
data_md
,
mkldnn_engine
},
platform
::
to_void_cast
(
diff_dst
)));
auto
backward_desc
=
dev_ctx
.
SetBlob
(
key_diff_dst_mem
,
p_diff_dst_mem
);
mkldnn
::
eltwise_backward
::
desc
(
algorithm
,
data_md
,
data_md
,
alpha
,
beta
);
auto
bwd_desc
=
mkldnn
::
eltwise_backward
::
desc
(
algorithm
,
data_md
,
data_md
,
// retrieve eltwise primitive desc from device context
alpha
,
beta
);
const
std
::
shared_ptr
<
void
>
forward_pd
=
dev_ctx
.
GetBlob
(
key_fwd
);
PADDLE_ENFORCE
(
forward_pd
!=
nullptr
,
const
std
::
string
key_fwd_pd
=
key
+
"eltwise_fwd_pd"
;
"Fail to find eltwise_pd in device context"
);
auto
*
p_fwd_pd
=
static_cast
<
mkldnn
::
eltwise_forward
::
primitive_desc
*>
(
auto
*
p_forward_pd
=
dev_ctx
.
GetBlob
(
key_fwd_pd
).
get
());
static_cast
<
mkldnn
::
eltwise_forward
::
primitive_desc
*>
(
forward_pd
.
get
());
auto
eltwise_bwd_prim_desc
=
mkldnn
::
eltwise_backward
::
primitive_desc
(
auto
eltwise_bwd_prim_desc
=
mkldnn
::
eltwise_backward
::
primitive_desc
(
backward_desc
,
mkldnn_engine
,
*
p_forward_pd
);
bwd_desc
,
mkldnn_engine
,
*
p_fwd_pd
);
const
std
::
string
key_src_mem
=
key
+
"@eltwise_fwd_src_mem"
;
const
std
::
shared_ptr
<
void
>
p_src_mem
=
dev_ctx
.
GetBlob
(
key_src_mem
);
auto
eltwise_bwd
=
mkldnn
::
eltwise_backward
(
eltwise_bwd_prim_desc
,
*
p_src_mem
,
p_grad
=
std
::
make_shared
<
mkldnn
::
eltwise_backward
>
(
diff_dst_memory
,
diff_src_memory
);
eltwise_bwd_prim_desc
,
*
static_cast
<
mkldnn
::
memory
*>
(
p_src_mem
.
get
()),
*
(
static_cast
<
mkldnn
::
memory
*>
(
p_diff_dst_mem
.
get
())),
*
(
static_cast
<
mkldnn
::
memory
*>
(
p_diff_src_mem
.
get
())));
}
else
{
// primitives already exist
auto
p_diff_src_mem
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx
.
GetBlob
(
key_diff_src_mem
));
auto
p_diff_dst_mem
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx
.
GetBlob
(
key_diff_dst_mem
));
p_diff_src_mem
->
set_data_handle
(
platform
::
to_void_reinterpret_cast
(
diff_src
));
p_diff_dst_mem
->
set_data_handle
(
platform
::
to_void_reinterpret_cast
(
diff_dst
));
}
// push primitive to stream and wait until it's executed
// push primitive to stream and wait until it's executed
std
::
vector
<
mkldnn
::
primitive
>
pipeline
=
{
eltwise_bwd
};
std
::
vector
<
mkldnn
::
primitive
>
pipeline
=
{
*
(
p_grad
.
get
())
};
mkldnn
::
stream
(
mkldnn
::
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
mkldnn
::
stream
(
mkldnn
::
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
}
}
}
// anonymous namespace
}
// anonymous namespace
...
...
paddle/fluid/platform/mkldnn_helper.h
浏览文件 @
0aa01929
...
@@ -38,6 +38,11 @@ void* to_void_cast(const Type* t) {
...
@@ -38,6 +38,11 @@ void* to_void_cast(const Type* t) {
return
static_cast
<
void
*>
(
const_cast
<
Type
*>
(
t
));
return
static_cast
<
void
*>
(
const_cast
<
Type
*>
(
t
));
}
}
template
<
typename
Type
>
void
*
to_void_reinterpret_cast
(
const
Type
*
t
)
{
return
reinterpret_cast
<
void
*>
(
const_cast
<
Type
*>
(
t
));
}
template
<
class
Type
>
template
<
class
Type
>
using
tf_desc
=
typename
Type
::
desc
;
using
tf_desc
=
typename
Type
::
desc
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录