Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
8869d7f7
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8869d7f7
编写于
7月 10, 2019
作者:
J
Jacek Czaja
提交者:
Tao Luo
7月 10, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Activations MKLDNN ops refactoring (#18191)
上级
b6d5c74f
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
178 addition
and
132 deletion
+178
-132
paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc
paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc
+63
-132
paddle/fluid/platform/mkldnn_reuse.h
paddle/fluid/platform/mkldnn_reuse.h
+115
-0
未找到文件。
paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc
浏览文件 @
8869d7f7
...
...
@@ -13,7 +13,7 @@
limitations under the License. */
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/platform/mkldnn_
helper
.h"
#include "paddle/fluid/platform/mkldnn_
reuse
.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -99,20 +99,21 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
auto
src_format
=
src_tz
.
size
()
==
2
?
mkldnn
::
memory
::
format
::
nc
:
x
->
format
();
const
std
::
string
key
=
gethash
(
src_tz
,
algorithm
);
const
std
::
string
key_src_data
=
key
+
ctx
.
op
().
Output
(
"Out"
)
+
"@eltwise_fwd_src_data"
;
const
std
::
string
key_src_layout
=
key
+
ctx
.
op
().
Output
(
"Out"
)
+
"@eltwise_fwd_src_layout"
;
const
std
::
string
key_with_layout
=
key
+
std
::
to_string
(
src_format
);
const
std
::
string
key_src_mem
=
key_with_layout
+
"@eltwise_fwd_src_mem"
;
const
std
::
string
key_dst_mem
=
key_with_layout
+
"@eltwise_fwd_dst_mem"
;
const
std
::
string
key_fwd
=
key_with_layout
+
"@eltwise_fwd"
;
const
std
::
string
key_fwd_pd
=
key_with_layout
+
"@eltwise_fwd_pd"
;
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
// TODO(jczaja): When adding leaky-relu , swish , elu make sure to extend key
// with alpha, beta
std
::
string
key
=
platform
::
MKLDNNHandler
::
GetHash
(
src_tz
,
std
::
to_string
(
algorithm
)
+
ctx
.
op
().
Output
(
"Out"
));
// TODO(jczaja): Make it Thread safe
// save input data and layout to be referred in backward path
const
std
::
string
key_src_data
=
key
+
"@eltwise_fwd_src_data"
;
const
std
::
string
key_src_layout
=
key
+
"@eltwise_fwd_src_layout"
;
// Just in case some int8 models are run interchangebly
// with float models then format maybe diffrent
key
+=
std
::
to_string
(
src_format
);
const
std
::
string
key_src_mem
=
key
+
"@eltwise_fwd_src_mem"
;
auto
p_src_data
=
std
::
make_shared
<
const
T
*>
(
x_data
);
auto
p_src_layout
=
std
::
make_shared
<
memory
::
format
>
(
src_format
);
if
(
!
is_test
)
{
...
...
@@ -120,65 +121,34 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
dev_ctx
.
SetBlob
(
key_src_layout
,
p_src_layout
);
}
auto
p_fwd
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_forward
>
(
dev_ctx
.
GetBlob
(
key_fwd
));
std
::
shared_ptr
<
memory
>
dst_memory
;
if
(
p_fwd
==
nullptr
)
{
// create mkldnn memory for input X
auto
src_md
=
platform
::
MKLDNNMemDesc
(
src_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
src_format
);
auto
src_memory
=
std
::
shared_ptr
<
memory
>
(
new
memory
({
src_md
,
mkldnn_engine
},
to_void_cast
(
x_data
)));
// save src_memory to be referred in backward path
dev_ctx
.
SetBlob
(
key_src_mem
,
src_memory
);
// create primitive descriptor for activation forward and save it
auto
mkldnn_forward_prop_kind
=
is_test
?
mkldnn
::
prop_kind
::
forward_inference
:
mkldnn
::
prop_kind
::
forward_training
;
auto
forward_desc
=
mkldnn
::
eltwise_forward
::
desc
(
mkldnn_forward_prop_kind
,
algorithm
,
src_memory
->
get_primitive_desc
().
desc
(),
alpha
,
beta
);
auto
forward_pd
=
std
::
make_shared
<
mkldnn
::
eltwise_forward
::
primitive_desc
>
(
forward_desc
,
mkldnn_engine
);
// save prim desc into global device context to be referred in backward path
if
(
!
is_test
)
dev_ctx
.
SetBlob
(
key_fwd_pd
,
forward_pd
);
// create mkldnn memory for output y
dst_memory
=
std
::
make_shared
<
memory
>
(
forward_pd
->
dst_primitive_desc
(),
y_data
);
dev_ctx
.
SetBlob
(
key_dst_mem
,
dst_memory
);
// create activation primitive
p_fwd
=
std
::
make_shared
<
mkldnn
::
eltwise_forward
>
(
*
forward_pd
,
*
src_memory
,
*
dst_memory
);
dev_ctx
.
SetBlob
(
key_fwd
,
p_fwd
);
}
else
{
// primitives already exist
auto
src_memory
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx
.
GetBlob
(
key_src_mem
));
PADDLE_ENFORCE
(
src_memory
!=
nullptr
,
"Fail to find eltwise src_memory in device context."
);
dst_memory
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx
.
GetBlob
(
key_dst_mem
));
PADDLE_ENFORCE
(
dst_memory
!=
nullptr
,
"Fail to find eltwise dst_memory in device context."
);
src_memory
->
set_data_handle
(
platform
::
to_void_cast
(
x_data
));
dst_memory
->
set_data_handle
(
y_data
);
platform
::
ActivationMKLDNNHandler
handler
(
dev_ctx
,
mkldnn_engine
,
key
);
auto
md
=
platform
::
MKLDNNMemDesc
(
src_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
src_format
);
auto
activation_pd
=
handler
.
AcquireActivationPrimitiveDescriptor
(
is_test
?
mkldnn
::
prop_kind
::
forward_inference
:
mkldnn
::
prop_kind
::
forward_training
,
algorithm
,
md
,
alpha
,
beta
);
auto
src_memory_p
=
handler
.
AcquireSrcMemory
(
md
,
to_void_cast
<
T
>
(
x_data
));
// jczaja: Workaround, src_memory_p is needed in BWD so it has
// to be accessible under key not dependant on TID
if
(
!
is_test
)
{
dev_ctx
.
SetBlob
(
key_src_mem
,
src_memory_p
);
}
auto
dst_memory_p
=
handler
.
AcquireDstMemoryFromPrimitive
(
to_void_cast
<
T
>
(
y_data
));
auto
activation_p
=
handler
.
AcquireActivation
(
dst_memory_p
,
src_memory_p
);
// push primitive to stream and wait until it's executed
std
::
vector
<
primitive
>
pipeline
;
pipeline
.
push_back
(
*
p_fwd
);
pipeline
.
push_back
(
*
activation_p
);
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
y
->
set_layout
(
DataLayout
::
kMKLDNN
);
y
->
set_format
(
GetMKLDNNFormat
(
*
dst_memory
));
y
->
set_format
(
GetMKLDNNFormat
(
*
dst_memory
_p
));
}
template
<
typename
T
>
...
...
@@ -199,90 +169,51 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
auto
diff_y_format
=
diff_dst_tz
.
size
()
==
2
?
mkldnn
::
memory
::
format
::
nc
:
diff_y
->
format
();
const
std
::
string
key
=
gethash
(
diff_dst_tz
,
algorithm
);
const
std
::
string
key_src_data
=
key
+
ctx
.
op
().
Input
(
"Out"
)
+
"@eltwise_fwd_src_data"
;
const
std
::
string
key_src_layout
=
key
+
ctx
.
op
().
Input
(
"Out"
)
+
"@eltwise_fwd_src_layout"
;
auto
diff_dst_md
=
platform
::
MKLDNNMemDesc
(
diff_dst_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
diff_y_format
);
std
::
string
key
=
platform
::
MKLDNNHandler
::
GetHash
(
diff_dst_tz
,
std
::
to_string
(
algorithm
)
+
ctx
.
op
().
Input
(
"Out"
));
const
std
::
string
key_src_data
=
key
+
"@eltwise_fwd_src_data"
;
const
std
::
string
key_src_layout
=
key
+
"@eltwise_fwd_src_layout"
;
// Get Data from FWD op
const
auto
p_src_layout
=
std
::
static_pointer_cast
<
memory
::
format
>
(
dev_ctx
.
GetBlob
(
key_src_layout
));
const
std
::
string
key_src_mem
=
key
+
std
::
to_string
(
*
p_src_layout
)
+
"@eltwise_fwd_src_mem"
;
const
std
::
string
key_fwd_pd
=
key
+
std
::
to_string
(
*
p_src_layout
)
+
"@eltwise_fwd_pd"
;
const
std
::
string
key_with_layouts
=
key
+
std
::
to_string
(
*
p_src_layout
)
+
"-"
+
std
::
to_string
(
diff_y_format
);
const
std
::
string
key_diff_src_mem
=
key_with_layouts
+
"@eltwise_diff_src_mem"
;
const
std
::
string
key_diff_dst_mem
=
key_with_layouts
+
"@eltwise_diff_dst_mem"
;
const
std
::
string
key_grad
=
key_with_layouts
+
"@eltwise_grad"
;
const
auto
p_src_data
=
std
::
static_pointer_cast
<
T
*>
(
dev_ctx
.
GetBlob
(
key_src_data
));
key
+=
std
::
to_string
(
*
p_src_layout
);
const
std
::
string
key_src_mem
=
key
+
"@eltwise_fwd_src_mem"
;
auto
src_memory
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx
.
GetBlob
(
key_src_mem
));
PADDLE_ENFORCE
(
src_memory
!=
nullptr
,
"Fail to find src_memory in device context"
);
src_memory
->
set_data_handle
(
*
p_src_data
);
std
::
shared_ptr
<
memory
>
diff_src_memory
;
auto
p_grad
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_backward
>
(
dev_ctx
.
GetBlob
(
key_grad
));
if
(
p_grad
==
nullptr
)
{
// create mkldnn memory for input diff_y
auto
diff_dst_md
=
platform
::
MKLDNNMemDesc
(
diff_dst_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
diff_y_format
);
auto
diff_dst_memory
=
std
::
shared_ptr
<
memory
>
(
new
memory
({
diff_dst_md
,
mkldnn_engine
},
to_void_cast
(
diff_y_data
)));
dev_ctx
.
SetBlob
(
key_diff_dst_mem
,
diff_dst_memory
);
// retrieve eltwise primitive desc from device context
auto
forward_pd
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_forward
::
primitive_desc
>
(
dev_ctx
.
GetBlob
(
key_fwd_pd
));
PADDLE_ENFORCE
(
forward_pd
!=
nullptr
,
"Fail to find eltwise_fwd_pd in device context"
);
// ceate primitive descriptor for activation backward
auto
backward_desc
=
mkldnn
::
eltwise_backward
::
desc
(
algorithm
,
diff_dst_memory
->
get_primitive_desc
().
desc
(),
src_memory
->
get_primitive_desc
().
desc
(),
alpha
,
beta
);
auto
backward_pd
=
mkldnn
::
eltwise_backward
::
primitive_desc
(
backward_desc
,
mkldnn_engine
,
*
forward_pd
);
// create mkldnn memory for output diff_src
diff_src_memory
=
std
::
make_shared
<
memory
>
(
backward_pd
.
diff_src_primitive_desc
(),
diff_x_data
);
dev_ctx
.
SetBlob
(
key_diff_src_mem
,
diff_src_memory
);
// create activation backward primitive
p_grad
=
std
::
make_shared
<
mkldnn
::
eltwise_backward
>
(
backward_pd
,
*
src_memory
,
*
diff_dst_memory
,
*
diff_src_memory
);
dev_ctx
.
SetBlob
(
key_grad
,
p_grad
);
}
else
{
// primitives already exist
diff_src_memory
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx
.
GetBlob
(
key_diff_src_mem
));
auto
diff_dst_memory
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx
.
GetBlob
(
key_diff_dst_mem
));
diff_src_memory
->
set_data_handle
(
platform
::
to_void_reinterpret_cast
(
diff_x_data
));
diff_dst_memory
->
set_data_handle
(
platform
::
to_void_reinterpret_cast
(
diff_y_data
));
}
platform
::
ActivationMKLDNNHandler
handler
(
dev_ctx
,
mkldnn_engine
,
key
);
auto
diff_dst_memory_p
=
handler
.
AcquireDiffDstMemory
(
diff_dst_md
,
to_void_cast
<
T
>
(
diff_y_data
));
auto
activation_backward_pd
=
handler
.
AcquireActivationBackwardPrimitiveDescriptor
(
algorithm
,
diff_dst_md
,
src_memory
->
get_primitive_desc
().
desc
(),
alpha
,
beta
);
auto
diff_src_memory_p
=
handler
.
AcquireDiffSrcMemoryFromPrimitive
(
diff_x_data
);
auto
activation_backward_p
=
handler
.
AcquireActivationBackward
(
diff_src_memory_p
,
diff_dst_memory_p
,
src_memory
);
// push primitive to stream and wait until it's executed
std
::
vector
<
primitive
>
pipeline
;
pipeline
.
push_back
(
*
p_grad
);
pipeline
.
push_back
(
*
activation_backward_p
);
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
diff_x
->
set_layout
(
DataLayout
::
kMKLDNN
);
diff_x
->
set_format
(
GetMKLDNNFormat
(
*
diff_src_memory
));
diff_x
->
set_format
(
GetMKLDNNFormat
(
*
diff_src_memory
_p
));
}
template
<
typename
T
,
mkldnn
::
algorithm
algorithm
>
...
...
paddle/fluid/platform/mkldnn_reuse.h
浏览文件 @
8869d7f7
...
...
@@ -309,6 +309,121 @@ class SumMKLDNNHandler : public MKLDNNHandler {
std
::
shared_ptr
<
mkldnn
::
sum
::
primitive_desc
>
sum_pd_
;
};
class
ActivationMKLDNNHandler
:
public
MKLDNNHandler
{
public:
ActivationMKLDNNHandler
(
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
mkldnn
::
engine
engine
,
const
std
::
string
&
base_key
)
:
platform
::
MKLDNNHandler
(
dev_ctx
,
engine
,
base_key
)
{}
std
::
shared_ptr
<
mkldnn
::
eltwise_forward
::
primitive_desc
>
AcquireActivationPrimitiveDescriptor
(
mkldnn
::
prop_kind
prop_kind
,
mkldnn
::
algorithm
algorithm
,
const
mkldnn
::
memory
::
desc
&
md
,
float
alpha
,
float
beta
)
{
// Activation PD has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
const
std
::
string
key_activation_pd
=
key_common_
+
"@activation_pd"
;
activation_pd_
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_forward
::
primitive_desc
>
(
dev_ctx_
.
GetBlob
(
key_activation_pd
));
if
(
activation_pd_
==
nullptr
)
{
static
std
::
mutex
acquire_barrier
;
std
::
lock_guard
<
std
::
mutex
>
block_threads_until_finish_this_job
(
acquire_barrier
);
activation_pd_
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_forward
::
primitive_desc
>
(
dev_ctx_
.
GetBlob
(
key_activation_pd
));
if
(
activation_pd_
==
nullptr
)
{
auto
activation_desc
=
mkldnn
::
eltwise_forward
::
desc
(
prop_kind
,
algorithm
,
md
,
alpha
,
beta
);
activation_pd_
.
reset
(
new
mkldnn
::
eltwise_forward
::
primitive_desc
(
activation_desc
,
engine_
));
dev_ctx_
.
SetBlob
(
key_activation_pd
,
activation_pd_
);
}
}
return
activation_pd_
;
}
std
::
shared_ptr
<
mkldnn
::
eltwise_backward
::
primitive_desc
>
AcquireActivationBackwardPrimitiveDescriptor
(
mkldnn
::
algorithm
algorithm
,
const
mkldnn
::
memory
::
desc
&
diff_dst_md
,
const
mkldnn
::
memory
::
desc
&
src_md
,
float
alpha
,
float
beta
)
{
const
std
::
string
key_activation_pd
=
key_common_
+
"@activation_pd"
;
const
std
::
string
key_activation_bwd_pd
=
key_
+
"@activation_bwd_pd"
;
activation_bwd_pd_
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_backward
::
primitive_desc
>
(
dev_ctx_
.
GetBlob
(
key_activation_bwd_pd
));
if
(
activation_bwd_pd_
==
nullptr
)
{
activation_pd_
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_forward
::
primitive_desc
>
(
dev_ctx_
.
GetBlob
(
key_activation_pd
));
// PD from FWD op has to exist.
PADDLE_ENFORCE
(
activation_pd_
!=
nullptr
,
"Eltwise MKL-DNN not found in cache!"
);
auto
backward_desc
=
mkldnn
::
eltwise_backward
::
desc
(
algorithm
,
diff_dst_md
,
src_md
,
alpha
,
beta
);
activation_bwd_pd_
.
reset
(
new
mkldnn
::
eltwise_backward
::
primitive_desc
(
backward_desc
,
engine_
,
*
activation_pd_
));
dev_ctx_
.
SetBlob
(
key_activation_bwd_pd
,
activation_bwd_pd_
);
}
return
activation_bwd_pd_
;
}
std
::
shared_ptr
<
mkldnn
::
eltwise_forward
>
AcquireActivation
(
std
::
shared_ptr
<
mkldnn
::
memory
>
dst_memory_p
,
std
::
shared_ptr
<
mkldnn
::
memory
>
src_memory_p
)
{
/*Generate key*/
auto
prim_key
=
key_
+
"@eltwise_p"
;
auto
eltwise_p
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_forward
>
(
dev_ctx_
.
GetBlob
(
prim_key
));
if
(
eltwise_p
==
nullptr
)
{
eltwise_p
=
std
::
make_shared
<
mkldnn
::
eltwise_forward
>
(
*
activation_pd_
,
*
(
src_memory_p
),
*
(
dst_memory_p
));
dev_ctx_
.
SetBlob
(
prim_key
,
eltwise_p
);
}
return
eltwise_p
;
}
// TODO(jczaja): Merge all AcquireDstMemoryFromPrimitive into one
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDstMemoryFromPrimitive
(
void
*
ptr
)
{
return
this
->
AcquireMemoryFromPrimitive
(
activation_pd_
->
dst_primitive_desc
(),
ptr
,
"@dst_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDiffSrcMemoryFromPrimitive
(
void
*
ptr
)
{
return
this
->
AcquireMemoryFromPrimitive
(
activation_bwd_pd_
->
diff_src_primitive_desc
(),
ptr
,
"@diff_src_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
eltwise_backward
>
AcquireActivationBackward
(
std
::
shared_ptr
<
mkldnn
::
memory
>
diff_src_memory_p
,
std
::
shared_ptr
<
mkldnn
::
memory
>
diff_dst_memory_p
,
std
::
shared_ptr
<
mkldnn
::
memory
>
src_memory_p
)
{
/*Generate key*/
auto
prim_key
=
key_
+
"@eltwise_bwd_p"
;
auto
eltwise_bwd_p
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_backward
>
(
dev_ctx_
.
GetBlob
(
prim_key
));
if
(
eltwise_bwd_p
==
nullptr
)
{
eltwise_bwd_p
=
std
::
make_shared
<
mkldnn
::
eltwise_backward
>
(
*
activation_bwd_pd_
,
*
(
src_memory_p
),
*
(
diff_dst_memory_p
),
*
(
diff_src_memory_p
));
dev_ctx_
.
SetBlob
(
prim_key
,
eltwise_bwd_p
);
}
return
eltwise_bwd_p
;
}
private:
std
::
shared_ptr
<
mkldnn
::
eltwise_forward
::
primitive_desc
>
activation_pd_
;
std
::
shared_ptr
<
mkldnn
::
eltwise_backward
::
primitive_desc
>
activation_bwd_pd_
;
};
class
TransposeMKLDNNHandler
:
public
MKLDNNHandler
{
public:
TransposeMKLDNNHandler
(
std
::
vector
<
int
>&
dims
,
// NOLINT
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录