Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
e53f517a
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e53f517a
编写于
8月 05, 2019
作者:
P
pawelpiotrowicz
提交者:
Tao Luo
8月 05, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix for multithreading test_analyzer_image_classification --num_threads=X (#18265)
test=develop
上级
65d98752
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
99 addition
and
9 deletion
+99
-9
paddle/fluid/operators/ngraph/ngraph_engine.cc
paddle/fluid/operators/ngraph/ngraph_engine.cc
+8
-5
paddle/fluid/operators/ngraph/ngraph_engine.h
paddle/fluid/operators/ngraph/ngraph_engine.h
+91
-4
未找到文件。
paddle/fluid/operators/ngraph/ngraph_engine.cc
浏览文件 @
e53f517a
...
@@ -77,11 +77,6 @@ framework::Variable* NgraphEngine::pre_var_ptr = nullptr;
...
@@ -77,11 +77,6 @@ framework::Variable* NgraphEngine::pre_var_ptr = nullptr;
const
framework
::
BlockDesc
*
NgraphEngine
::
p_bdesc
=
nullptr
;
const
framework
::
BlockDesc
*
NgraphEngine
::
p_bdesc
=
nullptr
;
bool
NgraphEngine
::
is_training
=
false
;
bool
NgraphEngine
::
is_training
=
false
;
std
::
unordered_map
<
std
::
string
,
EngineCache
>
NgraphEngine
::
engine_cache
=
{};
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
runtime
::
Tensor
>>>
NgraphEngine
::
t_in_cache_
=
{};
std
::
shared_ptr
<
ngraph
::
runtime
::
Backend
>
NgraphEngine
::
backend_
=
std
::
shared_ptr
<
ngraph
::
runtime
::
Backend
>
NgraphEngine
::
backend_
=
ngraph
::
runtime
::
Backend
::
create
(
"CPU"
);
ngraph
::
runtime
::
Backend
::
create
(
"CPU"
);
...
@@ -453,6 +448,9 @@ std::shared_ptr<ngraph::Function> NgraphEngine::BuildNgFunction(
...
@@ -453,6 +448,9 @@ std::shared_ptr<ngraph::Function> NgraphEngine::BuildNgFunction(
}
}
void
NgraphEngine
::
ClearNgCache
()
{
void
NgraphEngine
::
ClearNgCache
()
{
auto
&
engine_cache
=
main_engine_cache
::
fetch
();
auto
&
t_in_cache_
=
main_t_in_cache
::
fetch
();
auto
it
=
engine_cache
.
begin
();
auto
it
=
engine_cache
.
begin
();
while
(
it
!=
engine_cache
.
end
())
{
while
(
it
!=
engine_cache
.
end
())
{
auto
ng_engine
=
it
->
second
;
auto
ng_engine
=
it
->
second
;
...
@@ -494,6 +492,8 @@ void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) {
...
@@ -494,6 +492,8 @@ void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) {
std
::
to_string
(
interval
[
1
])
+
engine_key
;
std
::
to_string
(
interval
[
1
])
+
engine_key
;
func_cache_key_
=
std
::
to_string
(
std
::
hash
<
std
::
string
>
()(
func_cache_key_
));
func_cache_key_
=
std
::
to_string
(
std
::
hash
<
std
::
string
>
()(
func_cache_key_
));
auto
&
engine_cache
=
main_engine_cache
::
fetch
();
if
(
engine_cache
.
find
(
func_cache_key_
)
!=
engine_cache
.
end
())
{
if
(
engine_cache
.
find
(
func_cache_key_
)
!=
engine_cache
.
end
())
{
if
(
engine_cache
[
func_cache_key_
].
persistables
.
size
()
==
0
)
{
if
(
engine_cache
[
func_cache_key_
].
persistables
.
size
()
==
0
)
{
ClearNgCache
();
ClearNgCache
();
...
@@ -533,6 +533,9 @@ void NgraphEngine::Run(const framework::Scope& scope,
...
@@ -533,6 +533,9 @@ void NgraphEngine::Run(const framework::Scope& scope,
const
std
::
vector
<
std
::
string
>*
p_var_out
;
const
std
::
vector
<
std
::
string
>*
p_var_out
;
bool
is_test
;
bool
is_test
;
auto
&
engine_cache
=
main_engine_cache
::
fetch
();
auto
&
t_in_cache_
=
main_t_in_cache
::
fetch
();
PADDLE_ENFORCE
(
engine_cache
.
find
(
func_cache_key_
)
!=
engine_cache
.
end
(),
PADDLE_ENFORCE
(
engine_cache
.
find
(
func_cache_key_
)
!=
engine_cache
.
end
(),
"Cannot find cached data to run ngraph function"
);
"Cannot find cached data to run ngraph function"
);
ng_handle
=
engine_cache
[
func_cache_key_
].
ngraph_handle
;
ng_handle
=
engine_cache
[
func_cache_key_
].
ngraph_handle
;
...
...
paddle/fluid/operators/ngraph/ngraph_engine.h
浏览文件 @
e53f517a
...
@@ -14,11 +14,13 @@ limitations under the License. */
...
@@ -14,11 +14,13 @@ limitations under the License. */
#pragma once
#pragma once
#include <list>
#include <memory>
#include <memory>
#include <set>
#include <set>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#include <utility>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
...
@@ -40,6 +42,82 @@ struct EngineCache {
...
@@ -40,6 +42,82 @@ struct EngineCache {
bool
is_test
=
true
;
bool
is_test
=
true
;
};
};
template
<
class
T
,
class
Engine
,
int
separator
=
0
>
class
NgraphThreadCache
{
public:
typedef
decltype
(
Engine
::
getMutex
())
mutex_type
;
typedef
std
::
lock_guard
<
mutex_type
>
guard_type
;
typedef
T
&
ref_type
;
enum
class
type_of_thread
{
unknown
,
forward
,
backward
};
template
<
class
S
>
struct
MetaInfo
{
std
::
thread
::
id
owner_tid
;
// owner of the cache, future use;
type_of_thread
worker_type
;
// future use
S
real_content
;
MetaInfo
()
:
owner_tid
{
std
::
this_thread
::
get_id
()},
worker_type
{
type_of_thread
::
unknown
}
{}
};
typedef
std
::
unique_ptr
<
MetaInfo
<
T
>>
content_type
;
typedef
std
::
list
<
content_type
>
storage_type
;
protected:
static
storage_type
l
;
static
mutex_type
getMutex
()
{
return
Engine
::
getMutex
();
}
static
void
remove_from_list
(
const
T
*
raw_ptr
)
{
guard_type
guard
(
getMutex
());
l
.
remove_if
([
raw_ptr
](
const
content_type
&
sh
)
{
return
&
(
sh
->
real_content
)
==
raw_ptr
;
});
}
template
<
class
TRaw
>
struct
TLSDescriptor
{
TRaw
*
raw_ptr
;
TLSDescriptor
()
:
raw_ptr
{
nullptr
}
{}
~
TLSDescriptor
()
{
// if thread die
NgraphThreadCache
::
remove_from_list
(
raw_ptr
);
/* TODO : Parallel executor swap */
// FastMultiThreadCache::keep_alive_for_backward_thread(raw_ptr);
}
};
public:
NgraphThreadCache
()
=
delete
;
NgraphThreadCache
(
const
NgraphThreadCache
&
copy
)
=
delete
;
static
T
&
fetch
()
{
thread_local
TLSDescriptor
<
T
>
tls
;
if
(
!
tls
.
raw_ptr
)
{
using
elem_type
=
typename
content_type
::
element_type
;
content_type
_p
(
new
elem_type
());
if
(
!
_p
)
PADDLE_THROW
(
"Cannot alloc memory for thread-cache "
);
guard_type
guard
(
getMutex
());
l
.
push_back
(
std
::
move
(
_p
));
tls
.
raw_ptr
=
&
l
.
back
()
->
real_content
;
}
return
*
(
tls
.
raw_ptr
);
}
auto
getSize
()
->
decltype
(
l
.
size
())
{
guard_type
guard
(
getMutex
());
return
l
.
size
();
}
template
<
class
F
>
void
for_each_cache
(
F
f
)
{
guard_type
guard
(
getMutex
());
std
::
for_each
(
l
.
begin
(),
l
.
end
(),
f
);
}
};
template
<
class
T
,
class
Engine
,
int
separator
>
typename
NgraphThreadCache
<
T
,
Engine
,
separator
>::
storage_type
NgraphThreadCache
<
T
,
Engine
,
separator
>::
l
;
// perform graph build through bridge and execute computation
// perform graph build through bridge and execute computation
class
NgraphEngine
{
class
NgraphEngine
{
public:
public:
...
@@ -57,11 +135,20 @@ class NgraphEngine {
...
@@ -57,11 +135,20 @@ class NgraphEngine {
const
framework
::
BlockDesc
&
prog
,
const
framework
::
BlockDesc
&
prog
,
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>*
ops
);
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>*
ops
);
static
std
::
recursive_mutex
&
getMutex
()
{
static
std
::
recursive_mutex
mx
;
return
mx
;
}
private:
private:
static
std
::
unordered_map
<
std
::
string
,
EngineCache
>
engine_cache
;
template
<
class
T
>
static
std
::
unordered_map
<
using
ThCache
=
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
runtime
::
Tensor
>>>
NgraphThreadCache
<
std
::
unordered_map
<
std
::
string
,
T
>
,
NgraphEngine
>
;
t_in_cache_
;
using
main_engine_cache
=
ThCache
<
EngineCache
>
;
using
main_t_in_cache
=
ThCache
<
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
runtime
::
Tensor
>>>
;
static
framework
::
Variable
*
pre_var_ptr
;
static
framework
::
Variable
*
pre_var_ptr
;
const
framework
::
Scope
&
scope_
;
const
framework
::
Scope
&
scope_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录