Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
db2b6b65
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
db2b6b65
编写于
5月 14, 2020
作者:
P
pawelpiotrowicz
提交者:
GitHub
5月 14, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Hide globals & redesign restore PR (#24279)
test=develop
上级
4a105f80
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
136 addition
and
90 deletion
+136
-90
paddle/fluid/framework/data_layout_transform.cc
paddle/fluid/framework/data_layout_transform.cc
+4
-3
paddle/fluid/framework/data_transform.cc
paddle/fluid/framework/data_transform.cc
+2
-1
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+2
-1
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+2
-2
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+10
-9
paddle/fluid/operators/controlflow/fetch_op.cc
paddle/fluid/operators/controlflow/fetch_op.cc
+4
-4
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
+2
-2
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+45
-43
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+61
-21
paddle/fluid/platform/mkldnn_reuse.h
paddle/fluid/platform/mkldnn_reuse.h
+4
-4
未找到文件。
paddle/fluid/framework/data_layout_transform.cc
浏览文件 @
db2b6b65
...
...
@@ -124,9 +124,10 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
"TransDataLayoutFromMKLDNN only supports transform from MKLDNN to "
"non-MKLDNN"
);
innerTransDataLayoutFromMKLDNN
(
in_layout
,
paddle
::
platform
::
get_cur_paddle_data_layout
(),
in
,
out
,
place
);
innerTransDataLayoutFromMKLDNN
(
in_layout
,
paddle
::
platform
::
MKLDNNDeviceContext
::
tls
().
get_cur_paddle_data_layout
(),
in
,
out
,
place
);
}
void
innerTransDataLayoutFromMKLDNN
(
DataLayout
in_layout
,
DataLayout
out_layout
,
...
...
paddle/fluid/framework/data_transform.cc
浏览文件 @
db2b6b65
...
...
@@ -59,7 +59,8 @@ void TransformData(const OpKernelType &expected_kernel_type,
// For NHWC data we need reshape of tensors as MKL-DNN
// is expecting NHWC dims description order
platform
::
MatchShapeToLayout
(
&
out
,
lin
,
lout
);
paddle
::
platform
::
set_cur_paddle_data_layout
(
lin
);
paddle
::
platform
::
MKLDNNDeviceContext
::
tls
().
set_cur_paddle_data_layout
(
lin
);
out
.
set_layout
(
DataLayout
::
kMKLDNN
);
out
.
set_format
(
out_format
);
}
else
{
...
...
paddle/fluid/framework/executor.cc
浏览文件 @
db2b6b65
...
...
@@ -89,7 +89,8 @@ Executor::~Executor() {
platform
::
MKLDNNDeviceContext
*
dev_ctx
=
(
platform
::
MKLDNNDeviceContext
*
)
pool
.
Get
(
place_
);
dev_ctx
->
ResetBlobMap
();
platform
::
set_cur_paddle_data_layout
(
paddle
::
framework
::
DataLayout
::
kNCHW
);
platform
::
MKLDNNDeviceContext
::
tls
().
set_cur_paddle_data_layout
(
paddle
::
framework
::
DataLayout
::
kNCHW
);
}
#endif
}
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
db2b6b65
...
...
@@ -1155,8 +1155,8 @@ Scope* OperatorWithKernel::PrepareData(
if
((
tensor_in
->
layout
()
==
DataLayout
::
kMKLDNN
)
&&
(
var
->
IsType
<
LoDTensor
>
()
==
true
)
&&
(
expected_kernel_key
.
data_layout_
!=
DataLayout
::
kMKLDNN
)
&&
(
paddle
::
platform
::
get_cur_paddle_data_layout
()
==
DataLayout
::
kNHWC
))
{
(
paddle
::
platform
::
MKLDNNDeviceContext
::
tls
()
.
get_cur_paddle_data_layout
()
==
DataLayout
::
kNHWC
))
{
// Mixed execution : MKL-DNN and GPU is not supported!
if
(
!
new_scope
)
{
new_scope
=
&
scope
.
NewScope
();
...
...
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
db2b6b65
...
...
@@ -244,13 +244,14 @@ bool AnalysisPredictor::PrepareExecutor() {
void
AnalysisPredictor
::
MkldnnPreSet
(
const
std
::
vector
<
PaddleTensor
>
&
inputs
)
{
#ifdef PADDLE_WITH_MKLDNN
VLOG
(
2
)
<<
"AnalysisPredictor::Run get_cur_mkldnn_session_id="
<<
platform
::
get_cur_mkldnn_session_id
();
<<
platform
::
MKLDNNDeviceContext
::
tls
().
get_cur_mkldnn_session_id
();
// In cache clearing mode.
if
(
config_
.
mkldnn_cache_capacity_
>
0
)
{
VLOG
(
2
)
<<
"In mkldnn cache clear mode."
;
platform
::
set_cur_mkldnn_session_id
(
platform
::
kMKLDNNSessionID_CacheClearing
);
platform
::
set_cur_input_shape_cache_capacity
(
platform
::
MKLDNNDeviceContext
::
tls
().
set_cur_mkldnn_session_id
(
platform
::
MKLDNNDeviceContextThreadLocals
::
kMKLDNNSessionID_CacheClearing
);
platform
::
MKLDNNDeviceContext
::
tls
().
set_cur_input_shape_cache_capacity
(
config_
.
mkldnn_cache_capacity_
);
// Set current_input_shape for caching dynamic shape.
std
::
stringstream
ss
;
...
...
@@ -260,7 +261,7 @@ void AnalysisPredictor::MkldnnPreSet(const std::vector<PaddleTensor> &inputs) {
}
}
VLOG
(
2
)
<<
"Set input shape="
<<
ss
.
str
();
platform
::
set_cur_input_shape_str
(
ss
.
str
());
platform
::
MKLDNNDeviceContext
::
tls
().
set_cur_input_shape_str
(
ss
.
str
());
}
#endif
}
...
...
@@ -277,10 +278,10 @@ void AnalysisPredictor::MkldnnPostReset() {
CHECK_LE
(
shape_blob_size
,
static_cast
<
size_t
>
(
config_
.
mkldnn_cache_capacity_
));
}
paddle
::
platform
::
set_cur_mkldnn_session_id
(
platform
::
kMKLDNNSessionID_Default
);
platform
::
set_cur_input_shape_cache_capacity
(
0
);
platform
::
set_cur_input_shape_str
(
""
);
paddle
::
platform
::
MKLDNNDeviceContext
::
tls
().
set_cur_mkldnn_session_id
(
platform
::
MKLDNNDeviceContextThreadLocals
::
kMKLDNNSessionID_Default
);
platform
::
MKLDNNDeviceContext
::
tls
().
set_cur_input_shape_cache_capacity
(
0
);
platform
::
MKLDNNDeviceContext
::
tls
().
set_cur_input_shape_str
(
""
);
}
#endif
}
...
...
paddle/fluid/operators/controlflow/fetch_op.cc
浏览文件 @
db2b6b65
...
...
@@ -34,10 +34,10 @@ static void DataCopy(const framework::LoDTensor &src_item,
// Convert to desired Paddle layout, apart from grads of filter
// as params are not a subject to paddle's data_format
framework
::
innerTransDataLayoutFromMKLDNN
(
src_item
.
layout
(),
fetch_var_name
==
framework
::
GradVarName
(
"Filter"
)
?
framework
::
DataLayout
::
kNCHW
:
paddle
::
platform
::
get_cur_paddle_data_layout
(),
src_item
.
layout
(),
fetch_var_name
==
framework
::
GradVarName
(
"Filter"
)
?
framework
::
DataLayout
::
kNCHW
:
paddle
::
platform
::
MKLDNNDeviceContext
::
tls
()
.
get_cur_paddle_data_layout
(),
src_item
,
&
out
,
platform
::
CPUPlace
());
TensorCopySync
(
out
,
platform
::
CPUPlace
(),
dst_item
);
}
else
{
...
...
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
浏览文件 @
db2b6b65
...
...
@@ -446,8 +446,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// of conv int8 mkl-dnn. Once conv fp32 and conv int8
// are merged/unified, this will disappear
std
::
string
key_tid
=
""
;
if
(
platform
::
get_cur_mkldnn_session_id
()
==
platform
::
kMKLDNNSessionID_Default
)
{
if
(
platform
::
MKLDNNDeviceContext
::
tls
().
get_cur_mkldnn_session_id
()
==
platform
::
MKLDNNDeviceContextThreadLocals
::
kMKLDNNSessionID_Default
)
{
key_tid
=
"-t:"
+
platform
::
ThreadIDasStr
();
}
...
...
paddle/fluid/platform/device_context.cc
浏览文件 @
db2b6b65
...
...
@@ -375,36 +375,37 @@ MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
p_mutex_
.
reset
(
new
std
::
mutex
());
}
namespace
{
// Current mkldnn session id.
thread_local
size_t
cur_mkldnn_session_id
=
kMKLDNNSessionID_Default
;
// Current data input shape string.
// - For fixed-shape, it's a null string in default.
// - For dynamic-shape, it's user specific.
thread_local
std
::
string
cur_input_shape_str
=
""
;
// the cache capacity of different input shapes for MKLDNN.
// Default 1 means fixed input shape, not dynamic shape.
thread_local
int
cur_input_shape_cache_capacity
=
1
;
// Recently registered data_format. This is needed to
// know for converting MKL-DNN Tensor to non MKL-DNN
thread_local
paddle
::
framework
::
DataLayout
cur_paddle_data_layout
=
paddle
::
framework
::
DataLayout
::
kNCHW
;
}
// namespace
void
set_cur_mkldnn_session_id
(
size_t
sid
)
{
cur_mkldnn_session_id
=
sid
;
}
size_t
get_cur_mkldnn_session_id
(
void
)
{
return
cur_mkldnn_session_id
;
}
void
set_cur_input_shape_str
(
std
::
string
input_shape_str
)
{
MKLDNNDeviceContextThreadLocals
::
Body
::
Body
()
{
cur_mkldnn_session_id
=
kMKLDNNSessionID_Default
;
cur_input_shape_str
=
""
;
cur_input_shape_cache_capacity
=
1
;
cur_paddle_data_layout
=
paddle
::
framework
::
DataLayout
::
kNCHW
;
}
void
MKLDNNDeviceContextThreadLocals
::
Body
::
set_cur_mkldnn_session_id
(
size_t
sid
)
{
cur_mkldnn_session_id
=
sid
;
}
size_t
MKLDNNDeviceContextThreadLocals
::
Body
::
get_cur_mkldnn_session_id
(
void
)
{
return
cur_mkldnn_session_id
;
}
void
MKLDNNDeviceContextThreadLocals
::
Body
::
set_cur_input_shape_str
(
std
::
string
input_shape_str
)
{
cur_input_shape_str
=
input_shape_str
;
}
void
set_cur_input_shape_cache_capacity
(
int
input_shape_cache_capacity
)
{
void
MKLDNNDeviceContextThreadLocals
::
Body
::
set_cur_input_shape_cache_capacity
(
int
input_shape_cache_capacity
)
{
cur_input_shape_cache_capacity
=
input_shape_cache_capacity
;
}
void
set_cur_paddle_data_layout
(
framework
::
DataLayout
dl
)
{
void
MKLDNNDeviceContextThreadLocals
::
Body
::
set_cur_paddle_data_layout
(
framework
::
DataLayout
dl
)
{
cur_paddle_data_layout
=
dl
;
}
framework
::
DataLayout
get_cur_paddle_data_layout
(
void
)
{
framework
::
DataLayout
MKLDNNDeviceContextThreadLocals
::
Body
::
get_cur_paddle_data_layout
(
void
)
{
return
cur_paddle_data_layout
;
}
...
...
@@ -414,32 +415,32 @@ void MKLDNNDeviceContext::ResetBlobMap() const {
}
size_t
MKLDNNDeviceContext
::
GetShapeBlobSize
()
const
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
*
p_mutex_
);
std
::
lock_guard
<
decltype
(
*
p_mutex_
)
>
lock
(
*
p_mutex_
);
BlobMap
*
pMap
=
p_blobmap_
.
get
();
auto
map_it
=
pMap
->
find
(
cur_mkldnn_session_id
);
auto
map_it
=
pMap
->
find
(
tls
().
cur_mkldnn_session_id
);
if
(
map_it
==
pMap
->
end
())
{
LOG
(
FATAL
)
<<
"MKLDNNDeviceContext don't find cur_mkldnn_session_id : "
<<
cur_mkldnn_session_id
;
<<
tls
().
cur_mkldnn_session_id
;
}
return
map_it
->
second
->
size
();
}
void
MKLDNNDeviceContext
::
SetBlob
(
const
std
::
string
&
name
,
std
::
shared_ptr
<
void
>
data
)
const
{
BlobPtr_t
<
void
>
data
)
const
{
BlobMap
*
pMap
=
p_blobmap_
.
get
();
std
::
shared_ptr
<
ShapeBlob
>
sBlob
=
nullptr
;
std
::
shared_ptr
<
KeyBlob
>
pBlob
=
nullptr
;
BlobPtr_t
<
ShapeBlob
>
sBlob
=
nullptr
;
BlobPtr_t
<
KeyBlob
>
pBlob
=
nullptr
;
int
sid
=
platform
::
get_cur_mkldnn_session_id
();
int
sid
=
tls
().
get_cur_mkldnn_session_id
();
std
::
lock_guard
<
std
::
mutex
>
lock
(
*
p_mutex_
);
std
::
lock_guard
<
decltype
(
*
p_mutex_
)
>
lock
(
*
p_mutex_
);
// Find ShapeBlob for current mkldnn session id.
auto
map_it
=
pMap
->
find
(
sid
);
if
(
map_it
==
pMap
->
end
())
{
// 1st time to set blob in current thread
sBlob
=
std
::
shared_ptr
<
ShapeBlob
>
(
new
ShapeBlob
()
);
sBlob
=
std
::
make_shared
<
ShapeBlob
>
(
);
(
*
pMap
)[
sid
]
=
sBlob
;
VLOG
(
2
)
<<
"SetBlob: sid="
<<
sid
<<
", add new sid
\n
"
;
}
else
{
...
...
@@ -447,21 +448,22 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name,
}
// Find KeyBlob for current input shape
auto
key_it
=
sBlob
->
find
(
cur_input_shape_str
);
auto
key_it
=
sBlob
->
find
(
tls
().
cur_input_shape_str
);
if
(
key_it
==
sBlob
->
end
())
{
// In cache clearing mode, cur_input_shape_cache_capacity defines
// max pblob capacity
if
((
static_cast
<
size_t
>
(
sid
)
==
kMKLDNNSessionID_CacheClearing
)
&&
if
((
static_cast
<
size_t
>
(
sid
)
==
MKLDNNDeviceContextThreadLocals
::
kMKLDNNSessionID_CacheClearing
)
&&
sBlob
->
size
()
&&
(
sBlob
->
size
()
>=
static_cast
<
size_t
>
(
cur_input_shape_cache_capacity
)))
{
static_cast
<
size_t
>
(
tls
().
cur_input_shape_cache_capacity
)))
{
VLOG
(
2
)
<<
"sid="
<<
sid
<<
", remove all blobs of shape: "
<<
sBlob
->
begin
()
->
first
;
sBlob
->
erase
(
sBlob
->
begin
()
->
first
);
}
pBlob
=
std
::
shared_ptr
<
KeyBlob
>
(
new
KeyBlob
()
);
(
*
sBlob
)[
cur_input_shape_str
]
=
pBlob
;
pBlob
=
std
::
make_shared
<
KeyBlob
>
(
);
(
*
sBlob
)[
tls
().
cur_input_shape_str
]
=
pBlob
;
}
else
{
pBlob
=
key_it
->
second
;
}
...
...
@@ -478,15 +480,15 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name,
return
;
}
std
::
shared_ptr
<
void
>
MKLDNNDeviceContext
::
GetBlob
(
MKLDNNDeviceContext
::
BlobPtr_t
<
void
>
MKLDNNDeviceContext
::
GetBlob
(
const
std
::
string
&
name
)
const
{
BlobMap
*
pMap
=
p_blobmap_
.
get
();
std
::
shared_ptr
<
ShapeBlob
>
sBlob
=
nullptr
;
std
::
shared_ptr
<
KeyBlob
>
pBlob
=
nullptr
;
BlobPtr_t
<
ShapeBlob
>
sBlob
=
nullptr
;
BlobPtr_t
<
KeyBlob
>
pBlob
=
nullptr
;
int
sid
=
platform
::
get_cur_mkldnn_session_id
();
int
sid
=
tls
().
get_cur_mkldnn_session_id
();
std
::
lock_guard
<
std
::
mutex
>
lock
(
*
p_mutex_
);
std
::
lock_guard
<
decltype
(
*
p_mutex_
)
>
lock
(
*
p_mutex_
);
// Find ShapeBlob for current mkldnn session id firstly
auto
map_it
=
pMap
->
find
(
sid
);
...
...
@@ -497,9 +499,9 @@ std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
sBlob
=
map_it
->
second
;
// Find KeyBlob for current input shape secondly
auto
sBlob_it
=
sBlob
->
find
(
cur_input_shape_str
);
auto
sBlob_it
=
sBlob
->
find
(
tls
().
cur_input_shape_str
);
if
(
sBlob_it
==
sBlob
->
end
())
{
VLOG
(
2
)
<<
"GetBlob: sid="
<<
cur_input_shape_str
VLOG
(
2
)
<<
"GetBlob: sid="
<<
tls
().
cur_input_shape_str
<<
", miss input_shape_str
\n
"
;
return
nullptr
;
}
...
...
paddle/fluid/platform/device_context.h
浏览文件 @
db2b6b65
...
...
@@ -421,30 +421,66 @@ struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
#endif
#ifdef PADDLE_WITH_MKLDNN
// Following three maps are used to cache MKLDNN primitives.
// There relations are:
// - BlobMap = Map<cur_thread_id, ShapeBlob>
// - ShapeBlob = Map<cur_input_shape_str, KeyBlob>
// - KeyBlob = Map<blob_name, blob>
// Where:
using
KeyBlob
=
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
void
>>
;
using
ShapeBlob
=
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
KeyBlob
>>
;
using
BlobMap
=
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
ShapeBlob
>>
;
// default mkldnn session id
constexpr
size_t
kMKLDNNSessionID_Default
=
0
;
// mkldnn session id for cache clearing mode
constexpr
size_t
kMKLDNNSessionID_CacheClearing
=
-
1
;
void
set_cur_mkldnn_session_id
(
size_t
);
size_t
get_cur_mkldnn_session_id
(
void
);
void
set_cur_input_shape_str
(
std
::
string
input_shape_str
);
void
set_cur_input_shape_cache_capacity
(
int
input_shape_cache_capacity
);
void
set_cur_paddle_data_layout
(
framework
::
DataLayout
);
framework
::
DataLayout
get_cur_paddle_data_layout
(
void
);
class
MKLDNNDeviceContextThreadLocals
{
// default mkldnn session id
typedef
MKLDNNDeviceContextThreadLocals
self
;
struct
Body
{
size_t
cur_mkldnn_session_id
;
// Current data input shape string.
// - For fixed-shape, it's a null string in default.
// - For dynamic-shape, it's user specific.
std
::
string
cur_input_shape_str
;
// the cache capacity of different input shapes for MKLDNN.
// Default 1 means fixed input shape, not dynamic shape.
int
cur_input_shape_cache_capacity
;
// Recently registered data_format. This is needed to
// know for converting MKL-DNN Tensor to non MKL-DNN
paddle
::
framework
::
DataLayout
cur_paddle_data_layout
;
Body
();
void
set_cur_mkldnn_session_id
(
size_t
sid
);
size_t
get_cur_mkldnn_session_id
(
void
);
void
set_cur_input_shape_str
(
std
::
string
input_shape_str
);
void
set_cur_input_shape_cache_capacity
(
int
input_shape_cache_capacity
);
void
set_cur_paddle_data_layout
(
framework
::
DataLayout
dl
);
framework
::
DataLayout
get_cur_paddle_data_layout
(
void
);
};
MKLDNNDeviceContextThreadLocals
()
=
default
;
MKLDNNDeviceContextThreadLocals
(
const
MKLDNNDeviceContextThreadLocals
&
c
)
=
delete
;
public:
// default mkldnn session id
static
constexpr
size_t
kMKLDNNSessionID_Default
=
0
;
// mkldnn session id for cache clearing mode
static
constexpr
size_t
kMKLDNNSessionID_CacheClearing
=
-
1
;
static
Body
&
fetch
()
{
thread_local
Body
b
;
return
b
;
}
};
class
MKLDNNDeviceContext
:
public
CPUDeviceContext
{
public:
template
<
class
T
>
using
BlobPtr_t
=
std
::
shared_ptr
<
T
>
;
template
<
class
P1
,
class
P2
>
using
umap_value_smart_t
=
std
::
unordered_map
<
P1
,
BlobPtr_t
<
P2
>>
;
template
<
class
T
>
using
umap_key_string_t
=
umap_value_smart_t
<
std
::
string
,
T
>
;
// Following three maps are used to cache MKLDNN primitives.
// There relations are:
// - BlobMap = Map<cur_thread_id, ShapeBlob>
// - ShapeBlob = Map<cur_input_shape_str, KeyBlob>
// - KeyBlob = Map<blob_name, blob>
using
KeyBlob
=
umap_key_string_t
<
void
>
;
using
ShapeBlob
=
umap_key_string_t
<
KeyBlob
>
;
using
BlobMap
=
umap_value_smart_t
<
int
,
ShapeBlob
>
;
explicit
MKLDNNDeviceContext
(
CPUPlace
place
);
/* \brief Get the active engine */
...
...
@@ -462,6 +498,10 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
// Find a saved blob. Return nullptr if not found
std
::
shared_ptr
<
void
>
GetBlob
(
const
std
::
string
&
name
)
const
;
static
auto
tls
()
->
decltype
(
MKLDNNDeviceContextThreadLocals
::
fetch
())
{
return
MKLDNNDeviceContextThreadLocals
::
fetch
();
}
private:
mkldnn
::
engine
engine_
;
std
::
shared_ptr
<
BlobMap
>
p_blobmap_
;
...
...
paddle/fluid/platform/mkldnn_reuse.h
浏览文件 @
db2b6b65
...
...
@@ -42,8 +42,8 @@ class MKLDNNHandlerT {
key_common_
(
base_key
),
fwd_pd_
(
nullptr
),
bwd_pd_
(
nullptr
)
{
if
(
platform
::
get_cur_mkldnn_session_id
()
!=
platform
::
kMKLDNNSessionID_Default
)
{
if
(
platform
::
MKLDNNDeviceContext
::
tls
().
get_cur_mkldnn_session_id
()
!=
platform
::
MKLDNNDeviceContextThreadLocals
::
kMKLDNNSessionID_Default
)
{
key_
=
key_common_
;
}
else
{
key_
=
key_common_
+
"-t:"
+
ThreadIDasStr
();
...
...
@@ -177,8 +177,8 @@ class MKLDNNHandler {
MKLDNNHandler
(
const
MKLDNNDeviceContext
&
dev_ctx
,
mkldnn
::
engine
engine
,
const
std
::
string
&
base_key
)
:
dev_ctx_
(
dev_ctx
),
engine_
(
engine
),
key_common_
(
base_key
)
{
if
(
platform
::
get_cur_mkldnn_session_id
()
!=
platform
::
kMKLDNNSessionID_Default
)
{
if
(
platform
::
MKLDNNDeviceContext
::
tls
().
get_cur_mkldnn_session_id
()
!=
platform
::
MKLDNNDeviceContextThreadLocals
::
kMKLDNNSessionID_Default
)
{
key_
=
key_common_
;
}
else
{
key_
=
key_common_
+
"-t:"
+
ThreadIDasStr
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录