Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
db2b6b65
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
db2b6b65
编写于
5月 14, 2020
作者:
P
pawelpiotrowicz
提交者:
GitHub
5月 14, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Hide globals & redesign restore PR (#24279)
test=develop
上级
4a105f80
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
136 addition
and
90 deletion
+136
-90
paddle/fluid/framework/data_layout_transform.cc
paddle/fluid/framework/data_layout_transform.cc
+4
-3
paddle/fluid/framework/data_transform.cc
paddle/fluid/framework/data_transform.cc
+2
-1
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+2
-1
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+2
-2
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+10
-9
paddle/fluid/operators/controlflow/fetch_op.cc
paddle/fluid/operators/controlflow/fetch_op.cc
+4
-4
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
+2
-2
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+45
-43
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+61
-21
paddle/fluid/platform/mkldnn_reuse.h
paddle/fluid/platform/mkldnn_reuse.h
+4
-4
未找到文件。
paddle/fluid/framework/data_layout_transform.cc
浏览文件 @
db2b6b65
...
@@ -124,9 +124,10 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
...
@@ -124,9 +124,10 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
"TransDataLayoutFromMKLDNN only supports transform from MKLDNN to "
"TransDataLayoutFromMKLDNN only supports transform from MKLDNN to "
"non-MKLDNN"
);
"non-MKLDNN"
);
innerTransDataLayoutFromMKLDNN
(
in_layout
,
innerTransDataLayoutFromMKLDNN
(
paddle
::
platform
::
get_cur_paddle_data_layout
(),
in_layout
,
in
,
out
,
place
);
paddle
::
platform
::
MKLDNNDeviceContext
::
tls
().
get_cur_paddle_data_layout
(),
in
,
out
,
place
);
}
}
void
innerTransDataLayoutFromMKLDNN
(
DataLayout
in_layout
,
DataLayout
out_layout
,
void
innerTransDataLayoutFromMKLDNN
(
DataLayout
in_layout
,
DataLayout
out_layout
,
...
...
paddle/fluid/framework/data_transform.cc
浏览文件 @
db2b6b65
...
@@ -59,7 +59,8 @@ void TransformData(const OpKernelType &expected_kernel_type,
...
@@ -59,7 +59,8 @@ void TransformData(const OpKernelType &expected_kernel_type,
// For NHWC data we need reshape of tensors as MKL-DNN
// For NHWC data we need reshape of tensors as MKL-DNN
// is expecting NHWC dims description order
// is expecting NHWC dims description order
platform
::
MatchShapeToLayout
(
&
out
,
lin
,
lout
);
platform
::
MatchShapeToLayout
(
&
out
,
lin
,
lout
);
paddle
::
platform
::
set_cur_paddle_data_layout
(
lin
);
paddle
::
platform
::
MKLDNNDeviceContext
::
tls
().
set_cur_paddle_data_layout
(
lin
);
out
.
set_layout
(
DataLayout
::
kMKLDNN
);
out
.
set_layout
(
DataLayout
::
kMKLDNN
);
out
.
set_format
(
out_format
);
out
.
set_format
(
out_format
);
}
else
{
}
else
{
...
...
paddle/fluid/framework/executor.cc
浏览文件 @
db2b6b65
...
@@ -89,7 +89,8 @@ Executor::~Executor() {
...
@@ -89,7 +89,8 @@ Executor::~Executor() {
platform
::
MKLDNNDeviceContext
*
dev_ctx
=
platform
::
MKLDNNDeviceContext
*
dev_ctx
=
(
platform
::
MKLDNNDeviceContext
*
)
pool
.
Get
(
place_
);
(
platform
::
MKLDNNDeviceContext
*
)
pool
.
Get
(
place_
);
dev_ctx
->
ResetBlobMap
();
dev_ctx
->
ResetBlobMap
();
platform
::
set_cur_paddle_data_layout
(
paddle
::
framework
::
DataLayout
::
kNCHW
);
platform
::
MKLDNNDeviceContext
::
tls
().
set_cur_paddle_data_layout
(
paddle
::
framework
::
DataLayout
::
kNCHW
);
}
}
#endif
#endif
}
}
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
db2b6b65
...
@@ -1155,8 +1155,8 @@ Scope* OperatorWithKernel::PrepareData(
...
@@ -1155,8 +1155,8 @@ Scope* OperatorWithKernel::PrepareData(
if
((
tensor_in
->
layout
()
==
DataLayout
::
kMKLDNN
)
&&
if
((
tensor_in
->
layout
()
==
DataLayout
::
kMKLDNN
)
&&
(
var
->
IsType
<
LoDTensor
>
()
==
true
)
&&
(
var
->
IsType
<
LoDTensor
>
()
==
true
)
&&
(
expected_kernel_key
.
data_layout_
!=
DataLayout
::
kMKLDNN
)
&&
(
expected_kernel_key
.
data_layout_
!=
DataLayout
::
kMKLDNN
)
&&
(
paddle
::
platform
::
get_cur_paddle_data_layout
()
==
(
paddle
::
platform
::
MKLDNNDeviceContext
::
tls
()
DataLayout
::
kNHWC
))
{
.
get_cur_paddle_data_layout
()
==
DataLayout
::
kNHWC
))
{
// Mixed execution : MKL-DNN and GPU is not supported!
// Mixed execution : MKL-DNN and GPU is not supported!
if
(
!
new_scope
)
{
if
(
!
new_scope
)
{
new_scope
=
&
scope
.
NewScope
();
new_scope
=
&
scope
.
NewScope
();
...
...
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
db2b6b65
...
@@ -244,13 +244,14 @@ bool AnalysisPredictor::PrepareExecutor() {
...
@@ -244,13 +244,14 @@ bool AnalysisPredictor::PrepareExecutor() {
void
AnalysisPredictor
::
MkldnnPreSet
(
const
std
::
vector
<
PaddleTensor
>
&
inputs
)
{
void
AnalysisPredictor
::
MkldnnPreSet
(
const
std
::
vector
<
PaddleTensor
>
&
inputs
)
{
#ifdef PADDLE_WITH_MKLDNN
#ifdef PADDLE_WITH_MKLDNN
VLOG
(
2
)
<<
"AnalysisPredictor::Run get_cur_mkldnn_session_id="
VLOG
(
2
)
<<
"AnalysisPredictor::Run get_cur_mkldnn_session_id="
<<
platform
::
get_cur_mkldnn_session_id
();
<<
platform
::
MKLDNNDeviceContext
::
tls
().
get_cur_mkldnn_session_id
();
// In cache clearing mode.
// In cache clearing mode.
if
(
config_
.
mkldnn_cache_capacity_
>
0
)
{
if
(
config_
.
mkldnn_cache_capacity_
>
0
)
{
VLOG
(
2
)
<<
"In mkldnn cache clear mode."
;
VLOG
(
2
)
<<
"In mkldnn cache clear mode."
;
platform
::
set_cur_mkldnn_session_id
(
platform
::
MKLDNNDeviceContext
::
tls
().
set_cur_mkldnn_session_id
(
platform
::
kMKLDNNSessionID_CacheClearing
);
platform
::
MKLDNNDeviceContextThreadLocals
::
platform
::
set_cur_input_shape_cache_capacity
(
kMKLDNNSessionID_CacheClearing
);
platform
::
MKLDNNDeviceContext
::
tls
().
set_cur_input_shape_cache_capacity
(
config_
.
mkldnn_cache_capacity_
);
config_
.
mkldnn_cache_capacity_
);
// Set current_input_shape for caching dynamic shape.
// Set current_input_shape for caching dynamic shape.
std
::
stringstream
ss
;
std
::
stringstream
ss
;
...
@@ -260,7 +261,7 @@ void AnalysisPredictor::MkldnnPreSet(const std::vector<PaddleTensor> &inputs) {
...
@@ -260,7 +261,7 @@ void AnalysisPredictor::MkldnnPreSet(const std::vector<PaddleTensor> &inputs) {
}
}
}
}
VLOG
(
2
)
<<
"Set input shape="
<<
ss
.
str
();
VLOG
(
2
)
<<
"Set input shape="
<<
ss
.
str
();
platform
::
set_cur_input_shape_str
(
ss
.
str
());
platform
::
MKLDNNDeviceContext
::
tls
().
set_cur_input_shape_str
(
ss
.
str
());
}
}
#endif
#endif
}
}
...
@@ -277,10 +278,10 @@ void AnalysisPredictor::MkldnnPostReset() {
...
@@ -277,10 +278,10 @@ void AnalysisPredictor::MkldnnPostReset() {
CHECK_LE
(
shape_blob_size
,
CHECK_LE
(
shape_blob_size
,
static_cast
<
size_t
>
(
config_
.
mkldnn_cache_capacity_
));
static_cast
<
size_t
>
(
config_
.
mkldnn_cache_capacity_
));
}
}
paddle
::
platform
::
set_cur_mkldnn_session_id
(
paddle
::
platform
::
MKLDNNDeviceContext
::
tls
().
set_cur_mkldnn_session_id
(
platform
::
kMKLDNNSessionID_Default
);
platform
::
MKLDNNDeviceContextThreadLocals
::
kMKLDNNSessionID_Default
);
platform
::
set_cur_input_shape_cache_capacity
(
0
);
platform
::
MKLDNNDeviceContext
::
tls
().
set_cur_input_shape_cache_capacity
(
0
);
platform
::
set_cur_input_shape_str
(
""
);
platform
::
MKLDNNDeviceContext
::
tls
().
set_cur_input_shape_str
(
""
);
}
}
#endif
#endif
}
}
...
...
paddle/fluid/operators/controlflow/fetch_op.cc
浏览文件 @
db2b6b65
...
@@ -34,10 +34,10 @@ static void DataCopy(const framework::LoDTensor &src_item,
...
@@ -34,10 +34,10 @@ static void DataCopy(const framework::LoDTensor &src_item,
// Convert to desired Paddle layout, apart from grads of filter
// Convert to desired Paddle layout, apart from grads of filter
// as params are not a subject to paddle's data_format
// as params are not a subject to paddle's data_format
framework
::
innerTransDataLayoutFromMKLDNN
(
framework
::
innerTransDataLayoutFromMKLDNN
(
src_item
.
layout
(),
src_item
.
layout
(),
fetch_var_name
==
framework
::
GradVarName
(
"Filter"
)
fetch_var_name
==
framework
::
GradVarName
(
"Filter"
)
?
framework
::
DataLayout
::
kNCHW
?
framework
::
DataLayout
::
kNCHW
:
paddle
::
platform
::
MKLDNNDeviceContext
::
tls
()
:
paddle
::
platform
::
get_cur_paddle_data_layout
(),
.
get_cur_paddle_data_layout
(),
src_item
,
&
out
,
platform
::
CPUPlace
());
src_item
,
&
out
,
platform
::
CPUPlace
());
TensorCopySync
(
out
,
platform
::
CPUPlace
(),
dst_item
);
TensorCopySync
(
out
,
platform
::
CPUPlace
(),
dst_item
);
}
else
{
}
else
{
...
...
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
浏览文件 @
db2b6b65
...
@@ -446,8 +446,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -446,8 +446,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// of conv int8 mkl-dnn. Once conv fp32 and conv int8
// of conv int8 mkl-dnn. Once conv fp32 and conv int8
// are merged/unified, this will disappear
// are merged/unified, this will disappear
std
::
string
key_tid
=
""
;
std
::
string
key_tid
=
""
;
if
(
platform
::
get_cur_mkldnn_session_id
()
==
if
(
platform
::
MKLDNNDeviceContext
::
tls
().
get_cur_mkldnn_session_id
()
==
platform
::
kMKLDNNSessionID_Default
)
{
platform
::
MKLDNNDeviceContextThreadLocals
::
kMKLDNNSessionID_Default
)
{
key_tid
=
"-t:"
+
platform
::
ThreadIDasStr
();
key_tid
=
"-t:"
+
platform
::
ThreadIDasStr
();
}
}
...
...
paddle/fluid/platform/device_context.cc
浏览文件 @
db2b6b65
...
@@ -375,36 +375,37 @@ MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
...
@@ -375,36 +375,37 @@ MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
p_mutex_
.
reset
(
new
std
::
mutex
());
p_mutex_
.
reset
(
new
std
::
mutex
());
}
}
namespace
{
MKLDNNDeviceContextThreadLocals
::
Body
::
Body
()
{
// Current mkldnn session id.
cur_mkldnn_session_id
=
kMKLDNNSessionID_Default
;
thread_local
size_t
cur_mkldnn_session_id
=
kMKLDNNSessionID_Default
;
cur_input_shape_str
=
""
;
// Current data input shape string.
cur_input_shape_cache_capacity
=
1
;
// - For fixed-shape, it's a null string in default.
cur_paddle_data_layout
=
paddle
::
framework
::
DataLayout
::
kNCHW
;
// - For dynamic-shape, it's user specific.
}
thread_local
std
::
string
cur_input_shape_str
=
""
;
// the cache capacity of different input shapes for MKLDNN.
void
MKLDNNDeviceContextThreadLocals
::
Body
::
set_cur_mkldnn_session_id
(
// Default 1 means fixed input shape, not dynamic shape.
size_t
sid
)
{
thread_local
int
cur_input_shape_cache_capacity
=
1
;
cur_mkldnn_session_id
=
sid
;
// Recently registered data_format. This is needed to
}
// know for converting MKL-DNN Tensor to non MKL-DNN
size_t
MKLDNNDeviceContextThreadLocals
::
Body
::
get_cur_mkldnn_session_id
(
void
)
{
thread_local
paddle
::
framework
::
DataLayout
cur_paddle_data_layout
=
return
cur_mkldnn_session_id
;
paddle
::
framework
::
DataLayout
::
kNCHW
;
}
}
// namespace
void
MKLDNNDeviceContextThreadLocals
::
Body
::
set_cur_input_shape_str
(
void
set_cur_mkldnn_session_id
(
size_t
sid
)
{
cur_mkldnn_session_id
=
sid
;
}
std
::
string
input_shape_str
)
{
size_t
get_cur_mkldnn_session_id
(
void
)
{
return
cur_mkldnn_session_id
;
}
void
set_cur_input_shape_str
(
std
::
string
input_shape_str
)
{
cur_input_shape_str
=
input_shape_str
;
cur_input_shape_str
=
input_shape_str
;
}
}
void
set_cur_input_shape_cache_capacity
(
int
input_shape_cache_capacity
)
{
void
MKLDNNDeviceContextThreadLocals
::
Body
::
set_cur_input_shape_cache_capacity
(
int
input_shape_cache_capacity
)
{
cur_input_shape_cache_capacity
=
input_shape_cache_capacity
;
cur_input_shape_cache_capacity
=
input_shape_cache_capacity
;
}
}
void
set_cur_paddle_data_layout
(
framework
::
DataLayout
dl
)
{
void
MKLDNNDeviceContextThreadLocals
::
Body
::
set_cur_paddle_data_layout
(
framework
::
DataLayout
dl
)
{
cur_paddle_data_layout
=
dl
;
cur_paddle_data_layout
=
dl
;
}
}
framework
::
DataLayout
get_cur_paddle_data_layout
(
void
)
{
framework
::
DataLayout
MKLDNNDeviceContextThreadLocals
::
Body
::
get_cur_paddle_data_layout
(
void
)
{
return
cur_paddle_data_layout
;
return
cur_paddle_data_layout
;
}
}
...
@@ -414,32 +415,32 @@ void MKLDNNDeviceContext::ResetBlobMap() const {
...
@@ -414,32 +415,32 @@ void MKLDNNDeviceContext::ResetBlobMap() const {
}
}
size_t
MKLDNNDeviceContext
::
GetShapeBlobSize
()
const
{
size_t
MKLDNNDeviceContext
::
GetShapeBlobSize
()
const
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
*
p_mutex_
);
std
::
lock_guard
<
decltype
(
*
p_mutex_
)
>
lock
(
*
p_mutex_
);
BlobMap
*
pMap
=
p_blobmap_
.
get
();
BlobMap
*
pMap
=
p_blobmap_
.
get
();
auto
map_it
=
pMap
->
find
(
cur_mkldnn_session_id
);
auto
map_it
=
pMap
->
find
(
tls
().
cur_mkldnn_session_id
);
if
(
map_it
==
pMap
->
end
())
{
if
(
map_it
==
pMap
->
end
())
{
LOG
(
FATAL
)
<<
"MKLDNNDeviceContext don't find cur_mkldnn_session_id : "
LOG
(
FATAL
)
<<
"MKLDNNDeviceContext don't find cur_mkldnn_session_id : "
<<
cur_mkldnn_session_id
;
<<
tls
().
cur_mkldnn_session_id
;
}
}
return
map_it
->
second
->
size
();
return
map_it
->
second
->
size
();
}
}
void
MKLDNNDeviceContext
::
SetBlob
(
const
std
::
string
&
name
,
void
MKLDNNDeviceContext
::
SetBlob
(
const
std
::
string
&
name
,
std
::
shared_ptr
<
void
>
data
)
const
{
BlobPtr_t
<
void
>
data
)
const
{
BlobMap
*
pMap
=
p_blobmap_
.
get
();
BlobMap
*
pMap
=
p_blobmap_
.
get
();
std
::
shared_ptr
<
ShapeBlob
>
sBlob
=
nullptr
;
BlobPtr_t
<
ShapeBlob
>
sBlob
=
nullptr
;
std
::
shared_ptr
<
KeyBlob
>
pBlob
=
nullptr
;
BlobPtr_t
<
KeyBlob
>
pBlob
=
nullptr
;
int
sid
=
platform
::
get_cur_mkldnn_session_id
();
int
sid
=
tls
().
get_cur_mkldnn_session_id
();
std
::
lock_guard
<
std
::
mutex
>
lock
(
*
p_mutex_
);
std
::
lock_guard
<
decltype
(
*
p_mutex_
)
>
lock
(
*
p_mutex_
);
// Find ShapeBlob for current mkldnn session id.
// Find ShapeBlob for current mkldnn session id.
auto
map_it
=
pMap
->
find
(
sid
);
auto
map_it
=
pMap
->
find
(
sid
);
if
(
map_it
==
pMap
->
end
())
{
if
(
map_it
==
pMap
->
end
())
{
// 1st time to set blob in current thread
// 1st time to set blob in current thread
sBlob
=
std
::
shared_ptr
<
ShapeBlob
>
(
new
ShapeBlob
()
);
sBlob
=
std
::
make_shared
<
ShapeBlob
>
(
);
(
*
pMap
)[
sid
]
=
sBlob
;
(
*
pMap
)[
sid
]
=
sBlob
;
VLOG
(
2
)
<<
"SetBlob: sid="
<<
sid
<<
", add new sid
\n
"
;
VLOG
(
2
)
<<
"SetBlob: sid="
<<
sid
<<
", add new sid
\n
"
;
}
else
{
}
else
{
...
@@ -447,21 +448,22 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name,
...
@@ -447,21 +448,22 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name,
}
}
// Find KeyBlob for current input shape
// Find KeyBlob for current input shape
auto
key_it
=
sBlob
->
find
(
cur_input_shape_str
);
auto
key_it
=
sBlob
->
find
(
tls
().
cur_input_shape_str
);
if
(
key_it
==
sBlob
->
end
())
{
if
(
key_it
==
sBlob
->
end
())
{
// In cache clearing mode, cur_input_shape_cache_capacity defines
// In cache clearing mode, cur_input_shape_cache_capacity defines
// max pblob capacity
// max pblob capacity
if
((
static_cast
<
size_t
>
(
sid
)
==
kMKLDNNSessionID_CacheClearing
)
&&
if
((
static_cast
<
size_t
>
(
sid
)
==
MKLDNNDeviceContextThreadLocals
::
kMKLDNNSessionID_CacheClearing
)
&&
sBlob
->
size
()
&&
sBlob
->
size
()
&&
(
sBlob
->
size
()
>=
(
sBlob
->
size
()
>=
static_cast
<
size_t
>
(
cur_input_shape_cache_capacity
)))
{
static_cast
<
size_t
>
(
tls
().
cur_input_shape_cache_capacity
)))
{
VLOG
(
2
)
<<
"sid="
<<
sid
VLOG
(
2
)
<<
"sid="
<<
sid
<<
", remove all blobs of shape: "
<<
sBlob
->
begin
()
->
first
;
<<
", remove all blobs of shape: "
<<
sBlob
->
begin
()
->
first
;
sBlob
->
erase
(
sBlob
->
begin
()
->
first
);
sBlob
->
erase
(
sBlob
->
begin
()
->
first
);
}
}
pBlob
=
std
::
shared_ptr
<
KeyBlob
>
(
new
KeyBlob
()
);
pBlob
=
std
::
make_shared
<
KeyBlob
>
(
);
(
*
sBlob
)[
cur_input_shape_str
]
=
pBlob
;
(
*
sBlob
)[
tls
().
cur_input_shape_str
]
=
pBlob
;
}
else
{
}
else
{
pBlob
=
key_it
->
second
;
pBlob
=
key_it
->
second
;
}
}
...
@@ -478,15 +480,15 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name,
...
@@ -478,15 +480,15 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name,
return
;
return
;
}
}
std
::
shared_ptr
<
void
>
MKLDNNDeviceContext
::
GetBlob
(
MKLDNNDeviceContext
::
BlobPtr_t
<
void
>
MKLDNNDeviceContext
::
GetBlob
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
name
)
const
{
BlobMap
*
pMap
=
p_blobmap_
.
get
();
BlobMap
*
pMap
=
p_blobmap_
.
get
();
std
::
shared_ptr
<
ShapeBlob
>
sBlob
=
nullptr
;
BlobPtr_t
<
ShapeBlob
>
sBlob
=
nullptr
;
std
::
shared_ptr
<
KeyBlob
>
pBlob
=
nullptr
;
BlobPtr_t
<
KeyBlob
>
pBlob
=
nullptr
;
int
sid
=
platform
::
get_cur_mkldnn_session_id
();
int
sid
=
tls
().
get_cur_mkldnn_session_id
();
std
::
lock_guard
<
std
::
mutex
>
lock
(
*
p_mutex_
);
std
::
lock_guard
<
decltype
(
*
p_mutex_
)
>
lock
(
*
p_mutex_
);
// Find ShapeBlob for current mkldnn session id firstly
// Find ShapeBlob for current mkldnn session id firstly
auto
map_it
=
pMap
->
find
(
sid
);
auto
map_it
=
pMap
->
find
(
sid
);
...
@@ -497,9 +499,9 @@ std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
...
@@ -497,9 +499,9 @@ std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
sBlob
=
map_it
->
second
;
sBlob
=
map_it
->
second
;
// Find KeyBlob for current input shape secondly
// Find KeyBlob for current input shape secondly
auto
sBlob_it
=
sBlob
->
find
(
cur_input_shape_str
);
auto
sBlob_it
=
sBlob
->
find
(
tls
().
cur_input_shape_str
);
if
(
sBlob_it
==
sBlob
->
end
())
{
if
(
sBlob_it
==
sBlob
->
end
())
{
VLOG
(
2
)
<<
"GetBlob: sid="
<<
cur_input_shape_str
VLOG
(
2
)
<<
"GetBlob: sid="
<<
tls
().
cur_input_shape_str
<<
", miss input_shape_str
\n
"
;
<<
", miss input_shape_str
\n
"
;
return
nullptr
;
return
nullptr
;
}
}
...
...
paddle/fluid/platform/device_context.h
浏览文件 @
db2b6b65
...
@@ -421,30 +421,66 @@ struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
...
@@ -421,30 +421,66 @@ struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
#endif
#endif
#ifdef PADDLE_WITH_MKLDNN
#ifdef PADDLE_WITH_MKLDNN
// Following three maps are used to cache MKLDNN primitives.
// There relations are:
class
MKLDNNDeviceContextThreadLocals
{
// - BlobMap = Map<cur_thread_id, ShapeBlob>
// default mkldnn session id
// - ShapeBlob = Map<cur_input_shape_str, KeyBlob>
// - KeyBlob = Map<blob_name, blob>
typedef
MKLDNNDeviceContextThreadLocals
self
;
// Where:
struct
Body
{
using
KeyBlob
=
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
void
>>
;
size_t
cur_mkldnn_session_id
;
using
ShapeBlob
=
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
KeyBlob
>>
;
// Current data input shape string.
using
BlobMap
=
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
ShapeBlob
>>
;
// - For fixed-shape, it's a null string in default.
// - For dynamic-shape, it's user specific.
// default mkldnn session id
std
::
string
cur_input_shape_str
;
constexpr
size_t
kMKLDNNSessionID_Default
=
0
;
// the cache capacity of different input shapes for MKLDNN.
// mkldnn session id for cache clearing mode
// Default 1 means fixed input shape, not dynamic shape.
constexpr
size_t
kMKLDNNSessionID_CacheClearing
=
-
1
;
int
cur_input_shape_cache_capacity
;
// Recently registered data_format. This is needed to
void
set_cur_mkldnn_session_id
(
size_t
);
// know for converting MKL-DNN Tensor to non MKL-DNN
size_t
get_cur_mkldnn_session_id
(
void
);
paddle
::
framework
::
DataLayout
cur_paddle_data_layout
;
void
set_cur_input_shape_str
(
std
::
string
input_shape_str
);
void
set_cur_input_shape_cache_capacity
(
int
input_shape_cache_capacity
);
Body
();
void
set_cur_paddle_data_layout
(
framework
::
DataLayout
);
void
set_cur_mkldnn_session_id
(
size_t
sid
);
framework
::
DataLayout
get_cur_paddle_data_layout
(
void
);
size_t
get_cur_mkldnn_session_id
(
void
);
void
set_cur_input_shape_str
(
std
::
string
input_shape_str
);
void
set_cur_input_shape_cache_capacity
(
int
input_shape_cache_capacity
);
void
set_cur_paddle_data_layout
(
framework
::
DataLayout
dl
);
framework
::
DataLayout
get_cur_paddle_data_layout
(
void
);
};
MKLDNNDeviceContextThreadLocals
()
=
default
;
MKLDNNDeviceContextThreadLocals
(
const
MKLDNNDeviceContextThreadLocals
&
c
)
=
delete
;
public:
// default mkldnn session id
static
constexpr
size_t
kMKLDNNSessionID_Default
=
0
;
// mkldnn session id for cache clearing mode
static
constexpr
size_t
kMKLDNNSessionID_CacheClearing
=
-
1
;
static
Body
&
fetch
()
{
thread_local
Body
b
;
return
b
;
}
};
class
MKLDNNDeviceContext
:
public
CPUDeviceContext
{
class
MKLDNNDeviceContext
:
public
CPUDeviceContext
{
public:
public:
template
<
class
T
>
using
BlobPtr_t
=
std
::
shared_ptr
<
T
>
;
template
<
class
P1
,
class
P2
>
using
umap_value_smart_t
=
std
::
unordered_map
<
P1
,
BlobPtr_t
<
P2
>>
;
template
<
class
T
>
using
umap_key_string_t
=
umap_value_smart_t
<
std
::
string
,
T
>
;
// Following three maps are used to cache MKLDNN primitives.
// There relations are:
// - BlobMap = Map<cur_thread_id, ShapeBlob>
// - ShapeBlob = Map<cur_input_shape_str, KeyBlob>
// - KeyBlob = Map<blob_name, blob>
using
KeyBlob
=
umap_key_string_t
<
void
>
;
using
ShapeBlob
=
umap_key_string_t
<
KeyBlob
>
;
using
BlobMap
=
umap_value_smart_t
<
int
,
ShapeBlob
>
;
explicit
MKLDNNDeviceContext
(
CPUPlace
place
);
explicit
MKLDNNDeviceContext
(
CPUPlace
place
);
/* \brief Get the active engine */
/* \brief Get the active engine */
...
@@ -462,6 +498,10 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
...
@@ -462,6 +498,10 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
// Find a saved blob. Return nullptr if not found
// Find a saved blob. Return nullptr if not found
std
::
shared_ptr
<
void
>
GetBlob
(
const
std
::
string
&
name
)
const
;
std
::
shared_ptr
<
void
>
GetBlob
(
const
std
::
string
&
name
)
const
;
static
auto
tls
()
->
decltype
(
MKLDNNDeviceContextThreadLocals
::
fetch
())
{
return
MKLDNNDeviceContextThreadLocals
::
fetch
();
}
private:
private:
mkldnn
::
engine
engine_
;
mkldnn
::
engine
engine_
;
std
::
shared_ptr
<
BlobMap
>
p_blobmap_
;
std
::
shared_ptr
<
BlobMap
>
p_blobmap_
;
...
...
paddle/fluid/platform/mkldnn_reuse.h
浏览文件 @
db2b6b65
...
@@ -42,8 +42,8 @@ class MKLDNNHandlerT {
...
@@ -42,8 +42,8 @@ class MKLDNNHandlerT {
key_common_
(
base_key
),
key_common_
(
base_key
),
fwd_pd_
(
nullptr
),
fwd_pd_
(
nullptr
),
bwd_pd_
(
nullptr
)
{
bwd_pd_
(
nullptr
)
{
if
(
platform
::
get_cur_mkldnn_session_id
()
!=
if
(
platform
::
MKLDNNDeviceContext
::
tls
().
get_cur_mkldnn_session_id
()
!=
platform
::
kMKLDNNSessionID_Default
)
{
platform
::
MKLDNNDeviceContextThreadLocals
::
kMKLDNNSessionID_Default
)
{
key_
=
key_common_
;
key_
=
key_common_
;
}
else
{
}
else
{
key_
=
key_common_
+
"-t:"
+
ThreadIDasStr
();
key_
=
key_common_
+
"-t:"
+
ThreadIDasStr
();
...
@@ -177,8 +177,8 @@ class MKLDNNHandler {
...
@@ -177,8 +177,8 @@ class MKLDNNHandler {
MKLDNNHandler
(
const
MKLDNNDeviceContext
&
dev_ctx
,
mkldnn
::
engine
engine
,
MKLDNNHandler
(
const
MKLDNNDeviceContext
&
dev_ctx
,
mkldnn
::
engine
engine
,
const
std
::
string
&
base_key
)
const
std
::
string
&
base_key
)
:
dev_ctx_
(
dev_ctx
),
engine_
(
engine
),
key_common_
(
base_key
)
{
:
dev_ctx_
(
dev_ctx
),
engine_
(
engine
),
key_common_
(
base_key
)
{
if
(
platform
::
get_cur_mkldnn_session_id
()
!=
if
(
platform
::
MKLDNNDeviceContext
::
tls
().
get_cur_mkldnn_session_id
()
!=
platform
::
kMKLDNNSessionID_Default
)
{
platform
::
MKLDNNDeviceContextThreadLocals
::
kMKLDNNSessionID_Default
)
{
key_
=
key_common_
;
key_
=
key_common_
;
}
else
{
}
else
{
key_
=
key_common_
+
"-t:"
+
ThreadIDasStr
();
key_
=
key_common_
+
"-t:"
+
ThreadIDasStr
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录