Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
dfdd73cb
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
dfdd73cb
编写于
9月 17, 2019
作者:
A
Adam
提交者:
Tao Luo
9月 17, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add MKLDNNhandlerT templatized class (#19801)
test=develop
上级
cabb9501
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
169 addition
and
168 deletion
+169
-168
paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc
paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc
+38
-81
paddle/fluid/platform/mkldnn_helper.h
paddle/fluid/platform/mkldnn_helper.h
+2
-1
paddle/fluid/platform/mkldnn_reuse.h
paddle/fluid/platform/mkldnn_reuse.h
+129
-86
未找到文件。
paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc
浏览文件 @
dfdd73cb
...
@@ -33,17 +33,18 @@ using mkldnn::stream;
...
@@ -33,17 +33,18 @@ using mkldnn::stream;
using
platform
::
to_void_cast
;
using
platform
::
to_void_cast
;
template
<
typename
T
>
template
<
typename
T
>
class
SoftmaxMKLDNNHandler
:
public
platform
::
MKLDNNHandler
{
class
SoftmaxMKLDNNHandler
:
public
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
softmax_forward
,
mkldnn
::
softmax_backward
>
{
public:
public:
SoftmaxMKLDNNHandler
(
const
std
::
vector
<
int
>&
dims
,
SoftmaxMKLDNNHandler
(
const
std
::
vector
<
int
>&
dims
,
const
MKLDNNMemoryFormat
fmt
,
const
MKLDNNMemoryFormat
fmt
,
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
platform
::
Place
cpu_place
,
const
std
::
string
&
uniq_name
)
platform
::
Place
cpu_place
,
const
std
::
string
&
uniq_name
)
:
platform
::
MKLDNNHandler
(
dev_ctx
,
dev_ctx
.
GetEngine
(),
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
softmax_forward
,
platform
::
CreateKey
(
dims
,
uniq_name
)),
mkldnn
::
softmax_backward
>
(
place_
(
cpu_place
),
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
fwd_pd_
(
nullptr
),
platform
::
CreateKey
(
dims
,
uniq_name
))
{
bwd_pd_
(
nullptr
)
{
this
->
AcquireSoftmaxPrimitiveDescriptor
(
dims
,
fmt
);
this
->
AcquireSoftmaxPrimitiveDescriptor
(
dims
,
fmt
);
}
}
...
@@ -52,11 +53,10 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
...
@@ -52,11 +53,10 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
const
MKLDNNMemoryFormat
diff_fmt
,
const
MKLDNNMemoryFormat
diff_fmt
,
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
platform
::
Place
cpu_place
,
const
std
::
string
&
uniq_name
)
platform
::
Place
cpu_place
,
const
std
::
string
&
uniq_name
)
:
platform
::
MKLDNNHandler
(
dev_ctx
,
dev_ctx
.
GetEngine
(),
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
softmax_forward
,
platform
::
CreateKey
(
dims
,
uniq_name
)),
mkldnn
::
softmax_backward
>
(
place_
(
cpu_place
),
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
fwd_pd_
(
nullptr
),
platform
::
CreateKey
(
dims
,
uniq_name
))
{
bwd_pd_
(
nullptr
)
{
// If we are in Grad operatgor then update a key with BWD suffix to
// If we are in Grad operatgor then update a key with BWD suffix to
// distinguish from FWD memory primitives
// distinguish from FWD memory primitives
// Key_common will allow to access FWD_PD from cache
// Key_common will allow to access FWD_PD from cache
...
@@ -64,58 +64,19 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
...
@@ -64,58 +64,19 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
this
->
AcquireSoftmaxBackwardPrimitiveDescriptor
(
dims
,
fmt
,
diff_fmt
);
this
->
AcquireSoftmaxBackwardPrimitiveDescriptor
(
dims
,
fmt
,
diff_fmt
);
}
}
// TODO(jczaja): Once fwd_pd_ are moved to MKLDNNHandler then this function
// should be moved as well eg. SoftmaxMKLDNNHandler -> MKLDNNHandler<softmax_>
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireSrcMemory
(
const
Tensor
*
input
)
{
const
T
*
input_data
=
input
->
data
<
T
>
();
return
this
->
AcquireMemoryFromPrimitive
(
fwd_pd_
->
src_primitive_desc
(),
to_void_cast
<
T
>
(
input_data
),
"@src_mem_p"
);
}
// TODO(jczaja): Move to MKLDNNHandler as common code
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDstMemory
(
framework
::
Tensor
*
output
)
{
T
*
ptr
=
output
->
mutable_data
<
T
>
(
place_
,
fwd_pd_
->
dst_primitive_desc
().
get_size
());
return
this
->
AcquireMemoryFromPrimitive
(
fwd_pd_
->
dst_primitive_desc
(),
ptr
,
"@dst_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDstMemory
(
const
Tensor
*
output
)
{
const
T
*
output_data
=
output
->
data
<
T
>
();
return
this
->
AcquireMemoryFromPrimitive
(
bwd_pd_
->
dst_primitive_desc
(),
to_void_cast
<
T
>
(
output_data
),
"@bwd-dst_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDiffDstMemory
(
const
Tensor
*
diffdst
)
{
const
T
*
ptr
=
diffdst
->
data
<
T
>
();
return
this
->
AcquireMemoryFromPrimitive
(
bwd_pd_
->
diff_dst_primitive_desc
(),
to_void_cast
<
T
>
(
ptr
),
"@diff_dst_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDiffSrcMemory
(
framework
::
Tensor
*
diffsrc
)
{
T
*
ptr
=
diffsrc
->
mutable_data
<
T
>
(
place_
,
bwd_pd_
->
diff_src_primitive_desc
().
get_size
());
return
this
->
AcquireMemoryFromPrimitive
(
bwd_pd_
->
diff_src_primitive_desc
(),
ptr
,
"@diff_src_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
softmax_forward
>
AcquireSoftmax
(
std
::
shared_ptr
<
mkldnn
::
softmax_forward
>
AcquireSoftmax
(
std
::
shared_ptr
<
mkldnn
::
memory
>
dst_memory_p
,
std
::
shared_ptr
<
mkldnn
::
memory
>
dst_memory_p
,
std
::
shared_ptr
<
mkldnn
::
memory
>
src_memory_p
)
{
std
::
shared_ptr
<
mkldnn
::
memory
>
src_memory_p
)
{
/*Generate key*/
/*Generate key*/
auto
prim_key
=
key_
+
"@softmax_p"
;
auto
prim_key
=
this
->
key_
+
"@softmax_p"
;
auto
softmax_p
=
std
::
static_pointer_cast
<
mkldnn
::
softmax_forward
>
(
auto
softmax_p
=
std
::
static_pointer_cast
<
mkldnn
::
softmax_forward
>
(
dev_ctx_
.
GetBlob
(
prim_key
));
this
->
dev_ctx_
.
GetBlob
(
prim_key
));
if
(
softmax_p
==
nullptr
)
{
if
(
softmax_p
==
nullptr
)
{
softmax_p
=
std
::
make_shared
<
mkldnn
::
softmax_forward
>
(
softmax_p
=
std
::
make_shared
<
mkldnn
::
softmax_forward
>
(
*
fwd_pd_
,
*
(
static_cast
<
mkldnn
::
memory
*>
(
src_memory_p
.
get
())),
*
this
->
fwd_pd_
,
*
(
static_cast
<
mkldnn
::
memory
*>
(
src_memory_p
.
get
())),
*
(
static_cast
<
mkldnn
::
memory
*>
(
dst_memory_p
.
get
())));
*
(
static_cast
<
mkldnn
::
memory
*>
(
dst_memory_p
.
get
())));
dev_ctx_
.
SetBlob
(
prim_key
,
softmax_p
);
this
->
dev_ctx_
.
SetBlob
(
prim_key
,
softmax_p
);
}
}
return
softmax_p
;
return
softmax_p
;
...
@@ -125,13 +86,14 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
...
@@ -125,13 +86,14 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
std
::
shared_ptr
<
mkldnn
::
memory
>
dst_memory_p
,
std
::
shared_ptr
<
mkldnn
::
memory
>
dst_memory_p
,
std
::
shared_ptr
<
mkldnn
::
memory
>
diff_dst_memory_p
,
std
::
shared_ptr
<
mkldnn
::
memory
>
diff_dst_memory_p
,
std
::
shared_ptr
<
mkldnn
::
memory
>
diff_src_memory_p
)
{
std
::
shared_ptr
<
mkldnn
::
memory
>
diff_src_memory_p
)
{
auto
prim_key
=
key_
+
"@softmax_bwd_p"
;
auto
prim_key
=
this
->
key_
+
"@softmax_bwd_p"
;
auto
softmax_bwd_p
=
std
::
static_pointer_cast
<
mkldnn
::
softmax_backward
>
(
auto
softmax_bwd_p
=
std
::
static_pointer_cast
<
mkldnn
::
softmax_backward
>
(
dev_ctx_
.
GetBlob
(
prim_key
));
this
->
dev_ctx_
.
GetBlob
(
prim_key
));
if
(
softmax_bwd_p
==
nullptr
)
{
if
(
softmax_bwd_p
==
nullptr
)
{
softmax_bwd_p
=
std
::
make_shared
<
mkldnn
::
softmax_backward
>
(
softmax_bwd_p
=
std
::
make_shared
<
mkldnn
::
softmax_backward
>
(
*
bwd_pd_
,
*
dst_memory_p
,
*
diff_dst_memory_p
,
*
diff_src_memory_p
);
*
this
->
bwd_pd_
,
*
dst_memory_p
,
*
diff_dst_memory_p
,
dev_ctx_
.
SetBlob
(
prim_key
,
softmax_bwd_p
);
*
diff_src_memory_p
);
this
->
dev_ctx_
.
SetBlob
(
prim_key
,
softmax_bwd_p
);
}
}
return
softmax_bwd_p
;
return
softmax_bwd_p
;
...
@@ -143,17 +105,17 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
...
@@ -143,17 +105,17 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
// Softmax PD has to be passed to Grad op that
// Softmax PD has to be passed to Grad op that
// may be executed by diffrent thread, hence
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
// for that one we use key that does not contain TID
const
std
::
string
key_softmax_pd
=
key_common_
+
"@softmax_pd"
;
const
std
::
string
key_softmax_pd
=
this
->
key_common_
+
"@softmax_pd"
;
fwd_pd_
=
std
::
static_pointer_cast
<
softmax_forward
::
primitive_desc
>
(
this
->
fwd_pd_
=
std
::
static_pointer_cast
<
softmax_forward
::
primitive_desc
>
(
dev_ctx_
.
GetBlob
(
key_softmax_pd
));
this
->
dev_ctx_
.
GetBlob
(
key_softmax_pd
));
if
(
fwd_pd_
==
nullptr
)
{
if
(
this
->
fwd_pd_
==
nullptr
)
{
static
std
::
mutex
acquire_barrier
;
static
std
::
mutex
acquire_barrier
;
std
::
lock_guard
<
std
::
mutex
>
block_threads_until_finish_this_job
(
std
::
lock_guard
<
std
::
mutex
>
block_threads_until_finish_this_job
(
acquire_barrier
);
acquire_barrier
);
fwd_pd_
=
std
::
static_pointer_cast
<
softmax_forward
::
primitive_desc
>
(
this
->
fwd_pd_
=
std
::
static_pointer_cast
<
softmax_forward
::
primitive_desc
>
(
dev_ctx_
.
GetBlob
(
key_softmax_pd
));
this
->
dev_ctx_
.
GetBlob
(
key_softmax_pd
));
if
(
fwd_pd_
==
nullptr
)
{
if
(
this
->
fwd_pd_
==
nullptr
)
{
// TODO(jczaja): Make it working along chosen axis and for
// TODO(jczaja): Make it working along chosen axis and for
// forward_training
// forward_training
// Normalization is made after innermost dimension eg. C out of NC
// Normalization is made after innermost dimension eg. C out of NC
...
@@ -161,9 +123,9 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
...
@@ -161,9 +123,9 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
mkldnn
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
fmt
);
mkldnn
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
fmt
);
auto
softmax_desc
=
auto
softmax_desc
=
softmax_forward
::
desc
(
prop_kind
::
forward_scoring
,
md
,
1
/*dim: C*/
);
softmax_forward
::
desc
(
prop_kind
::
forward_scoring
,
md
,
1
/*dim: C*/
);
fwd_pd_
.
reset
(
this
->
fwd_pd_
.
reset
(
new
softmax_forward
::
primitive_desc
(
softmax_desc
,
engine_
));
new
softmax_forward
::
primitive_desc
(
softmax_desc
,
this
->
engine_
));
dev_ctx_
.
SetBlob
(
key_softmax_pd
,
fwd_pd_
);
this
->
dev_ctx_
.
SetBlob
(
key_softmax_pd
,
this
->
fwd_pd_
);
}
}
}
}
}
}
...
@@ -172,12 +134,12 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
...
@@ -172,12 +134,12 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
const
std
::
vector
<
int
>&
dims
,
const
mkldnn
::
memory
::
format
fmt
,
const
std
::
vector
<
int
>&
dims
,
const
mkldnn
::
memory
::
format
fmt
,
const
mkldnn
::
memory
::
format
diff_fmt
)
{
const
mkldnn
::
memory
::
format
diff_fmt
)
{
// Fwd_PD_ has to exists when to create BWD_PD_
// Fwd_PD_ has to exists when to create BWD_PD_
PADDLE_ENFORCE_NOT_NULL
(
fwd_pd_
);
PADDLE_ENFORCE_NOT_NULL
(
this
->
fwd_pd_
);
const
std
::
string
key_bwd_pd
=
key_
+
"@softmax_bwd_pd"
;
const
std
::
string
key_bwd_pd
=
this
->
key_
+
"@softmax_bwd_pd"
;
bwd_pd_
=
this
->
bwd_pd_
=
std
::
static_pointer_cast
<
mkldnn
::
softmax_backward
::
primitive_desc
>
(
std
::
static_pointer_cast
<
mkldnn
::
softmax_backward
::
primitive_desc
>
(
dev_ctx_
.
GetBlob
(
key_bwd_pd
));
this
->
dev_ctx_
.
GetBlob
(
key_bwd_pd
));
if
(
bwd_pd_
==
nullptr
)
{
if
(
this
->
bwd_pd_
==
nullptr
)
{
auto
data_softmax_md
=
auto
data_softmax_md
=
mkldnn
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
fmt
);
mkldnn
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
fmt
);
auto
diff_softmax_md
=
mkldnn
::
memory
::
desc
(
auto
diff_softmax_md
=
mkldnn
::
memory
::
desc
(
...
@@ -185,16 +147,11 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
...
@@ -185,16 +147,11 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
// TODO(jczaja): Add support for other axes
// TODO(jczaja): Add support for other axes
auto
backward_desc
=
softmax_backward
::
desc
(
auto
backward_desc
=
softmax_backward
::
desc
(
diff_softmax_md
,
data_softmax_md
,
1
/* dim: C*/
);
diff_softmax_md
,
data_softmax_md
,
1
/* dim: C*/
);
bwd_pd_
.
reset
(
new
mkldnn
::
softmax_backward
::
primitive_desc
(
this
->
bwd_pd_
.
reset
(
new
mkldnn
::
softmax_backward
::
primitive_desc
(
backward_desc
,
engine_
,
*
fwd_pd_
));
backward_desc
,
this
->
engine_
,
*
this
->
fwd_pd_
));
dev_ctx_
.
SetBlob
(
key_bwd_pd
,
bwd_pd_
);
this
->
dev_ctx_
.
SetBlob
(
key_bwd_pd
,
this
->
bwd_pd_
);
}
}
}
}
private:
platform
::
Place
place_
;
std
::
shared_ptr
<
mkldnn
::
softmax_forward
::
primitive_desc
>
fwd_pd_
;
std
::
shared_ptr
<
mkldnn
::
softmax_backward
::
primitive_desc
>
bwd_pd_
;
};
};
template
<
typename
T
>
template
<
typename
T
>
...
...
paddle/fluid/platform/mkldnn_helper.h
浏览文件 @
dfdd73cb
...
@@ -17,6 +17,7 @@ limitations under the License. */
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <algorithm>
#include <algorithm>
#include <memory>
#include <memory>
#include <string>
#include <string>
#include <utility>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
...
@@ -206,7 +207,7 @@ inline std::string CreateKey(ArgTypes&&... args) {
...
@@ -206,7 +207,7 @@ inline std::string CreateKey(ArgTypes&&... args) {
std
::
string
key
;
std
::
string
key
;
key
.
reserve
(
256
);
key
.
reserve
(
256
);
using
expand_type
=
int
[];
using
expand_type
=
int
[];
expand_type
{
0
,
(
AppendKey
(
&
key
,
args
),
0
)...};
expand_type
{
0
,
(
AppendKey
(
&
key
,
std
::
forward
<
ArgTypes
>
(
args
)
),
0
)...};
return
key
;
return
key
;
}
}
...
...
paddle/fluid/platform/mkldnn_reuse.h
浏览文件 @
dfdd73cb
...
@@ -29,6 +29,90 @@ namespace platform {
...
@@ -29,6 +29,90 @@ namespace platform {
using
user_function
=
std
::
function
<
std
::
shared_ptr
<
float
>
(
const
float
*
)
>
;
using
user_function
=
std
::
function
<
std
::
shared_ptr
<
float
>
(
const
float
*
)
>
;
using
memory
=
mkldnn
::
memory
;
using
memory
=
mkldnn
::
memory
;
template
<
typename
T
,
typename
TForward
,
typename
TBackward
>
class
MKLDNNHandlerT
{
public:
MKLDNNHandlerT
(
const
MKLDNNDeviceContext
&
dev_ctx
,
mkldnn
::
engine
engine
,
platform
::
Place
cpu_place
,
const
std
::
string
&
base_key
)
:
dev_ctx_
(
dev_ctx
),
engine_
(
engine
),
place_
(
cpu_place
),
key_common_
(
base_key
),
fwd_pd_
(
nullptr
),
bwd_pd_
(
nullptr
)
{
if
(
platform
::
get_cur_mkldnn_session_id
()
!=
platform
::
kMKLDNNSessionID_Default
)
{
key_
=
key_common_
;
}
else
{
key_
=
key_common_
+
"-t:"
+
ThreadIDasStr
();
}
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireSrcMemory
(
const
framework
::
Tensor
*
input
)
{
const
T
*
input_data
=
input
->
data
<
T
>
();
return
this
->
AcquireMemoryFromPrimitive
(
fwd_pd_
->
src_primitive_desc
(),
to_void_cast
<
T
>
(
input_data
),
"@src_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDstMemory
(
framework
::
Tensor
*
output
)
{
T
*
ptr
=
output
->
mutable_data
<
T
>
(
place_
,
fwd_pd_
->
dst_primitive_desc
().
get_size
());
return
this
->
AcquireMemoryFromPrimitive
(
fwd_pd_
->
dst_primitive_desc
(),
ptr
,
"@dst_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDstMemory
(
const
framework
::
Tensor
*
output
)
{
const
T
*
output_data
=
output
->
data
<
T
>
();
return
this
->
AcquireMemoryFromPrimitive
(
bwd_pd_
->
dst_primitive_desc
(),
to_void_cast
<
T
>
(
output_data
),
"@bwd-dst_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDiffDstMemory
(
const
framework
::
Tensor
*
diffdst
)
{
const
T
*
ptr
=
diffdst
->
data
<
T
>
();
return
this
->
AcquireMemoryFromPrimitive
(
bwd_pd_
->
diff_dst_primitive_desc
(),
to_void_cast
<
T
>
(
ptr
),
"@diff_dst_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDiffSrcMemory
(
framework
::
Tensor
*
diffsrc
)
{
T
*
ptr
=
diffsrc
->
mutable_data
<
T
>
(
place_
,
bwd_pd_
->
diff_src_primitive_desc
().
get_size
());
return
this
->
AcquireMemoryFromPrimitive
(
bwd_pd_
->
diff_src_primitive_desc
(),
ptr
,
"@diff_src_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireMemoryFromPrimitive
(
mkldnn
::
memory
::
primitive_desc
mdp
,
void
*
ptr
,
const
std
::
string
&
suffix
)
{
auto
local_key
=
key_
+
suffix
;
auto
mem_p
=
std
::
static_pointer_cast
<
mkldnn
::
memory
>
(
dev_ctx_
.
GetBlob
(
local_key
));
if
(
mem_p
==
nullptr
)
{
mem_p
=
std
::
make_shared
<
mkldnn
::
memory
>
(
mdp
,
ptr
);
dev_ctx_
.
SetBlob
(
local_key
,
mem_p
);
}
else
{
mem_p
->
set_data_handle
(
ptr
);
}
return
mem_p
;
}
protected:
const
MKLDNNDeviceContext
&
dev_ctx_
;
mkldnn
::
engine
engine_
;
platform
::
Place
place_
;
std
::
string
key_
;
std
::
string
key_common_
;
std
::
shared_ptr
<
typename
TForward
::
primitive_desc
>
fwd_pd_
;
std
::
shared_ptr
<
typename
TBackward
::
primitive_desc
>
bwd_pd_
;
};
// TODO(grygielski) this class will be deleted later.
class
MKLDNNHandler
{
class
MKLDNNHandler
{
public:
public:
MKLDNNHandler
(
const
MKLDNNDeviceContext
&
dev_ctx
,
mkldnn
::
engine
engine
,
MKLDNNHandler
(
const
MKLDNNDeviceContext
&
dev_ctx
,
mkldnn
::
engine
engine
,
...
@@ -255,7 +339,9 @@ class SumMKLDNNHandler : public MKLDNNHandler {
...
@@ -255,7 +339,9 @@ class SumMKLDNNHandler : public MKLDNNHandler {
};
};
template
<
typename
T
>
template
<
typename
T
>
class
ActivationMKLDNNHandler
:
public
MKLDNNHandler
{
class
ActivationMKLDNNHandler
:
public
MKLDNNHandlerT
<
T
,
mkldnn
::
eltwise_forward
,
mkldnn
::
eltwise_backward
>
{
public:
public:
ActivationMKLDNNHandler
(
const
std
::
vector
<
int
>&
dims
,
ActivationMKLDNNHandler
(
const
std
::
vector
<
int
>&
dims
,
mkldnn
::
algorithm
algorithm
,
float
alpha
,
float
beta
,
mkldnn
::
algorithm
algorithm
,
float
alpha
,
float
beta
,
...
@@ -264,12 +350,11 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
...
@@ -264,12 +350,11 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
platform
::
Place
cpu_place
,
platform
::
Place
cpu_place
,
const
std
::
string
&
unique_name
)
const
std
::
string
&
unique_name
)
:
platform
::
MKLDNNHandler
(
dev_ctx
,
dev_ctx
.
GetEngine
(),
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
eltwise_forward
,
platform
::
CreateKey
(
dims
,
algorithm
,
fmt
,
alpha
,
mkldnn
::
eltwise_backward
>
(
beta
,
unique_name
)),
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
place_
(
cpu_place
),
platform
::
CreateKey
(
dims
,
algorithm
,
fmt
,
alpha
,
beta
,
fwd_pd_
(
nullptr
),
unique_name
))
{
bwd_pd_
(
nullptr
)
{
AcquireActivationPrimitiveDescriptor
(
AcquireActivationPrimitiveDescriptor
(
is_test
?
mkldnn
::
prop_kind
::
forward_inference
is_test
?
mkldnn
::
prop_kind
::
forward_inference
:
mkldnn
::
prop_kind
::
forward_training
,
:
mkldnn
::
prop_kind
::
forward_training
,
...
@@ -284,76 +369,37 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
...
@@ -284,76 +369,37 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
platform
::
Place
cpu_place
,
platform
::
Place
cpu_place
,
const
std
::
string
&
unique_name
)
const
std
::
string
&
unique_name
)
:
platform
::
MKLDNNHandler
(
dev_ctx
,
dev_ctx
.
GetEngine
(),
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
eltwise_forward
,
platform
::
CreateKey
(
dims
,
algorithm
,
fmt
,
alpha
,
mkldnn
::
eltwise_backward
>
(
beta
,
unique_name
)),
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
place_
(
cpu_place
),
platform
::
CreateKey
(
dims
,
algorithm
,
fmt
,
alpha
,
beta
,
fwd_pd_
(
nullptr
),
unique_name
))
{
bwd_pd_
(
nullptr
)
{
AcquireActivationPrimitiveDescriptor
(
mkldnn
::
prop_kind
::
forward_training
,
AcquireActivationPrimitiveDescriptor
(
mkldnn
::
prop_kind
::
forward_training
,
algorithm
,
dims
,
fmt
,
alpha
,
beta
);
algorithm
,
dims
,
fmt
,
alpha
,
beta
);
AcquireActivationBackwardPrimitiveDescriptor
(
algorithm
,
dims
,
fmt
,
diff_fmt
,
AcquireActivationBackwardPrimitiveDescriptor
(
algorithm
,
dims
,
fmt
,
diff_fmt
,
alpha
,
beta
);
alpha
,
beta
);
}
}
// TODO(jczaja): Once fwd_pd_ are moved to MKLDNNHandler then this
// function
// should be moved as well eg. ActivationMKLDNNHandler ->
// MKLDNNHandler<activation_>
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireSrcMemory
(
const
framework
::
Tensor
*
input
)
{
const
T
*
input_data
=
input
->
data
<
T
>
();
return
this
->
AcquireMemoryFromPrimitive
(
fwd_pd_
->
src_primitive_desc
(),
to_void_cast
<
T
>
(
input_data
),
"@src_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireBackwardSrcMemory
(
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireBackwardSrcMemory
(
const
framework
::
Tensor
*
input
)
{
const
framework
::
Tensor
*
input
)
{
const
T
*
input_data
=
input
->
data
<
T
>
();
const
T
*
input_data
=
input
->
data
<
T
>
();
return
this
->
AcquireMemoryFromPrimitive
(
bwd_pd_
->
src_primitive_desc
(),
return
this
->
AcquireMemoryFromPrimitive
(
this
->
bwd_pd_
->
src_primitive_desc
(),
to_void_cast
<
T
>
(
input_data
),
to_void_cast
<
T
>
(
input_data
),
"@bwd-src_mem_p"
);
"@bwd-src_mem_p"
);
}
}
// TODO(jczaja): Move to MKLDNNHandler as common code
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDstMemory
(
framework
::
Tensor
*
output
)
{
T
*
ptr
=
output
->
mutable_data
<
T
>
(
place_
,
fwd_pd_
->
dst_primitive_desc
().
get_size
());
return
this
->
AcquireMemoryFromPrimitive
(
fwd_pd_
->
dst_primitive_desc
(),
ptr
,
"@dst_mem_p"
);
}
// TODO(jczaja): Move to MKLDNNHandler as common code
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDiffDstMemory
(
const
framework
::
Tensor
*
diffdst
)
{
const
T
*
ptr
=
diffdst
->
data
<
T
>
();
return
this
->
AcquireMemoryFromPrimitive
(
bwd_pd_
->
diff_dst_primitive_desc
(),
to_void_cast
<
T
>
(
ptr
),
"@diff_dst_mem_p"
);
}
// TODO(jczaja): Move to MKLDNNHandler as common code
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDiffSrcMemory
(
framework
::
Tensor
*
diffsrc
)
{
T
*
ptr
=
diffsrc
->
mutable_data
<
T
>
(
place_
,
bwd_pd_
->
diff_src_primitive_desc
().
get_size
());
return
this
->
AcquireMemoryFromPrimitive
(
bwd_pd_
->
diff_src_primitive_desc
(),
ptr
,
"@diff_src_mem_p"
);
}
std
::
shared_ptr
<
mkldnn
::
eltwise_forward
>
AcquireActivation
(
std
::
shared_ptr
<
mkldnn
::
eltwise_forward
>
AcquireActivation
(
std
::
shared_ptr
<
mkldnn
::
memory
>
dst_memory_p
,
std
::
shared_ptr
<
mkldnn
::
memory
>
dst_memory_p
,
std
::
shared_ptr
<
mkldnn
::
memory
>
src_memory_p
)
{
std
::
shared_ptr
<
mkldnn
::
memory
>
src_memory_p
)
{
/*Generate key*/
/*Generate key*/
auto
prim_key
=
key_
+
"@eltwise_p"
;
auto
prim_key
=
this
->
key_
+
"@eltwise_p"
;
auto
eltwise_p
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_forward
>
(
auto
eltwise_p
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_forward
>
(
dev_ctx_
.
GetBlob
(
prim_key
));
this
->
dev_ctx_
.
GetBlob
(
prim_key
));
if
(
eltwise_p
==
nullptr
)
{
if
(
eltwise_p
==
nullptr
)
{
eltwise_p
=
std
::
make_shared
<
mkldnn
::
eltwise_forward
>
(
eltwise_p
=
std
::
make_shared
<
mkldnn
::
eltwise_forward
>
(
*
fwd_pd_
,
*
(
src_memory_p
),
*
(
dst_memory_p
));
*
this
->
fwd_pd_
,
*
(
src_memory_p
),
*
(
dst_memory_p
));
dev_ctx_
.
SetBlob
(
prim_key
,
eltwise_p
);
this
->
dev_ctx_
.
SetBlob
(
prim_key
,
eltwise_p
);
}
}
return
eltwise_p
;
return
eltwise_p
;
...
@@ -364,15 +410,15 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
...
@@ -364,15 +410,15 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
std
::
shared_ptr
<
mkldnn
::
memory
>
diff_dst_memory_p
,
std
::
shared_ptr
<
mkldnn
::
memory
>
diff_dst_memory_p
,
std
::
shared_ptr
<
mkldnn
::
memory
>
src_memory_p
)
{
std
::
shared_ptr
<
mkldnn
::
memory
>
src_memory_p
)
{
/*Generate key*/
/*Generate key*/
auto
prim_key
=
key_
+
"@eltwise_bwd_p"
;
auto
prim_key
=
this
->
key_
+
"@eltwise_bwd_p"
;
auto
eltwise_bwd_p
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_backward
>
(
auto
eltwise_bwd_p
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_backward
>
(
dev_ctx_
.
GetBlob
(
prim_key
));
this
->
dev_ctx_
.
GetBlob
(
prim_key
));
if
(
eltwise_bwd_p
==
nullptr
)
{
if
(
eltwise_bwd_p
==
nullptr
)
{
eltwise_bwd_p
=
std
::
make_shared
<
mkldnn
::
eltwise_backward
>
(
eltwise_bwd_p
=
std
::
make_shared
<
mkldnn
::
eltwise_backward
>
(
*
bwd_pd_
,
*
(
src_memory_p
),
*
(
diff_dst_memory_p
),
*
this
->
bwd_pd_
,
*
(
src_memory_p
),
*
(
diff_dst_memory_p
),
*
(
diff_src_memory_p
));
*
(
diff_src_memory_p
));
dev_ctx_
.
SetBlob
(
prim_key
,
eltwise_bwd_p
);
this
->
dev_ctx_
.
SetBlob
(
prim_key
,
eltwise_bwd_p
);
}
}
return
eltwise_bwd_p
;
return
eltwise_bwd_p
;
...
@@ -387,26 +433,27 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
...
@@ -387,26 +433,27 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
// Activation PD has to be passed to Grad op that
// Activation PD has to be passed to Grad op that
// may be executed by diffrent thread, hence
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
// for that one we use key that does not contain TID
const
std
::
string
key_activation_pd
=
key_common_
+
"@activation_pd"
;
const
std
::
string
key_activation_pd
=
this
->
key_common_
+
"@activation_pd"
;
fwd_pd_
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_forward
::
primitive_desc
>
(
this
->
fwd_pd_
=
dev_ctx_
.
GetBlob
(
key_activation_pd
));
std
::
static_pointer_cast
<
mkldnn
::
eltwise_forward
::
primitive_desc
>
(
if
(
fwd_pd_
==
nullptr
)
{
this
->
dev_ctx_
.
GetBlob
(
key_activation_pd
));
if
(
this
->
fwd_pd_
==
nullptr
)
{
static
std
::
mutex
acquire_barrier
;
static
std
::
mutex
acquire_barrier
;
std
::
lock_guard
<
std
::
mutex
>
block_threads_until_finish_this_job
(
std
::
lock_guard
<
std
::
mutex
>
block_threads_until_finish_this_job
(
acquire_barrier
);
acquire_barrier
);
fwd_pd_
=
this
->
fwd_pd_
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_forward
::
primitive_desc
>
(
std
::
static_pointer_cast
<
mkldnn
::
eltwise_forward
::
primitive_desc
>
(
dev_ctx_
.
GetBlob
(
key_activation_pd
));
this
->
dev_ctx_
.
GetBlob
(
key_activation_pd
));
if
(
fwd_pd_
==
nullptr
)
{
if
(
this
->
fwd_pd_
==
nullptr
)
{
auto
md
=
platform
::
MKLDNNMemDesc
(
auto
md
=
platform
::
MKLDNNMemDesc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
fmt
);
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
fmt
);
auto
activation_desc
=
mkldnn
::
eltwise_forward
::
desc
(
auto
activation_desc
=
mkldnn
::
eltwise_forward
::
desc
(
prop_kind
,
algorithm
,
md
,
alpha
,
beta
);
prop_kind
,
algorithm
,
md
,
alpha
,
beta
);
fwd_pd_
.
reset
(
new
mkldnn
::
eltwise_forward
::
primitive_desc
(
this
->
fwd_pd_
.
reset
(
new
mkldnn
::
eltwise_forward
::
primitive_desc
(
activation_desc
,
engine_
));
activation_desc
,
this
->
engine_
));
dev_ctx_
.
SetBlob
(
key_activation_pd
,
fwd_pd_
);
this
->
dev_ctx_
.
SetBlob
(
key_activation_pd
,
this
->
fwd_pd_
);
}
}
}
}
}
}
...
@@ -415,17 +462,18 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
...
@@ -415,17 +462,18 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
mkldnn
::
algorithm
algorithm
,
const
std
::
vector
<
int
>&
dims
,
mkldnn
::
algorithm
algorithm
,
const
std
::
vector
<
int
>&
dims
,
const
MKLDNNMemoryFormat
fmt
,
const
MKLDNNMemoryFormat
diff_fmt
,
const
MKLDNNMemoryFormat
fmt
,
const
MKLDNNMemoryFormat
diff_fmt
,
float
alpha
,
float
beta
)
{
float
alpha
,
float
beta
)
{
const
std
::
string
key_activation_pd
=
key_common_
+
"@activation_pd"
;
const
std
::
string
key_activation_pd
=
this
->
key_common_
+
"@activation_pd"
;
const
std
::
string
key_activation_bwd_pd
=
key_
+
"@activation_bwd_pd"
;
const
std
::
string
key_activation_bwd_pd
=
this
->
key_
+
"@activation_bwd_pd"
;
bwd_pd_
=
this
->
bwd_pd_
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_backward
::
primitive_desc
>
(
std
::
static_pointer_cast
<
mkldnn
::
eltwise_backward
::
primitive_desc
>
(
dev_ctx_
.
GetBlob
(
key_activation_bwd_pd
));
this
->
dev_ctx_
.
GetBlob
(
key_activation_bwd_pd
));
if
(
bwd_pd_
==
nullptr
)
{
if
(
this
->
bwd_pd_
==
nullptr
)
{
fwd_pd_
=
this
->
fwd_pd_
=
std
::
static_pointer_cast
<
mkldnn
::
eltwise_forward
::
primitive_desc
>
(
std
::
static_pointer_cast
<
mkldnn
::
eltwise_forward
::
primitive_desc
>
(
dev_ctx_
.
GetBlob
(
key_activation_pd
));
this
->
dev_ctx_
.
GetBlob
(
key_activation_pd
));
// PD from FWD op has to exist.
// PD from FWD op has to exist.
PADDLE_ENFORCE_NOT_NULL
(
fwd_pd_
,
"Eltwise MKL-DNN not found in cache!"
);
PADDLE_ENFORCE_NOT_NULL
(
this
->
fwd_pd_
,
"Eltwise MKL-DNN not found in cache!"
);
auto
diff_dst_md
=
platform
::
MKLDNNMemDesc
(
auto
diff_dst_md
=
platform
::
MKLDNNMemDesc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
diff_fmt
);
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
diff_fmt
);
...
@@ -434,16 +482,11 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
...
@@ -434,16 +482,11 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
auto
backward_desc
=
mkldnn
::
eltwise_backward
::
desc
(
auto
backward_desc
=
mkldnn
::
eltwise_backward
::
desc
(
algorithm
,
diff_dst_md
,
src_md
,
alpha
,
beta
);
algorithm
,
diff_dst_md
,
src_md
,
alpha
,
beta
);
bwd_pd_
.
reset
(
new
mkldnn
::
eltwise_backward
::
primitive_desc
(
this
->
bwd_pd_
.
reset
(
new
mkldnn
::
eltwise_backward
::
primitive_desc
(
backward_desc
,
engine_
,
*
fwd_pd_
));
backward_desc
,
this
->
engine_
,
*
this
->
fwd_pd_
));
dev_ctx_
.
SetBlob
(
key_activation_bwd_pd
,
bwd_pd_
);
this
->
dev_ctx_
.
SetBlob
(
key_activation_bwd_pd
,
this
->
bwd_pd_
);
}
}
}
}
private:
platform
::
Place
place_
;
std
::
shared_ptr
<
mkldnn
::
eltwise_forward
::
primitive_desc
>
fwd_pd_
;
std
::
shared_ptr
<
mkldnn
::
eltwise_backward
::
primitive_desc
>
bwd_pd_
;
};
};
class
LRNMKLDNNHandler
:
public
MKLDNNHandler
{
class
LRNMKLDNNHandler
:
public
MKLDNNHandler
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录