Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
8869d7f7
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
8869d7f7
编写于
7月 10, 2019
作者:
J
Jacek Czaja
提交者:
Tao Luo
7月 10, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Activations MKLDNN ops refactoring (#18191)
上级
b6d5c74f
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
178 addition
and
132 deletion
+178
-132
paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc
paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc
+63
-132
paddle/fluid/platform/mkldnn_reuse.h
paddle/fluid/platform/mkldnn_reuse.h
+115
-0
未找到文件。
paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc
浏览文件 @
8869d7f7
...
...
@@ -13,7 +13,7 @@
limitations under the License. */
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/platform/mkldnn_
helper
.h"
#include "paddle/fluid/platform/mkldnn_
reuse
.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -99,20 +99,21 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
auto
src_format
=
src_tz
.
size
()
==
2
?
mkldnn
::
memory
::
format
::
nc
:
x
->
format
();
const
std
::
string
key
=
gethash
(
src_tz
,
algorithm
);
const
std
::
string
key_src_data
=
key
+
ctx
.
op
().
Output
(
"Out"
)
+
"@eltwise_fwd_src_data"
;
const
std
::
string
key_src_layout
=
key
+
ctx
.
op
().
Output
(
"Out"
)
+
"@eltwise_fwd_src_layout"
;
const
std
::
string
key_with_layout
=
key
+
std
::
to_string
(
src_format
);
const
std
::
string
key_src_mem
=
key_with_layout
+
"@eltwise_fwd_src_mem"
;
const
std
::
string
key_dst_mem
=
key_with_layout
+
"@eltwise_fwd_dst_mem"
;
const
std
::
string
key_fwd
=
key_with_layout
+
"@eltwise_fwd"
;
const
std
::
string
key_fwd_pd
=
key_with_layout
+
"@eltwise_fwd_pd"
;
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
// TODO(jczaja): When adding leaky-relu , swish , elu make sure to extend key
// with alpha, beta
std
::
string
key
=
platform
::
MKLDNNHandler
::
GetHash
(
src_tz
,
std
::
to_string
(
algorithm
)
+
ctx
.
op
().
Output
(
"Out"
));
// TODO(jczaja): Make it Thread safe
// save input data and layout to be referred in backward path
const
std
::
string
key_src_data
=
key
+
"@eltwise_fwd_src_data"
;
const
std
::
string
key_src_layout
=
key
+
"@eltwise_fwd_src_layout"
;
// Just in case some int8 models are run interchangebly
// with float models then format maybe diffrent
key
+=
std
::
to_string
(
src_format
);
const
std
::
string
key_src_mem
=
key
+
"@eltwise_fwd_src_mem"
;
auto
p_src_data
=
std
::
make_shared
<
const
T
*>
(
x_data
);
auto
p_src_layout
=
std
::
make_shared
<
memory
::
format
>
(
src_format
);
if
(
!
is_test
)
{
...
...
@@ -120,65 +121,34 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
dev_ctx
.
SetBlob
(
key_src_layout
,
p_src_layout
);
}
auto
p_fwd
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_forward
>
(
dev_ctx
.
GetBlob
(
key_fwd
));
std
::
shared_ptr
<
memory
>
dst_memory
;
if
(
p_fwd
==
nullptr
)
{
// create mkldnn memory for input X
auto
src_md
=
platform
::
MKLDNNMemDesc
(
src_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
src_format
);
auto
src_memory
=
std
::
shared_ptr
<
memory
>
(
new
memory
({
src_md
,
mkldnn_engine
},
to_void_cast
(
x_data
)));
// save src_memory to be referred in backward path
dev_ctx
.
SetBlob
(
key_src_mem
,
src_memory
);
// create primitive descriptor for activation forward and save it
auto
mkldnn_forward_prop_kind
=
is_test
?
mkldnn
::
prop_kind
::
forward_inference
:
mkldnn
::
prop_kind
::
forward_training
;
auto
forward_desc
=
mkldnn
::
eltwise_forward
::
desc
(
mkldnn_forward_prop_kind
,
algorithm
,
src_memory
->
get_primitive_desc
().
desc
(),
alpha
,
beta
);
auto
forward_pd
=
std
::
make_shared
<
mkldnn
::
eltwise_forward
::
primitive_desc
>
(
forward_desc
,
mkldnn_engine
);
// save prim desc into global device context to be referred in backward path
if
(
!
is_test
)
dev_ctx
.
SetBlob
(
key_fwd_pd
,
forward_pd
);
// create mkldnn memory for output y
dst_memory
=
std
::
make_shared
<
memory
>
(
forward_pd
->
dst_primitive_desc
(),
y_data
);
dev_ctx
.
SetBlob
(
key_dst_mem
,
dst_memory
);
// create activation primitive
p_fwd
=
std
::
make_shared
<
mkldnn
::
eltwise_forward
>
(
*
forward_pd
,
*
src_memory
,
*
dst_memory
);
dev_ctx
.
SetBlob
(
key_fwd
,
p_fwd
);
}
else
{
// primitives already exist
auto
src_memory
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx
.
GetBlob
(
key_src_mem
));
PADDLE_ENFORCE
(
src_memory
!=
nullptr
,
"Fail to find eltwise src_memory in device context."
);
dst_memory
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx
.
GetBlob
(
key_dst_mem
));
PADDLE_ENFORCE
(
dst_memory
!=
nullptr
,
"Fail to find eltwise dst_memory in device context."
);
src_memory
->
set_data_handle
(
platform
::
to_void_cast
(
x_data
));
dst_memory
->
set_data_handle
(
y_data
);
platform
::
ActivationMKLDNNHandler
handler
(
dev_ctx
,
mkldnn_engine
,
key
);
auto
md
=
platform
::
MKLDNNMemDesc
(
src_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
src_format
);
auto
activation_pd
=
handler
.
AcquireActivationPrimitiveDescriptor
(
is_test
?
mkldnn
::
prop_kind
::
forward_inference
:
mkldnn
::
prop_kind
::
forward_training
,
algorithm
,
md
,
alpha
,
beta
);
auto
src_memory_p
=
handler
.
AcquireSrcMemory
(
md
,
to_void_cast
<
T
>
(
x_data
));
// jczaja: Workaround, src_memory_p is needed in BWD so it has
// to be accessible under key not dependant on TID
if
(
!
is_test
)
{
dev_ctx
.
SetBlob
(
key_src_mem
,
src_memory_p
);
}
auto
dst_memory_p
=
handler
.
AcquireDstMemoryFromPrimitive
(
to_void_cast
<
T
>
(
y_data
));
auto
activation_p
=
handler
.
AcquireActivation
(
dst_memory_p
,
src_memory_p
);
// push primitive to stream and wait until it's executed
std
::
vector
<
primitive
>
pipeline
;
pipeline
.
push_back
(
*
p_fwd
);
pipeline
.
push_back
(
*
activation_p
);
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
y
->
set_layout
(
DataLayout
::
kMKLDNN
);
y
->
set_format
(
GetMKLDNNFormat
(
*
dst_memory
));
y
->
set_format
(
GetMKLDNNFormat
(
*
dst_memory
_p
));
}
template
<
typename
T
>
...
...
@@ -199,90 +169,51 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
auto
diff_y_format
=
diff_dst_tz
.
size
()
==
2
?
mkldnn
::
memory
::
format
::
nc
:
diff_y
->
format
();
const
std
::
string
key
=
gethash
(
diff_dst_tz
,
algorithm
);
const
std
::
string
key_src_data
=
key
+
ctx
.
op
().
Input
(
"Out"
)
+
"@eltwise_fwd_src_data"
;
const
std
::
string
key_src_layout
=
key
+
ctx
.
op
().
Input
(
"Out"
)
+
"@eltwise_fwd_src_layout"
;
auto
diff_dst_md
=
platform
::
MKLDNNMemDesc
(
diff_dst_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
diff_y_format
);
std
::
string
key
=
platform
::
MKLDNNHandler
::
GetHash
(
diff_dst_tz
,
std
::
to_string
(
algorithm
)
+
ctx
.
op
().
Input
(
"Out"
));
const
std
::
string
key_src_data
=
key
+
"@eltwise_fwd_src_data"
;
const
std
::
string
key_src_layout
=
key
+
"@eltwise_fwd_src_layout"
;
// Get Data from FWD op
const
auto
p_src_layout
=
std
::
static_pointer_cast
<
memory
::
format
>
(
dev_ctx
.
GetBlob
(
key_src_layout
));
const
std
::
string
key_src_mem
=
key
+
std
::
to_string
(
*
p_src_layout
)
+
"@eltwise_fwd_src_mem"
;
const
std
::
string
key_fwd_pd
=
key
+
std
::
to_string
(
*
p_src_layout
)
+
"@eltwise_fwd_pd"
;
const
std
::
string
key_with_layouts
=
key
+
std
::
to_string
(
*
p_src_layout
)
+
"-"
+
std
::
to_string
(
diff_y_format
);
const
std
::
string
key_diff_src_mem
=
key_with_layouts
+
"@eltwise_diff_src_mem"
;
const
std
::
string
key_diff_dst_mem
=
key_with_layouts
+
"@eltwise_diff_dst_mem"
;
const
std
::
string
key_grad
=
key_with_layouts
+
"@eltwise_grad"
;
const
auto
p_src_data
=
std
::
static_pointer_cast
<
T
*>
(
dev_ctx
.
GetBlob
(
key_src_data
));
key
+=
std
::
to_string
(
*
p_src_layout
);
const
std
::
string
key_src_mem
=
key
+
"@eltwise_fwd_src_mem"
;
auto
src_memory
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx
.
GetBlob
(
key_src_mem
));
PADDLE_ENFORCE
(
src_memory
!=
nullptr
,
"Fail to find src_memory in device context"
);
src_memory
->
set_data_handle
(
*
p_src_data
);
std
::
shared_ptr
<
memory
>
diff_src_memory
;
auto
p_grad
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_backward
>
(
dev_ctx
.
GetBlob
(
key_grad
));
if
(
p_grad
==
nullptr
)
{
// create mkldnn memory for input diff_y
auto
diff_dst_md
=
platform
::
MKLDNNMemDesc
(
diff_dst_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
diff_y_format
);
auto
diff_dst_memory
=
std
::
shared_ptr
<
memory
>
(
new
memory
({
diff_dst_md
,
mkldnn_engine
},
to_void_cast
(
diff_y_data
)));
dev_ctx
.
SetBlob
(
key_diff_dst_mem
,
diff_dst_memory
);
// retrieve eltwise primitive desc from device context
auto
forward_pd
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_forward
::
primitive_desc
>
(
dev_ctx
.
GetBlob
(
key_fwd_pd
));
PADDLE_ENFORCE
(
forward_pd
!=
nullptr
,
"Fail to find eltwise_fwd_pd in device context"
);
// ceate primitive descriptor for activation backward
auto
backward_desc
=
mkldnn
::
eltwise_backward
::
desc
(
algorithm
,
diff_dst_memory
->
get_primitive_desc
().
desc
(),
src_memory
->
get_primitive_desc
().
desc
(),
alpha
,
beta
);
auto
backward_pd
=
mkldnn
::
eltwise_backward
::
primitive_desc
(
backward_desc
,
mkldnn_engine
,
*
forward_pd
);
// create mkldnn memory for output diff_src
diff_src_memory
=
std
::
make_shared
<
memory
>
(
backward_pd
.
diff_src_primitive_desc
(),
diff_x_data
);
dev_ctx
.
SetBlob
(
key_diff_src_mem
,
diff_src_memory
);
// create activation backward primitive
p_grad
=
std
::
make_shared
<
mkldnn
::
eltwise_backward
>
(
backward_pd
,
*
src_memory
,
*
diff_dst_memory
,
*
diff_src_memory
);
dev_ctx
.
SetBlob
(
key_grad
,
p_grad
);
}
else
{
// primitives already exist
diff_src_memory
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx
.
GetBlob
(
key_diff_src_mem
));
auto
diff_dst_memory
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx
.
GetBlob
(
key_diff_dst_mem
));
diff_src_memory
->
set_data_handle
(
platform
::
to_void_reinterpret_cast
(
diff_x_data
));
diff_dst_memory
->
set_data_handle
(
platform
::
to_void_reinterpret_cast
(
diff_y_data
));
}
platform
::
ActivationMKLDNNHandler
handler
(
dev_ctx
,
mkldnn_engine
,
key
);
auto
diff_dst_memory_p
=
handler
.
AcquireDiffDstMemory
(
diff_dst_md
,
to_void_cast
<
T
>
(
diff_y_data
));
auto
activation_backward_pd
=
handler
.
AcquireActivationBackwardPrimitiveDescriptor
(
algorithm
,
diff_dst_md
,
src_memory
->
get_primitive_desc
().
desc
(),
alpha
,
beta
);
auto
diff_src_memory_p
=
handler
.
AcquireDiffSrcMemoryFromPrimitive
(
diff_x_data
);
auto
activation_backward_p
=
handler
.
AcquireActivationBackward
(
diff_src_memory_p
,
diff_dst_memory_p
,
src_memory
);
// push primitive to stream and wait until it's executed
std
::
vector
<
primitive
>
pipeline
;
pipeline
.
push_back
(
*
p_grad
);
pipeline
.
push_back
(
*
activation_backward_p
);
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
diff_x
->
set_layout
(
DataLayout
::
kMKLDNN
);
diff_x
->
set_format
(
GetMKLDNNFormat
(
*
diff_src_memory
));
diff_x
->
set_format
(
GetMKLDNNFormat
(
*
diff_src_memory
_p
));
}
template
<
typename
T
,
mkldnn
::
algorithm
algorithm
>
...
...
paddle/fluid/platform/mkldnn_reuse.h
浏览文件 @
8869d7f7
...
...
@@ -309,6 +309,121 @@ class SumMKLDNNHandler : public MKLDNNHandler {
std
::
shared_ptr
<
mkldnn
::
sum
::
primitive_desc
>
sum_pd_
;
};
class
ActivationMKLDNNHandler
:
public
MKLDNNHandler
{
public:
ActivationMKLDNNHandler
(
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
mkldnn
::
engine
engine
,
const
std
::
string
&
base_key
)
:
platform
::
MKLDNNHandler
(
dev_ctx
,
engine
,
base_key
)
{}
std
::
shared_ptr
<
mkldnn
::
eltwise_forward
::
primitive_desc
>
AcquireActivationPrimitiveDescriptor
(
mkldnn
::
prop_kind
prop_kind
,
mkldnn
::
algorithm
algorithm
,
const
mkldnn
::
memory
::
desc
&
md
,
float
alpha
,
float
beta
)
{
// Activation PD has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
const
std
::
string
key_activation_pd
=
key_common_
+
"@activation_pd"
;
activation_pd_
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_forward
::
primitive_desc
>
(
dev_ctx_
.
GetBlob
(
key_activation_pd
));
if
(
activation_pd_
==
nullptr
)
{
static
std
::
mutex
acquire_barrier
;
std
::
lock_guard
<
std
::
mutex
>
block_threads_until_finish_this_job
(
acquire_barrier
);
activation_pd_
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_forward
::
primitive_desc
>
(
dev_ctx_
.
GetBlob
(
key_activation_pd
));
if
(
activation_pd_
==
nullptr
)
{
auto
activation_desc
=
mkldnn
::
eltwise_forward
::
desc
(
prop_kind
,
algorithm
,
md
,
alpha
,
beta
);
activation_pd_
.
reset
(
new
mkldnn
::
eltwise_forward
::
primitive_desc
(
activation_desc
,
engine_
));
dev_ctx_
.
SetBlob
(
key_activation_pd
,
activation_pd_
);
}
}
return
activation_pd_
;
}
std
::
shared_ptr
<
mkldnn
::
eltwise_backward
::
primitive_desc
>
AcquireActivationBackwardPrimitiveDescriptor
(
mkldnn
::
algorithm
algorithm
,
const
mkldnn
::
memory
::
desc
&
diff_dst_md
,
const
mkldnn
::
memory
::
desc
&
src_md
,
float
alpha
,
float
beta
)
{
const
std
::
string
key_activation_pd
=
key_common_
+
"@activation_pd"
;
const
std
::
string
key_activation_bwd_pd
=
key_
+
"@activation_bwd_pd"
;
activation_bwd_pd_
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_backward
::
primitive_desc
>
(
dev_ctx_
.
GetBlob
(
key_activation_bwd_pd
));
if
(
activation_bwd_pd_
==
nullptr
)
{
activation_pd_
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_forward
::
primitive_desc
>
(
dev_ctx_
.
GetBlob
(
key_activation_pd
));
// PD from FWD op has to exist.
PADDLE_ENFORCE
(
activation_pd_
!=
nullptr
,
"Eltwise MKL-DNN not found in cache!"
);
auto
backward_desc
=
mkldnn
::
eltwise_backward
::
desc
(
algorithm
,
diff_dst_md
,
src_md
,
alpha
,
beta
);
activation_bwd_pd_
.
reset
(
new
mkldnn
::
eltwise_backward
::
primitive_desc
(
backward_desc
,
engine_
,
*
activation_pd_
));
dev_ctx_
.
SetBlob
(
key_activation_bwd_pd
,
activation_bwd_pd_
);
}
return
activation_bwd_pd_
;
}
std
::
shared_ptr
<
mkldnn
::
eltwise_forward
>
AcquireActivation
(
std
::
shared_ptr
<
mkldnn
::
memory
>
dst_memory_p
,
std
::
shared_ptr
<
mkldnn
::
memory
>
src_memory_p
)
{
/*Generate key*/
auto
prim_key
=
key_
+
"@eltwise_p"
;
auto
eltwise_p
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_forward
>
(
dev_ctx_
.
GetBlob
(
prim_key
));
if
(
eltwise_p
==
nullptr
)
{
eltwise_p
=
std
::
make_shared
<
mkldnn
::
eltwise_forward
>
(
*
activation_pd_
,
*
(
src_memory_p
),
*
(
dst_memory_p
));
dev_ctx_
.
SetBlob
(
prim_key
,
eltwise_p
);
}
return
eltwise_p
;
}
// TODO(jczaja): Merge all AcquireDstMemoryFromPrimitive into one
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDstMemoryFromPrimitive
(
void
*
ptr
)
{
return
this
->
AcquireMemoryFromPrimitive
(
activation_pd_
->
dst_primitive_desc
(),
ptr
,
"@dst_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDiffSrcMemoryFromPrimitive
(
void
*
ptr
)
{
return
this
->
AcquireMemoryFromPrimitive
(
activation_bwd_pd_
->
diff_src_primitive_desc
(),
ptr
,
"@diff_src_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
eltwise_backward
>
AcquireActivationBackward
(
std
::
shared_ptr
<
mkldnn
::
memory
>
diff_src_memory_p
,
std
::
shared_ptr
<
mkldnn
::
memory
>
diff_dst_memory_p
,
std
::
shared_ptr
<
mkldnn
::
memory
>
src_memory_p
)
{
/*Generate key*/
auto
prim_key
=
key_
+
"@eltwise_bwd_p"
;
auto
eltwise_bwd_p
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_backward
>
(
dev_ctx_
.
GetBlob
(
prim_key
));
if
(
eltwise_bwd_p
==
nullptr
)
{
eltwise_bwd_p
=
std
::
make_shared
<
mkldnn
::
eltwise_backward
>
(
*
activation_bwd_pd_
,
*
(
src_memory_p
),
*
(
diff_dst_memory_p
),
*
(
diff_src_memory_p
));
dev_ctx_
.
SetBlob
(
prim_key
,
eltwise_bwd_p
);
}
return
eltwise_bwd_p
;
}
private:
std
::
shared_ptr
<
mkldnn
::
eltwise_forward
::
primitive_desc
>
activation_pd_
;
std
::
shared_ptr
<
mkldnn
::
eltwise_backward
::
primitive_desc
>
activation_bwd_pd_
;
};
class
TransposeMKLDNNHandler
:
public
MKLDNNHandler
{
public:
TransposeMKLDNNHandler
(
std
::
vector
<
int
>&
dims
,
// NOLINT
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录