Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
3f3112ce
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3f3112ce
编写于
7月 03, 2019
作者:
T
Tao Luo
提交者:
GitHub
7月 03, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add shape_blob for cache mkldnn primitive (#18454)
test=develop
上级
d234aa02
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
62 addition
and
15 deletion
+62
-15
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+52
-14
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+10
-1
未找到文件。
paddle/fluid/platform/device_context.cc
浏览文件 @
3f3112ce
...
...
@@ -403,42 +403,62 @@ MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
namespace
{
// Current mkldnn session id.
thread_local
size_t
cur_mkldnn_session_id
=
kMKLDNNSessionID_Default
;
}
// Current data input shape string.
// - For fixed-shape, it's a null string in default.
// - For dynamic-shape, it's user specific.
thread_local
std
::
string
cur_input_shape_str
=
""
;
}
// namespace
void
set_cur_mkldnn_session_id
(
size_t
sid
)
{
cur_mkldnn_session_id
=
sid
;
}
size_t
get_cur_mkldnn_session_id
(
void
)
{
return
cur_mkldnn_session_id
;
}
void
set_cur_input_shape_str
(
std
::
string
input_shape_str
)
{
cur_input_shape_str
=
input_shape_str
;
}
std
::
string
get_cur_input_shape_str
(
void
)
{
return
cur_input_shape_str
;
}
void
MKLDNNDeviceContext
::
ResetBlobMap
()
const
{
p_blobmap_
->
clear
();
}
void
MKLDNNDeviceContext
::
SetBlob
(
const
std
::
string
&
name
,
std
::
shared_ptr
<
void
>
data
)
const
{
BlobMap
*
pMap
=
p_blobmap_
.
get
();
std
::
shared_ptr
<
ShapeBlob
>
sBlob
=
nullptr
;
std
::
shared_ptr
<
KeyBlob
>
pBlob
=
nullptr
;
int
tid
=
platform
::
get_cur_mkldnn_session_id
();
std
::
lock_guard
<
std
::
mutex
>
lock
(
*
p_mutex_
);
// Find
Key
Blob for current thread
// Find
Shape
Blob for current thread
auto
map_it
=
pMap
->
find
(
tid
);
if
(
map_it
==
pMap
->
end
())
{
// 1st time to set blob in current thread
pBlob
=
std
::
shared_ptr
<
KeyBlob
>
(
new
KeyBlob
());
(
*
pMap
)[
tid
]
=
pBlob
;
sBlob
=
std
::
shared_ptr
<
ShapeBlob
>
(
new
ShapeBlob
());
(
*
pMap
)[
tid
]
=
sBlob
;
VLOG
(
2
)
<<
"SetBlob: tid="
<<
tid
<<
", add new tid
\n
"
;
}
else
{
p
Blob
=
map_it
->
second
;
s
Blob
=
map_it
->
second
;
}
// Find Key in found (or newly created) KeyBlob
auto
key_it
=
pBlob
->
find
(
name
);
// Find KeyBlob for current input shape
std
::
string
cur_input_shape_str
=
platform
::
get_cur_input_shape_str
();
auto
key_it
=
sBlob
->
find
(
cur_input_shape_str
);
if
(
key_it
==
pBlob
->
end
())
{
(
*
pBlob
)[
name
]
=
data
;
// create new blob
if
(
key_it
==
sBlob
->
end
())
{
pBlob
=
std
::
shared_ptr
<
KeyBlob
>
(
new
KeyBlob
());
(
*
sBlob
)[
cur_input_shape_str
]
=
pBlob
;
}
else
{
key_it
->
second
=
data
;
// set data to existing blob
pBlob
=
key_it
->
second
;
}
// Find Blob via name
auto
blob_it
=
pBlob
->
find
(
name
);
if
(
blob_it
==
pBlob
->
end
())
{
(
*
pBlob
)[
name
]
=
data
;
}
else
{
blob_it
->
second
=
data
;
// set data to existing blob
}
VLOG
(
2
)
<<
"SetBlob: tid="
<<
tid
<<
", add blob="
<<
name
<<
"
\n
"
;
// lock will be automatically released when out of scope
return
;
}
...
...
@@ -446,22 +466,40 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name,
std
::
shared_ptr
<
void
>
MKLDNNDeviceContext
::
GetBlob
(
const
std
::
string
&
name
)
const
{
BlobMap
*
pMap
=
p_blobmap_
.
get
();
std
::
shared_ptr
<
ShapeBlob
>
sBlob
=
nullptr
;
std
::
shared_ptr
<
KeyBlob
>
pBlob
=
nullptr
;
int
tid
=
platform
::
get_cur_mkldnn_session_id
();
std
::
lock_guard
<
std
::
mutex
>
lock
(
*
p_mutex_
);
// Find
Key
Blob for current thread firstly
// Find
Shape
Blob for current thread firstly
auto
map_it
=
pMap
->
find
(
tid
);
if
(
map_it
==
pMap
->
end
())
return
nullptr
;
pBlob
=
map_it
->
second
;
if
(
map_it
==
pMap
->
end
())
{
VLOG
(
2
)
<<
"GetBlob: tid="
<<
tid
<<
", miss tid
\n
"
;
return
nullptr
;
}
std
::
string
cur_input_shape_str
=
platform
::
get_cur_input_shape_str
();
sBlob
=
map_it
->
second
;
// Find KeyBlob for current input shape secondly
auto
sBlob_it
=
sBlob
->
find
(
cur_input_shape_str
);
if
(
sBlob_it
==
sBlob
->
end
())
{
VLOG
(
2
)
<<
"GetBlob: tid="
<<
cur_input_shape_str
<<
", miss input_shape_str
\n
"
;
return
nullptr
;
}
pBlob
=
sBlob_it
->
second
;
// Find Blob via name
auto
key_it
=
pBlob
->
find
(
name
);
if
(
key_it
==
pBlob
->
end
())
return
nullptr
;
if
(
key_it
==
pBlob
->
end
())
{
VLOG
(
2
)
<<
"GetBlob tid="
<<
tid
<<
", miss blob="
<<
name
<<
"
\n
"
;
return
nullptr
;
}
VLOG
(
2
)
<<
"GetBlob tid="
<<
tid
<<
", get blob="
<<
name
<<
"
\n
"
;
// lock will be automatically released when out of scope
return
key_it
->
second
;
}
...
...
paddle/fluid/platform/device_context.h
浏览文件 @
3f3112ce
...
...
@@ -378,8 +378,15 @@ struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
#endif
#ifdef PADDLE_WITH_MKLDNN
// Following three maps are used to cache MKLDNN primitives.
// There relations are:
// - BlobMap = Map<cur_thread_id, ShapeBlob>
// - ShapeBlob = Map<cur_input_shape_str, KeyBlob>
// - KeyBlob = Map<blob_name, blob>
// Where:
using
KeyBlob
=
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
void
>>
;
using
BlobMap
=
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
KeyBlob
>>
;
using
ShapeBlob
=
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
KeyBlob
>>
;
using
BlobMap
=
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
ShapeBlob
>>
;
// default mkldnn session id
constexpr
size_t
kMKLDNNSessionID_Default
=
0
;
...
...
@@ -388,6 +395,8 @@ constexpr size_t kMKLDNNSessionID_CacheClearing = -1;
void
set_cur_mkldnn_session_id
(
size_t
);
size_t
get_cur_mkldnn_session_id
(
void
);
void
set_cur_input_shape_str
(
std
::
string
input_shape_str
);
std
::
string
get_cur_input_shape_str
(
void
);
class
MKLDNNDeviceContext
:
public
CPUDeviceContext
{
public:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录