Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
f6cca625
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f6cca625
编写于
12月 14, 2020
作者:
J
Jacek Czaja
提交者:
GitHub
12月 14, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[oneDNN] Making ThreadID info in caching key optional (#29272)
上级
08f24a31
变更
21
显示空白变更内容
内联
并排
Showing
21 changed file
with
113 addition
and
110 deletion
+113
-110
paddle/fluid/framework/data_layout_transform.cc
paddle/fluid/framework/data_layout_transform.cc
+2
-2
paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc
paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc
+3
-8
paddle/fluid/operators/fused/mkldnn/multi_gru_mkldnn_op.cc
paddle/fluid/operators/fused/mkldnn/multi_gru_mkldnn_op.cc
+2
-7
paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc
paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc
+3
-2
paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc
paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc
+3
-2
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
+21
-22
paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc
paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc
+1
-2
paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc
paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc
+5
-2
paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc
paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc
+8
-6
paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc
paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc
+1
-1
paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc
paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc
+2
-3
paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc
paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc
+5
-3
paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc
paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc
+1
-1
paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc
paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc
+5
-3
paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc
paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc
+3
-3
paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc
paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc
+3
-3
paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc
paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc
+5
-3
paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc
paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc
+3
-2
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+5
-0
paddle/fluid/platform/mkldnn_helper.h
paddle/fluid/platform/mkldnn_helper.h
+17
-8
paddle/fluid/platform/mkldnn_reuse.h
paddle/fluid/platform/mkldnn_reuse.h
+15
-27
未找到文件。
paddle/fluid/framework/data_layout_transform.cc
浏览文件 @
f6cca625
...
@@ -181,8 +181,8 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
...
@@ -181,8 +181,8 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
if
(
in_format
!=
out_format
)
{
if
(
in_format
!=
out_format
)
{
void
*
in_data
=
GetDataFromTensor
(
in
,
in_type
);
void
*
in_data
=
GetDataFromTensor
(
in
,
in_type
);
const
std
::
string
key
=
std
::
string
key
=
platform
::
CreateKey
(
in_tz
,
in_format
,
out_format
,
in_type
);
platform
::
CreateKey
(
*
dev_ctx
,
in_tz
,
in_format
,
out_format
,
in_type
);
platform
::
ReorderMKLDNNHandler
handler
(
in_tz
,
in
.
type
(),
in_type
,
*
dev_ctx
,
platform
::
ReorderMKLDNNHandler
handler
(
in_tz
,
in
.
type
(),
in_type
,
*
dev_ctx
,
cpu_engine
,
key
);
cpu_engine
,
key
);
...
...
paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc
浏览文件 @
f6cca625
...
@@ -39,20 +39,15 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {
...
@@ -39,20 +39,15 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {
const
std
::
string
&
unique_name
)
const
std
::
string
&
unique_name
)
:
platform
::
MKLDNNHandlerT
<
T
,
dnnl
::
gru_forward
>
(
:
platform
::
MKLDNNHandlerT
<
T
,
dnnl
::
gru_forward
>
(
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
CreateKey
(
unique_name
,
MKLDNNGetDataType
<
T
>
(),
Ti
)),
CreateKey
(
dev_ctx
,
unique_name
,
MKLDNNGetDataType
<
T
>
(),
Ti
)),
N
(
N
),
N
(
N
),
Ti
(
Ti
),
Ti
(
Ti
),
IC
(
IC
),
IC
(
IC
),
OC
(
OC
)
{
OC
(
OC
)
{
// Create memory key without Ti because weights, bias and h0 memories
// Create memory key without Ti because weights, bias and h0 memories
// do not depend on Ti size but primitive and input/output memory do
// do not depend on Ti size but primitive and input/output memory do
if
(
platform
::
MKLDNNDeviceContext
::
tls
().
get_cur_mkldnn_session_id
()
!=
memory_key_
=
platform
::
ExtendKeyWithThreadInfoIfNeeded
(
platform
::
MKLDNNDeviceContextThreadLocals
::
kMKLDNNSessionID_Default
)
{
dev_ctx
,
CreateKey
(
dev_ctx
,
unique_name
,
MKLDNNGetDataType
<
T
>
()));
memory_key_
=
CreateKey
(
unique_name
,
MKLDNNGetDataType
<
T
>
());
}
else
{
memory_key_
=
CreateKey
(
unique_name
,
MKLDNNGetDataType
<
T
>
(),
"-t:"
,
platform
::
ThreadIDasStr
());
}
// Is it int8 kernel
// Is it int8 kernel
const
bool
is_INT8
=
std
::
is_same
<
T
,
uint8_t
>::
value
;
const
bool
is_INT8
=
std
::
is_same
<
T
,
uint8_t
>::
value
;
...
...
paddle/fluid/operators/fused/mkldnn/multi_gru_mkldnn_op.cc
浏览文件 @
f6cca625
...
@@ -109,13 +109,8 @@ class MultiGRUHandler {
...
@@ -109,13 +109,8 @@ class MultiGRUHandler {
const
std
::
string
unique_name
=
ctx
.
OutputName
(
"Hidden"
);
const
std
::
string
unique_name
=
ctx
.
OutputName
(
"Hidden"
);
// Create memory key without Ti because weights, bias and h0 memories
// Create memory key without Ti because weights, bias and h0 memories
// do not depend on Ti size but primitive and input/output memory do
// do not depend on Ti size but primitive and input/output memory do
if
(
platform
::
MKLDNNDeviceContext
::
tls
().
get_cur_mkldnn_session_id
()
!=
memory_key_
=
platform
::
ExtendKeyWithThreadInfoIfNeeded
(
platform
::
MKLDNNDeviceContextThreadLocals
::
kMKLDNNSessionID_Default
)
{
dev_ctx
,
CreateKey
(
dev_ctx
,
unique_name
,
MKLDNNGetDataType
<
T
>
()));
memory_key_
=
CreateKey
(
unique_name
,
MKLDNNGetDataType
<
T
>
());
}
else
{
memory_key_
=
CreateKey
(
unique_name
,
MKLDNNGetDataType
<
T
>
(),
"-t:"
,
platform
::
ThreadIDasStr
());
}
key_
=
memory_key_
;
key_
=
memory_key_
;
key_
.
append
(
"T"
).
append
(
std
::
to_string
(
Ti_
));
key_
.
append
(
"T"
).
append
(
std
::
to_string
(
Ti_
));
...
...
paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc
浏览文件 @
f6cca625
...
@@ -48,7 +48,8 @@ class BatchNormMKLDNNHandler
...
@@ -48,7 +48,8 @@ class BatchNormMKLDNNHandler
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
batch_normalization_forward
,
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
batch_normalization_forward
,
mkldnn
::
batch_normalization_backward
>
(
mkldnn
::
batch_normalization_backward
>
(
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
platform
::
CreateKey
(
framework
::
vectorize
(
x
->
dims
()),
unique_name
))
{
platform
::
CreateKey
(
dev_ctx
,
framework
::
vectorize
(
x
->
dims
()),
unique_name
))
{
if
(
!
this
->
isCached
())
{
if
(
!
this
->
isCached
())
{
const
float
epsilon
=
ctx
.
Attr
<
float
>
(
"epsilon"
);
const
float
epsilon
=
ctx
.
Attr
<
float
>
(
"epsilon"
);
const
bool
fuse_with_relu
=
ctx
.
Attr
<
bool
>
(
"fuse_with_relu"
);
const
bool
fuse_with_relu
=
ctx
.
Attr
<
bool
>
(
"fuse_with_relu"
);
...
@@ -89,7 +90,7 @@ class BatchNormMKLDNNHandler
...
@@ -89,7 +90,7 @@ class BatchNormMKLDNNHandler
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
batch_normalization_forward
,
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
batch_normalization_forward
,
mkldnn
::
batch_normalization_backward
>
(
mkldnn
::
batch_normalization_backward
>
(
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
platform
::
CreateKey
(
dims
,
uniq_name
))
{
platform
::
CreateKey
(
d
ev_ctx
,
d
ims
,
uniq_name
))
{
auto
diff_dst_md
=
auto
diff_dst_md
=
mkldnn
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
diff_fmt
);
mkldnn
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
diff_fmt
);
auto
src_md
=
auto
src_md
=
...
...
paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc
浏览文件 @
f6cca625
...
@@ -158,9 +158,10 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -158,9 +158,10 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// If one of the multiple inputs of concat has an input size of 0, the
// If one of the multiple inputs of concat has an input size of 0, the
// actual size of the multi_input will change
// actual size of the multi_input will change
std
::
string
key
=
platform
::
CreateKey
(
std
::
string
key
=
platform
::
CreateKey
(
paddle
::
framework
::
vectorize
<
int
>
(
multi_input
[
0
]
->
dims
()),
dev_ctx
,
paddle
::
framework
::
vectorize
<
int
>
(
multi_input
[
0
]
->
dims
()),
multi_input
.
size
(),
ctx
.
OutputName
(
"Out"
),
dt
,
multi_input
.
size
(),
ctx
.
OutputName
(
"Out"
),
dt
,
platform
::
ThreadIDasStr
(),
dev_ctx
.
GetKeySuffix
());
platform
::
ThreadIDasStr
());
key
=
platform
::
ExtendKeyWithThreadInfoIfNeeded
(
dev_ctx
,
key
);
const
std
::
string
key_prim
=
key
+
"@concat_p"
;
const
std
::
string
key_prim
=
key
+
"@concat_p"
;
const
std
::
string
key_concat_pd
=
key
+
"@concat_pd"
;
const
std
::
string
key_concat_pd
=
key
+
"@concat_pd"
;
...
...
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
浏览文件 @
f6cca625
...
@@ -95,7 +95,7 @@ class ConvMKLDNNHandlerT
...
@@ -95,7 +95,7 @@ class ConvMKLDNNHandlerT
const
std
::
string
&
unique_name
)
const
std
::
string
&
unique_name
)
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
convolution_forward
>
(
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
convolution_forward
>
(
dev_ctx
,
mkldnn_engine
,
cpu_place
,
dev_ctx
,
mkldnn_engine
,
cpu_place
,
platform
::
CreateKey
(
framework
::
vectorize
(
input
->
dims
()),
platform
::
CreateKey
(
dev_ctx
,
framework
::
vectorize
(
input
->
dims
()),
unique_name
))
{
unique_name
))
{
if
(
!
this
->
isCached
())
{
if
(
!
this
->
isCached
())
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
...
@@ -521,8 +521,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -521,8 +521,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mkldnn
::
memory
::
data_type
src_dt
=
mkldnn
::
memory
::
data_type
src_dt
=
paddle
::
framework
::
ToMKLDNNDataType
(
input
->
type
());
paddle
::
framework
::
ToMKLDNNDataType
(
input
->
type
());
std
::
string
key
=
platform
::
CreateKey
(
std
::
string
key
=
src_tz
,
src_dt
,
ctx
.
InputName
(
"Input"
)
+
ctx
.
InputName
(
"Filter"
));
platform
::
CreateKey
(
dev_ctx
,
src_tz
,
src_dt
,
ctx
.
InputName
(
"Input"
)
+
ctx
.
InputName
(
"Filter"
));
const
std
::
string
key_conv_pd
=
key
+
"@conv_pd"
;
const
std
::
string
key_conv_pd
=
key
+
"@conv_pd"
;
bool
need_s8_to_u8
=
false
;
bool
need_s8_to_u8
=
false
;
...
@@ -537,21 +538,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -537,21 +538,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// This is workaround for hacky implementation
// This is workaround for hacky implementation
// of conv int8 mkl-dnn. Once conv fp32 and conv int8
// of conv int8 mkl-dnn. Once conv fp32 and conv int8
// are merged/unified, this will disappear
// are merged/unified, this will disappear
std
::
string
key_tid
=
""
;
auto
key_tid
=
platform
::
ExtendKeyWithThreadInfoIfNeeded
(
dev_ctx
,
key
);
if
(
platform
::
MKLDNNDeviceContext
::
tls
().
get_cur_mkldnn_session_id
()
==
platform
::
MKLDNNDeviceContextThreadLocals
::
kMKLDNNSessionID_Default
)
{
auto
prim_key
=
key_tid
+
"@conv_p"
;
key_tid
=
"-t:"
+
platform
::
ThreadIDasStr
();
auto
dst_key
=
key_tid
+
"@dst_mem_p"
;
}
auto
src_key
=
key_tid
+
"@src_mem_p"
;
auto
weights_key
=
key_tid
+
"@weights_mem_p"
;
auto
prim_key
=
key
+
key_tid
+
"@conv_p"
;
auto
bias_key
=
key_tid
+
"@bias_mem_p"
;
auto
dst_key
=
key
+
key_tid
+
"@dst_mem_p"
;
auto
user_src_key
=
key_tid
+
"@user_src_mem_p"
;
auto
src_key
=
key
+
key_tid
+
"@src_mem_p"
;
auto
user_residual_key
=
key_tid
+
"@user_residual_data_mem_p"
;
auto
weights_key
=
key
+
key_tid
+
"@weights_mem_p"
;
auto
src_reorder_key
=
key_tid
+
"@src_mem_preorder_p"
;
auto
bias_key
=
key
+
key_tid
+
"@bias_mem_p"
;
auto
residual_reorder_key
=
key_tid
+
"@residual_data_mem_preorder_p"
;
auto
user_src_key
=
key
+
key_tid
+
"@user_src_mem_p"
;
auto
user_residual_key
=
key
+
key_tid
+
"@user_residual_data_mem_p"
;
auto
src_reorder_key
=
key
+
key_tid
+
"@src_mem_preorder_p"
;
auto
residual_reorder_key
=
key
+
key_tid
+
"@residual_data_mem_preorder_p"
;
conv_p
=
std
::
static_pointer_cast
<
mkldnn
::
convolution_forward
>
(
conv_p
=
std
::
static_pointer_cast
<
mkldnn
::
convolution_forward
>
(
dev_ctx
.
GetBlob
(
prim_key
));
dev_ctx
.
GetBlob
(
prim_key
));
...
@@ -972,10 +969,11 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -972,10 +969,11 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// Get an unique name from "argument" name of "input" and "Filter" variable
// Get an unique name from "argument" name of "input" and "Filter" variable
// as well as attributes of primitive to be created
// as well as attributes of primitive to be created
// This name will be used as key when saving info into device context
// This name will be used as key when saving info into device context
const
std
::
string
key
=
platform
::
CreateKey
(
std
::
string
key
=
platform
::
CreateKey
(
src_tz
,
ctx
.
InputName
(
"Input"
)
+
ctx
.
InputName
(
"Filter"
));
dev_ctx
,
src_tz
,
ctx
.
InputName
(
"Input"
)
+
ctx
.
InputName
(
"Filter"
));
const
std
::
string
key_conv_pd
=
key
+
"@fwd_pd"
;
const
std
::
string
key_conv_pd
=
key
+
"@fwd_pd"
;
key
=
platform
::
ExtendKeyWithThreadInfoIfNeeded
(
dev_ctx
,
key
);
std
::
vector
<
primitive
>
pipeline
;
std
::
vector
<
primitive
>
pipeline
;
// Create user memory descriptors
// Create user memory descriptors
...
@@ -1090,8 +1088,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -1090,8 +1088,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
mkldnn
::
memory
::
format_tag
out_format
=
mkldnn
::
memory
::
format_tag
out_format
=
weights_tz
.
size
()
==
6
?
mkldnn
::
memory
::
format_tag
::
goidhw
weights_tz
.
size
()
==
6
?
mkldnn
::
memory
::
format_tag
::
goidhw
:
mkldnn
::
memory
::
format_tag
::
goihw
;
:
mkldnn
::
memory
::
format_tag
::
goihw
;
const
std
::
string
key
=
std
::
string
key
=
platform
::
CreateKey
(
dev_ctx
,
weights_tz
,
filter_fmt
,
platform
::
CreateKey
(
weights_tz
,
filter_fmt
,
out_format
,
in_type
);
out_format
,
in_type
);
key
=
platform
::
ExtendKeyWithThreadInfoIfNeeded
(
dev_ctx
,
key
);
platform
::
ReorderMKLDNNHandler
handler
(
weights_tz
,
filter_grad
->
type
(),
platform
::
ReorderMKLDNNHandler
handler
(
weights_tz
,
filter_grad
->
type
(),
in_type
,
dev_ctx
,
mkldnn_engine
,
in_type
,
dev_ctx
,
mkldnn_engine
,
...
...
paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc
浏览文件 @
f6cca625
...
@@ -172,9 +172,8 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -172,9 +172,8 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto
dst_tz
=
paddle
::
framework
::
vectorize
<
int64_t
>
(
output
->
dims
());
auto
dst_tz
=
paddle
::
framework
::
vectorize
<
int64_t
>
(
output
->
dims
());
// Get unique name for storing MKLDNN primitives
// Get unique name for storing MKLDNN primitives
const
std
::
string
key
=
const
std
::
string
key
=
platform
::
CreateKey
(
src_tz
,
ctx
.
OutputName
(
"Output"
));
platform
::
CreateKey
(
dev_ctx
,
src_tz
,
ctx
.
OutputName
(
"Output"
));
std
::
vector
<
mkldnn
::
primitive
>
pipeline
;
std
::
vector
<
mkldnn
::
primitive
>
pipeline
;
...
...
paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc
浏览文件 @
f6cca625
...
@@ -67,8 +67,11 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
...
@@ -67,8 +67,11 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
mkldnn
::
memory
::
data_type
src_dt
=
mkldnn
::
memory
::
data_type
src_dt
=
paddle
::
framework
::
ToMKLDNNDataType
(
input
->
type
());
paddle
::
framework
::
ToMKLDNNDataType
(
input
->
type
());
MKLDNNMemoryFormat
src_fmt
=
input
->
format
();
MKLDNNMemoryFormat
src_fmt
=
input
->
format
();
std
::
string
key
=
platform
::
CreateKey
(
platform
::
ThreadIDasStr
(),
src_dt
,
src_tz
,
ctx
.
OutputName
(
"Output"
));
std
::
string
key
=
platform
::
CreateKey
(
dev_ctx
,
src_dt
,
src_tz
,
ctx
.
OutputName
(
"Output"
));
key
=
platform
::
ExtendKeyWithThreadInfoIfNeeded
(
dev_ctx
,
key
);
const
std
::
string
key_prim
=
key
+
"@r"
;
const
std
::
string
key_prim
=
key
+
"@r"
;
const
std
::
string
key_src_mem
=
key
+
"@s"
;
const
std
::
string
key_src_mem
=
key
+
"@s"
;
const
std
::
string
key_dst_mem
=
key
+
"@d"
;
const
std
::
string
key_dst_mem
=
key
+
"@d"
;
...
...
paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc
浏览文件 @
f6cca625
...
@@ -370,8 +370,9 @@ class FCPrimitiveFactory {
...
@@ -370,8 +370,9 @@ class FCPrimitiveFactory {
void
CacheWeightsAndBias
(
const
MKLDNNDeviceContext
&
dev_ctx
,
void
CacheWeightsAndBias
(
const
MKLDNNDeviceContext
&
dev_ctx
,
const
ExecutionContext
&
ctx
)
{
const
ExecutionContext
&
ctx
)
{
const
std
::
string
key
=
std
::
string
key
=
platform
::
CreateKey
(
dev_ctx
);
platform
::
CreateKey
(
platform
::
ThreadIDasStr
(),
dev_ctx
.
GetKeySuffix
());
key
=
platform
::
ExtendKeyWithThreadInfoIfNeeded
(
dev_ctx
,
key
);
const
std
::
string
weights_key
=
key
+
ctx
.
InputName
(
"W"
);
const
std
::
string
weights_key
=
key
+
ctx
.
InputName
(
"W"
);
const
std
::
string
bias_key
=
key
+
ctx
.
InputName
(
"Bias"
);
const
std
::
string
bias_key
=
key
+
ctx
.
InputName
(
"Bias"
);
dev_ctx
.
SetBlob
(
weights_key
,
weights_
);
dev_ctx
.
SetBlob
(
weights_key
,
weights_
);
...
@@ -541,10 +542,11 @@ static void ExecuteFc(const ExecutionContext& ctx, const LoDTensor* input,
...
@@ -541,10 +542,11 @@ static void ExecuteFc(const ExecutionContext& ctx, const LoDTensor* input,
const
Tensor
*
w
,
const
Tensor
*
bias
,
LoDTensor
*
output
,
const
Tensor
*
w
,
const
Tensor
*
bias
,
LoDTensor
*
output
,
bool
fuse_relu
,
bool
force_fp32_output
)
{
bool
fuse_relu
,
bool
force_fp32_output
)
{
auto
&
dev_ctx
=
ctx
.
template
device_context
<
MKLDNNDeviceContext
>();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
MKLDNNDeviceContext
>();
const
std
::
string
prim_key
=
platform
::
CreateKey
(
std
::
string
prim_key
=
platform
::
CreateKey
(
platform
::
ThreadIDasStr
(),
dev_ctx
.
GetKeySuffix
(),
input
->
format
(),
dev_ctx
,
input
->
format
(),
input
->
dims
()[
0
],
input
->
dims
()[
0
],
framework
::
vectorize
<
int
>
(
w
->
dims
()),
framework
::
vectorize
<
int
>
(
w
->
dims
()),
ctx
.
OutputName
(
"Out"
));
ctx
.
OutputName
(
"Out"
));
prim_key
=
platform
::
ExtendKeyWithThreadInfoIfNeeded
(
dev_ctx
,
prim_key
);
constexpr
bool
is_int8
=
constexpr
bool
is_int8
=
std
::
is_same
<
T_in
,
int8_t
>::
value
||
std
::
is_same
<
T_in
,
uint8_t
>::
value
;
std
::
is_same
<
T_in
,
int8_t
>::
value
||
std
::
is_same
<
T_in
,
uint8_t
>::
value
;
bool
is_bfloat16
=
std
::
is_same
<
T_in
,
paddle
::
platform
::
bfloat16
>::
value
;
bool
is_bfloat16
=
std
::
is_same
<
T_in
,
paddle
::
platform
::
bfloat16
>::
value
;
...
...
paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc
浏览文件 @
f6cca625
...
@@ -30,7 +30,7 @@ class LayerNormMKLDNNHandler
...
@@ -30,7 +30,7 @@ class LayerNormMKLDNNHandler
const
std
::
string
&
uniq_name
)
const
std
::
string
&
uniq_name
)
:
platform
::
MKLDNNHandlerT
<
T
,
dnnl
::
layer_normalization_forward
>
(
:
platform
::
MKLDNNHandlerT
<
T
,
dnnl
::
layer_normalization_forward
>
(
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
platform
::
CreateKey
(
dims
,
uniq_name
))
{
platform
::
CreateKey
(
d
ev_ctx
,
d
ims
,
uniq_name
))
{
if
(
!
this
->
isCached
())
{
if
(
!
this
->
isCached
())
{
auto
md
=
dnnl
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
fmt
);
auto
md
=
dnnl
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
fmt
);
if
(
!
is_test
)
{
if
(
!
is_test
)
{
...
...
paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc
浏览文件 @
f6cca625
...
@@ -336,9 +336,8 @@ static std::shared_ptr<MatMulFactory<XT, YT, OT>> GetPrimitiveFactory(
...
@@ -336,9 +336,8 @@ static std::shared_ptr<MatMulFactory<XT, YT, OT>> GetPrimitiveFactory(
const
auto
&
out_name
=
ctx
.
OutputName
(
"Out"
);
const
auto
&
out_name
=
ctx
.
OutputName
(
"Out"
);
const
auto
&
dev_ctx
=
ctx
.
template
device_context
<
MKLDNNDeviceContext
>();
const
auto
&
dev_ctx
=
ctx
.
template
device_context
<
MKLDNNDeviceContext
>();
const
auto
batch_size
=
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
()[
0
];
const
auto
batch_size
=
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
()[
0
];
std
::
string
key
=
platform
::
CreateKey
(
dev_ctx
,
batch_size
,
out_name
);
const
std
::
string
key
=
platform
::
CreateKey
(
key
=
platform
::
ExtendKeyWithThreadInfoIfNeeded
(
dev_ctx
,
key
);
platform
::
ThreadIDasStr
(),
dev_ctx
.
GetKeySuffix
(),
batch_size
,
out_name
);
auto
factory
=
auto
factory
=
std
::
static_pointer_cast
<
MatMulFactory
<
XT
,
YT
,
OT
>>
(
dev_ctx
.
GetBlob
(
key
));
std
::
static_pointer_cast
<
MatMulFactory
<
XT
,
YT
,
OT
>>
(
dev_ctx
.
GetBlob
(
key
));
...
...
paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc
浏览文件 @
f6cca625
...
@@ -305,9 +305,11 @@ std::shared_ptr<MulPrimitiveFactory<XT, YT, OT>> GetPrimitiveFactory(
...
@@ -305,9 +305,11 @@ std::shared_ptr<MulPrimitiveFactory<XT, YT, OT>> GetPrimitiveFactory(
const
MKLDNNDeviceContext
&
dev_ctx
,
const
ExecutionContext
&
ctx
,
const
MKLDNNDeviceContext
&
dev_ctx
,
const
ExecutionContext
&
ctx
,
const
Tensor
*
input_x
,
const
Tensor
*
input_y
,
const
Tensor
*
input_x
,
const
Tensor
*
input_y
,
const
mkldnn
::
engine
&
mkldnn_engine
)
{
const
mkldnn
::
engine
&
mkldnn_engine
)
{
const
std
::
string
key
=
platform
::
CreateKey
(
std
::
string
key
=
platform
::
CreateKey
(
input_x
->
type
(),
framework
::
vectorize
(
input_x
->
dims
()),
input_y
->
type
(),
dev_ctx
,
input_x
->
type
(),
framework
::
vectorize
(
input_x
->
dims
()),
framework
::
vectorize
(
input_y
->
dims
()),
ctx
.
OutputName
(
"Out"
));
input_y
->
type
(),
framework
::
vectorize
(
input_y
->
dims
()),
ctx
.
OutputName
(
"Out"
));
key
=
platform
::
ExtendKeyWithThreadInfoIfNeeded
(
dev_ctx
,
key
);
auto
prim_creator
=
std
::
static_pointer_cast
<
MulPrimitiveFactory
<
XT
,
YT
,
OT
>>
(
auto
prim_creator
=
std
::
static_pointer_cast
<
MulPrimitiveFactory
<
XT
,
YT
,
OT
>>
(
dev_ctx
.
GetBlob
(
key
));
dev_ctx
.
GetBlob
(
key
));
...
...
paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc
浏览文件 @
f6cca625
...
@@ -140,7 +140,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -140,7 +140,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// Get an unique name from "argument" name of "Out" variable
// Get an unique name from "argument" name of "Out" variable
// This name will be used as key when referring info from device context
// This name will be used as key when referring info from device context
const
std
::
string
key
=
platform
::
CreateKey
(
const
std
::
string
key
=
platform
::
CreateKey
(
diff_src_tz
,
pooling_type
,
ksize
,
strides
,
paddings
,
d
ev_ctx
,
d
iff_src_tz
,
pooling_type
,
ksize
,
strides
,
paddings
,
memory
::
data_type
::
f32
,
in_x
->
format
(),
ctx
.
InputName
(
"Out"
));
memory
::
data_type
::
f32
,
in_x
->
format
(),
ctx
.
InputName
(
"Out"
));
platform
::
PoolingMKLDNNHandler
<
T
>
handler
(
platform
::
PoolingMKLDNNHandler
<
T
>
handler
(
...
...
paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc
浏览文件 @
f6cca625
...
@@ -64,9 +64,11 @@ class QuantOpKernel : public framework::OpKernel<T> {
...
@@ -64,9 +64,11 @@ class QuantOpKernel : public framework::OpKernel<T> {
bool
is_negative_input
=
ctx
.
Attr
<
bool
>
(
"is_negative_input"
);
bool
is_negative_input
=
ctx
.
Attr
<
bool
>
(
"is_negative_input"
);
bool
bfloat16
=
ctx
.
Attr
<
bool
>
(
"bfloat16"
);
bool
bfloat16
=
ctx
.
Attr
<
bool
>
(
"bfloat16"
);
std
::
string
key
=
platform
::
CreateKey
(
std
::
string
key
=
platform
::
ThreadIDasStr
()
,
src_tz
,
scale_data
,
scale_shift
,
platform
::
CreateKey
(
dev_ctx
,
src_tz
,
scale_data
,
scale_shift
,
is_negative_input
,
ctx
.
OutputName
(
"Output"
));
is_negative_input
,
ctx
.
OutputName
(
"Output"
));
key
=
platform
::
ExtendKeyWithThreadInfoIfNeeded
(
dev_ctx
,
key
);
const
std
::
string
key_prim
=
key
+
"@r"
;
const
std
::
string
key_prim
=
key
+
"@r"
;
const
std
::
string
key_src_mem
=
key
+
"@s"
;
const
std
::
string
key_src_mem
=
key
+
"@s"
;
const
std
::
string
key_dst_mem
=
key
+
"@d"
;
const
std
::
string
key_dst_mem
=
key
+
"@d"
;
...
...
paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc
浏览文件 @
f6cca625
...
@@ -65,9 +65,9 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
...
@@ -65,9 +65,9 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
float
reorder_scale
=
scale_out
/
scale_in
;
float
reorder_scale
=
scale_out
/
scale_in
;
std
::
string
key
=
std
::
string
key
=
platform
::
CreateKey
(
dev_ctx
,
src_tz
,
scale_in
,
scale_out
,
platform
::
CreateKey
(
platform
::
ThreadIDasStr
(),
src_tz
,
scale_in
,
ctx
.
OutputName
(
"Output"
));
scale_out
,
ctx
.
OutputName
(
"Output"
)
);
key
=
platform
::
ExtendKeyWithThreadInfoIfNeeded
(
dev_ctx
,
key
);
const
std
::
string
key_prim
=
key
+
"@r"
;
const
std
::
string
key_prim
=
key
+
"@r"
;
const
std
::
string
key_src_mem
=
key
+
"@s"
;
const
std
::
string
key_src_mem
=
key
+
"@s"
;
const
std
::
string
key_dst_mem
=
key
+
"@d"
;
const
std
::
string
key_dst_mem
=
key
+
"@d"
;
...
...
paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc
浏览文件 @
f6cca625
...
@@ -53,8 +53,8 @@ class SoftmaxMKLDNNHandler
...
@@ -53,8 +53,8 @@ class SoftmaxMKLDNNHandler
mkldnn
::
softmax_backward
>
(
mkldnn
::
softmax_backward
>
(
dev_ctx
,
mkldnn_engine
,
cpu_place
,
dev_ctx
,
mkldnn_engine
,
cpu_place
,
// Softmax may be inplace then uniq_name is no longer unique
// Softmax may be inplace then uniq_name is no longer unique
platform
::
CreateKey
(
framework
::
vectorize
(
input
->
dims
()),
axis
,
platform
::
CreateKey
(
dev_ctx
,
framework
::
vectorize
(
input
->
dims
())
,
uniq_name
))
{
axis
,
uniq_name
))
{
if
(
!
this
->
isCached
())
{
if
(
!
this
->
isCached
())
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
input
->
dims
(),
output
->
dims
(),
input
->
dims
(),
output
->
dims
(),
...
@@ -78,7 +78,7 @@ class SoftmaxMKLDNNHandler
...
@@ -78,7 +78,7 @@ class SoftmaxMKLDNNHandler
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
softmax_forward
,
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
softmax_forward
,
mkldnn
::
softmax_backward
>
(
mkldnn
::
softmax_backward
>
(
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
platform
::
CreateKey
(
dims
,
axis
,
uniq_name
))
{
platform
::
CreateKey
(
d
ev_ctx
,
d
ims
,
axis
,
uniq_name
))
{
auto
data_softmax_md
=
auto
data_softmax_md
=
mkldnn
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
fmt
);
mkldnn
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
fmt
);
auto
diff_softmax_md
=
auto
diff_softmax_md
=
...
...
paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc
浏览文件 @
f6cca625
...
@@ -54,7 +54,8 @@ class SumMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::sum> {
...
@@ -54,7 +54,8 @@ class SumMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::sum> {
:
platform
::
MKLDNNHandlerT
<
T
,
dnnl
::
sum
>
(
:
platform
::
MKLDNNHandlerT
<
T
,
dnnl
::
sum
>
(
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
platform
::
CreateKey
(
framework
::
vectorize
(
z
->
dims
()),
uniq_name
)),
platform
::
CreateKey
(
dev_ctx
,
framework
::
vectorize
(
z
->
dims
()),
uniq_name
)),
num_inputs_
(
0
)
{
num_inputs_
(
0
)
{
for
(
size_t
i
=
0
;
i
<
in_vars
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
in_vars
.
size
();
i
++
)
{
srcs_suffix_
.
push_back
(
std
::
string
(
"-"
)
+
std
::
to_string
(
i
));
srcs_suffix_
.
push_back
(
std
::
string
(
"-"
)
+
std
::
to_string
(
i
));
...
@@ -184,8 +185,9 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -184,8 +185,9 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// For in-place execution which sum does not have we need to fake it
// For in-place execution which sum does not have we need to fake it
// so from oneDNN dst memory we reorder data into input
// so from oneDNN dst memory we reorder data into input
if
(
in_place
)
{
if
(
in_place
)
{
const
std
::
string
reorder_key
=
platform
::
CreateKey
(
const
std
::
string
reorder_key
=
framework
::
vectorize
(
output
->
dims
()),
ctx
.
OutputName
(
"Out"
)
+
"-I"
);
platform
::
CreateKey
(
dev_ctx
,
framework
::
vectorize
(
output
->
dims
()),
ctx
.
OutputName
(
"Out"
)
+
"-I"
);
auto
&
in_out
=
in_vars
[
0
]
->
Get
<
framework
::
LoDTensor
>
();
auto
&
in_out
=
in_vars
[
0
]
->
Get
<
framework
::
LoDTensor
>
();
auto
output_tz
=
framework
::
vectorize
<
int64_t
>
(
output
->
dims
());
auto
output_tz
=
framework
::
vectorize
<
int64_t
>
(
output
->
dims
());
...
...
paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc
浏览文件 @
f6cca625
...
@@ -48,7 +48,8 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -48,7 +48,8 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto
nchw_tz
=
paddle
::
framework
::
vectorize
<
int64_t
>
(
input
->
dims
());
auto
nchw_tz
=
paddle
::
framework
::
vectorize
<
int64_t
>
(
input
->
dims
());
const
std
::
string
key
=
platform
::
CreateKey
(
nchw_tz
,
ctx
.
OutputName
(
"Out"
));
const
std
::
string
key
=
platform
::
CreateKey
(
dev_ctx
,
nchw_tz
,
ctx
.
OutputName
(
"Out"
));
platform
::
TransposeMKLDNNHandler
<
T
>
handler
(
nchw_tz
,
axis
,
dev_ctx
,
platform
::
TransposeMKLDNNHandler
<
T
>
handler
(
nchw_tz
,
axis
,
dev_ctx
,
mkldnn_engine
,
key
);
mkldnn_engine
,
key
);
...
@@ -103,7 +104,7 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -103,7 +104,7 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto
nchw_tz
=
paddle
::
framework
::
vectorize
<
int64_t
>
(
out_grad
->
dims
());
auto
nchw_tz
=
paddle
::
framework
::
vectorize
<
int64_t
>
(
out_grad
->
dims
());
const
std
::
string
key
=
platform
::
CreateKey
(
const
std
::
string
key
=
platform
::
CreateKey
(
nchw_tz
,
ctx
.
OutputName
(
framework
::
GradVarName
(
"X"
)));
dev_ctx
,
nchw_tz
,
ctx
.
OutputName
(
framework
::
GradVarName
(
"X"
)));
platform
::
TransposeMKLDNNHandler
<
T
>
handler
(
nchw_tz
,
reversed_axis
,
dev_ctx
,
platform
::
TransposeMKLDNNHandler
<
T
>
handler
(
nchw_tz
,
reversed_axis
,
dev_ctx
,
mkldnn_engine
,
key
);
mkldnn_engine
,
key
);
...
...
paddle/fluid/platform/device_context.h
浏览文件 @
f6cca625
...
@@ -532,6 +532,10 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
...
@@ -532,6 +532,10 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
void
SetKeySuffix
(
const
std
::
string
&
suffix
)
{
key_suffix_
=
suffix
;
}
void
SetKeySuffix
(
const
std
::
string
&
suffix
)
{
key_suffix_
=
suffix
;
}
const
std
::
string
&
GetKeySuffix
(
void
)
const
{
return
key_suffix_
;
}
const
std
::
string
&
GetKeySuffix
(
void
)
const
{
return
key_suffix_
;
}
// Disable adding thread ID to the key
void
DisableThreadInfoInKey
(
void
)
{
key_attach_thread_id_
=
false
;
};
bool
IsThreadIdUsedInKey
(
void
)
const
{
return
key_attach_thread_id_
;
};
// Prevent next ResetBlobMap()
// Prevent next ResetBlobMap()
void
BlockNextCacheClearing
();
void
BlockNextCacheClearing
();
...
@@ -554,6 +558,7 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
...
@@ -554,6 +558,7 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
std
::
shared_ptr
<
std
::
mutex
>
p_mutex_
;
std
::
shared_ptr
<
std
::
mutex
>
p_mutex_
;
bool
block_next_cache_clearing_
=
false
;
bool
block_next_cache_clearing_
=
false
;
std
::
string
key_suffix_
;
// Key identifying current Executor
std
::
string
key_suffix_
;
// Key identifying current Executor
bool
key_attach_thread_id_
=
true
;
};
};
#endif
#endif
...
...
paddle/fluid/platform/mkldnn_helper.h
浏览文件 @
f6cca625
...
@@ -431,11 +431,6 @@ inline void AppendKey(std::string* key, const std::vector<T>& dims) {
...
@@ -431,11 +431,6 @@ inline void AppendKey(std::string* key, const std::vector<T>& dims) {
}
}
}
}
inline
unsigned
int
HashPointer
(
uintptr_t
ptr
)
{
// Get four less meaningful digits in decimal numerals
return
ptr
%
1000
;
}
// If MKLDNN build and CPU place then register suffix in DeviceContext
// If MKLDNN build and CPU place then register suffix in DeviceContext
inline
void
AttachPointerHashToMKLDNNKey
(
void
*
ptr
,
inline
void
AttachPointerHashToMKLDNNKey
(
void
*
ptr
,
const
platform
::
Place
&
place
)
{
const
platform
::
Place
&
place
)
{
...
@@ -443,20 +438,34 @@ inline void AttachPointerHashToMKLDNNKey(void* ptr,
...
@@ -443,20 +438,34 @@ inline void AttachPointerHashToMKLDNNKey(void* ptr,
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
MKLDNNDeviceContext
*
dev_ctx
=
platform
::
MKLDNNDeviceContext
*
dev_ctx
=
(
platform
::
MKLDNNDeviceContext
*
)
pool
.
Get
(
place
);
(
platform
::
MKLDNNDeviceContext
*
)
pool
.
Get
(
place
);
dev_ctx
->
SetKeySuffix
(
"E"
+
std
::
to_string
(
platform
::
HashPointer
(
dev_ctx
->
SetKeySuffix
(
"E"
+
reinterpret_cast
<
uintptr_t
>
(
ptr
))));
std
::
to_string
(
reinterpret_cast
<
uintptr_t
>
(
ptr
)));
// When NaiveExecutor/Executor is used no info on thread id is needed in a
// key
dev_ctx
->
DisableThreadInfoInKey
();
}
}
}
}
template
<
typename
...
ArgTypes
>
template
<
typename
...
ArgTypes
>
inline
std
::
string
CreateKey
(
ArgTypes
&&
...
args
)
{
inline
std
::
string
CreateKey
(
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
ArgTypes
&&
...
args
)
{
std
::
string
key
;
std
::
string
key
;
key
.
reserve
(
64
);
key
.
reserve
(
64
);
using
expand_type
=
int
[];
using
expand_type
=
int
[];
expand_type
{
0
,
(
AppendKey
(
&
key
,
std
::
forward
<
ArgTypes
>
(
args
)),
0
)...};
expand_type
{
0
,
(
AppendKey
(
&
key
,
std
::
forward
<
ArgTypes
>
(
args
)),
0
)...};
key
+=
dev_ctx
.
GetKeySuffix
();
return
key
;
return
key
;
}
}
inline
std
::
string
ExtendKeyWithThreadInfoIfNeeded
(
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
const
std
::
string
&
key
)
{
return
((
dev_ctx
.
IsThreadIdUsedInKey
()
==
true
)
&&
(
platform
::
MKLDNNDeviceContext
::
tls
().
get_cur_mkldnn_session_id
()
==
platform
::
MKLDNNDeviceContextThreadLocals
::
kMKLDNNSessionID_Default
))
?
key
+
"-t:"
+
ThreadIDasStr
()
:
key
;
}
inline
std
::
vector
<
std
::
vector
<
int64_t
>>
ToMkldnnPadding
(
inline
std
::
vector
<
std
::
vector
<
int64_t
>>
ToMkldnnPadding
(
const
std
::
vector
<
int64_t
>&
paddings
)
{
const
std
::
vector
<
int64_t
>&
paddings
)
{
if
(
paddings
.
size
()
==
6
)
{
if
(
paddings
.
size
()
==
6
)
{
...
...
paddle/fluid/platform/mkldnn_reuse.h
浏览文件 @
f6cca625
...
@@ -43,16 +43,9 @@ class MKLDNNHandlerT {
...
@@ -43,16 +43,9 @@ class MKLDNNHandlerT {
engine_
(
engine
),
engine_
(
engine
),
place_
(
cpu_place
),
place_
(
cpu_place
),
key_common_
(
base_key
),
key_common_
(
base_key
),
key_
(
platform
::
ExtendKeyWithThreadInfoIfNeeded
(
dev_ctx
,
base_key
)),
fwd_pd_
(
nullptr
),
fwd_pd_
(
nullptr
),
bwd_pd_
(
nullptr
)
{
bwd_pd_
(
nullptr
)
{}
if
(
platform
::
MKLDNNDeviceContext
::
tls
().
get_cur_mkldnn_session_id
()
!=
platform
::
MKLDNNDeviceContextThreadLocals
::
kMKLDNNSessionID_Default
)
{
key_
=
key_common_
;
}
else
{
key_
=
key_common_
+
"-t:"
+
ThreadIDasStr
();
}
key_
+=
dev_ctx
.
GetKeySuffix
();
}
std
::
shared_ptr
<
TForward
>
AcquireForwardPrimitive
()
{
std
::
shared_ptr
<
TForward
>
AcquireForwardPrimitive
()
{
const
std
::
string
key_p
=
key_
+
"@fwd_p"
;
const
std
::
string
key_p
=
key_
+
"@fwd_p"
;
...
@@ -306,8 +299,8 @@ class MKLDNNHandlerT {
...
@@ -306,8 +299,8 @@ class MKLDNNHandlerT {
const
MKLDNNDeviceContext
&
dev_ctx_
;
const
MKLDNNDeviceContext
&
dev_ctx_
;
mkldnn
::
engine
engine_
;
mkldnn
::
engine
engine_
;
platform
::
Place
place_
;
platform
::
Place
place_
;
std
::
string
key_
;
std
::
string
key_common_
;
std
::
string
key_common_
;
std
::
string
key_
;
std
::
shared_ptr
<
typename
TForward
::
primitive_desc
>
fwd_pd_
;
std
::
shared_ptr
<
typename
TForward
::
primitive_desc
>
fwd_pd_
;
std
::
shared_ptr
<
typename
TBackward
::
primitive_desc
>
bwd_pd_
;
std
::
shared_ptr
<
typename
TBackward
::
primitive_desc
>
bwd_pd_
;
};
};
...
@@ -317,15 +310,10 @@ class MKLDNNHandler {
...
@@ -317,15 +310,10 @@ class MKLDNNHandler {
public:
public:
MKLDNNHandler
(
const
MKLDNNDeviceContext
&
dev_ctx
,
mkldnn
::
engine
engine
,
MKLDNNHandler
(
const
MKLDNNDeviceContext
&
dev_ctx
,
mkldnn
::
engine
engine
,
const
std
::
string
&
base_key
)
const
std
::
string
&
base_key
)
:
dev_ctx_
(
dev_ctx
),
engine_
(
engine
),
key_common_
(
base_key
)
{
:
dev_ctx_
(
dev_ctx
),
if
(
platform
::
MKLDNNDeviceContext
::
tls
().
get_cur_mkldnn_session_id
()
!=
engine_
(
engine
),
platform
::
MKLDNNDeviceContextThreadLocals
::
kMKLDNNSessionID_Default
)
{
key_common_
(
base_key
),
key_
=
key_common_
;
key_
(
platform
::
ExtendKeyWithThreadInfoIfNeeded
(
dev_ctx
,
base_key
))
{}
}
else
{
key_
=
key_common_
+
"-t:"
+
ThreadIDasStr
();
}
key_
+=
dev_ctx
.
GetKeySuffix
();
}
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireSrcMemory
(
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireSrcMemory
(
const
mkldnn
::
memory
::
desc
&
md
,
void
*
ptr
)
{
const
mkldnn
::
memory
::
desc
&
md
,
void
*
ptr
)
{
...
@@ -508,8 +496,8 @@ class MKLDNNHandler {
...
@@ -508,8 +496,8 @@ class MKLDNNHandler {
protected:
protected:
const
MKLDNNDeviceContext
&
dev_ctx_
;
const
MKLDNNDeviceContext
&
dev_ctx_
;
mkldnn
::
engine
engine_
;
mkldnn
::
engine
engine_
;
std
::
string
key_
;
std
::
string
key_common_
;
std
::
string
key_common_
;
std
::
string
key_
;
};
};
template
<
typename
T
>
template
<
typename
T
>
...
@@ -524,7 +512,7 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
...
@@ -524,7 +512,7 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
:
platform
::
MKLDNNHandlerT
<
T
,
dnnl
::
binary
>
(
:
platform
::
MKLDNNHandlerT
<
T
,
dnnl
::
binary
>
(
dev_ctx
,
engine
,
cpu_place
,
dev_ctx
,
engine
,
cpu_place
,
platform
::
CreateKey
(
platform
::
CreateKey
(
framework
::
vectorize
(
x
->
dims
()),
dev_ctx
,
framework
::
vectorize
(
x
->
dims
()),
uniq_name
+
(
algo
==
dnnl
::
algorithm
::
binary_mul
?
"M"
:
""
)))
{
uniq_name
+
(
algo
==
dnnl
::
algorithm
::
binary_mul
?
"M"
:
""
)))
{
// bradcasting combined with in-place may require
// bradcasting combined with in-place may require
auto
rankdiff
=
x
->
dims
().
size
()
-
y
->
dims
().
size
();
auto
rankdiff
=
x
->
dims
().
size
()
-
y
->
dims
().
size
();
...
@@ -627,7 +615,7 @@ class ActivationMKLDNNHandler
...
@@ -627,7 +615,7 @@ class ActivationMKLDNNHandler
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
eltwise_forward
,
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
eltwise_forward
,
mkldnn
::
eltwise_backward
>
(
mkldnn
::
eltwise_backward
>
(
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
platform
::
CreateKey
(
dims
,
"a"
,
algorithm
,
unique_name
))
{
platform
::
CreateKey
(
d
ev_ctx
,
d
ims
,
"a"
,
algorithm
,
unique_name
))
{
auto
md
=
mkldnn
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
fmt
);
auto
md
=
mkldnn
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
fmt
);
this
->
AcquireForwardPrimitiveDescriptor
(
mkldnn
::
prop_kind
::
forward_training
,
this
->
AcquireForwardPrimitiveDescriptor
(
mkldnn
::
prop_kind
::
forward_training
,
...
@@ -645,7 +633,7 @@ class ActivationMKLDNNHandler
...
@@ -645,7 +633,7 @@ class ActivationMKLDNNHandler
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
eltwise_forward
,
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
eltwise_forward
,
mkldnn
::
eltwise_backward
>
(
mkldnn
::
eltwise_backward
>
(
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
platform
::
CreateKey
(
dims
,
"a"
,
algorithm
,
unique_name
))
{
platform
::
CreateKey
(
d
ev_ctx
,
d
ims
,
"a"
,
algorithm
,
unique_name
))
{
auto
diff_dst_md
=
platform
::
MKLDNNMemDesc
(
auto
diff_dst_md
=
platform
::
MKLDNNMemDesc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
diff_fmt
);
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
diff_fmt
);
auto
src_md
=
auto
src_md
=
...
@@ -676,7 +664,7 @@ class LRNMKLDNNHandler
...
@@ -676,7 +664,7 @@ class LRNMKLDNNHandler
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
lrn_forward
,
mkldnn
::
lrn_backward
>
(
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
lrn_forward
,
mkldnn
::
lrn_backward
>
(
dev_ctx
,
mkldnn_engine
,
cpu_place
,
dev_ctx
,
mkldnn_engine
,
cpu_place
,
platform
::
CreateKey
(
framework
::
vectorize
(
input
->
dims
()),
platform
::
CreateKey
(
dev_ctx
,
framework
::
vectorize
(
input
->
dims
()),
unique_name
))
{
unique_name
))
{
if
(
!
this
->
isCached
())
{
if
(
!
this
->
isCached
())
{
const
int
n
=
ctx
.
Attr
<
int
>
(
"n"
);
const
int
n
=
ctx
.
Attr
<
int
>
(
"n"
);
...
@@ -712,7 +700,7 @@ class LRNMKLDNNHandler
...
@@ -712,7 +700,7 @@ class LRNMKLDNNHandler
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
lrn_forward
,
mkldnn
::
lrn_backward
>
(
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
lrn_forward
,
mkldnn
::
lrn_backward
>
(
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
platform
::
CreateKey
(
dims
,
unique_name
))
{
platform
::
CreateKey
(
d
ev_ctx
,
d
ims
,
unique_name
))
{
auto
src_md
=
auto
src_md
=
mkldnn
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
fmt
);
mkldnn
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
fmt
);
auto
diff_md
=
auto
diff_md
=
...
@@ -752,7 +740,7 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward,
...
@@ -752,7 +740,7 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward,
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
pooling_forward
,
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
pooling_forward
,
mkldnn
::
pooling_backward
>
(
mkldnn
::
pooling_backward
>
(
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
platform
::
CreateKey
(
framework
::
vectorize
(
input
->
dims
()),
platform
::
CreateKey
(
dev_ctx
,
framework
::
vectorize
(
input
->
dims
()),
framework
::
ToMKLDNNDataType
(
input
->
type
()),
framework
::
ToMKLDNNDataType
(
input
->
type
()),
unique_name
))
{
unique_name
))
{
if
(
!
this
->
isCached
())
{
if
(
!
this
->
isCached
())
{
...
@@ -861,7 +849,7 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward,
...
@@ -861,7 +849,7 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward,
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
pooling_forward
,
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
pooling_forward
,
mkldnn
::
pooling_backward
>
(
mkldnn
::
pooling_backward
>
(
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
platform
::
CreateKey
(
diff_src_dims
,
dt
,
unique_name
))
{
platform
::
CreateKey
(
d
ev_ctx
,
d
iff_src_dims
,
dt
,
unique_name
))
{
auto
diff_dst_md
=
mkldnn
::
memory
::
desc
(
auto
diff_dst_md
=
mkldnn
::
memory
::
desc
(
diff_dst_dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
diff_dst_fmt
);
diff_dst_dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
diff_dst_fmt
);
auto
diff_src_md
=
auto
diff_src_md
=
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录