Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
e336dc86
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e336dc86
编写于
5月 16, 2019
作者:
C
chengduo
提交者:
GitHub
5月 16, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Speed] Refine the Executor when the num_thread=1 (#17405)
Refine the Executor when the num_thread=1
上级
30e178fa
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
261 addition
and
113 deletion
+261
-113
paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc
...uid/framework/details/fast_threaded_ssa_graph_executor.cc
+112
-47
paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h
...luid/framework/details/fast_threaded_ssa_graph_executor.h
+18
-0
paddle/fluid/framework/details/ssa_graph_executor.cc
paddle/fluid/framework/details/ssa_graph_executor.cc
+4
-1
paddle/fluid/framework/details/ssa_graph_executor.h
paddle/fluid/framework/details/ssa_graph_executor.h
+1
-1
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
...le/fluid/framework/details/threaded_ssa_graph_executor.cc
+98
-50
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
+12
-1
python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py
...luid/tests/unittests/test_parallel_executor_fetch_feed.py
+16
-13
未找到文件。
paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc
浏览文件 @
e336dc86
...
...
@@ -43,7 +43,7 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
bootstrap_ops_
.
emplace_back
(
op
);
}
}
PADDLE_ENFORCE_GT
(
op_deps_
.
size
(),
0
,
"The graph doesn't have operators."
);
PrepareAtomicOpDeps
();
}
...
...
@@ -52,26 +52,85 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
std
::
unique_ptr
<
std
::
unordered_map
<
OpHandleBase
*
,
std
::
atomic
<
int
>>>
op_deps
=
atomic_op_deps_
.
get
();
PrepareAtomicOpDeps
();
size_t
num_ops
=
op_deps
->
size
();
paddle
::
framework
::
FeedFetchList
fetches
;
fetches
.
resize
(
fetch_tensors
.
size
());
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
VarHandleBase
*>>
fetched_vars
;
std
::
vector
<
FetchOpHandl
e
*>
fetch_ops
;
std
::
vector
<
OpHandleBas
e
*>
fetch_ops
;
std
::
vector
<
OpHandleBase
*>
ready_fetch_ops
;
exception_
.
Clear
();
InsertFetchOps
(
fetch_tensors
,
&
fetches
,
&
fetched_vars
,
op_deps
.
get
(),
&
fetch_ops
,
&
ready_fetch_ops
);
if
(
strategy_
.
num_threads_
==
1
&&
traced_ops_
.
size
()
==
num_ops
)
{
// If the num_threads is 1, we can record the order of operator's
// execution in the first iteration, and in subsequent iterations,
// run the recorded operators directly. This strategy could make the
// execution faster.
VLOG
(
3
)
<<
"Run the traced ops."
;
RunTracedOps
(
traced_ops_
);
RunTracedOps
(
fetch_ops
);
if
(
exception_
.
IsCaught
())
{
ExecutionFinal
(
&
fetch_ops
);
}
}
else
{
traced_ops_
.
clear
();
remaining_
=
0
;
auto
complete_q
=
std
::
make_shared
<
BlockingQueue
<
size_t
>>
();
for
(
auto
op
:
bootstrap_ops_
)
{
RunOpAsync
(
op_deps
.
get
(),
op
,
complete_q
);
}
for
(
auto
op
:
ready_fetch_ops
)
{
RunOpAsync
(
op_deps
.
get
(),
op
,
complete_q
);
}
size_t
num_complete
=
0
;
while
(
num_complete
!=
op_deps
->
size
())
{
size_t
num_comp
=
complete_q
->
Pop
();
if
(
num_comp
==
-
1UL
)
{
int
remaining
=
0
;
while
(
true
)
{
remaining
=
remaining_
;
if
(
remaining
==
0
)
{
break
;
}
for
(
int
i
=
0
;
i
<
remaining
;
++
i
)
{
complete_q
->
Pop
();
}
}
if
(
exception_
.
IsCaught
())
{
ExecutionFinal
(
&
fetch_ops
);
}
}
num_complete
+=
num_comp
;
}
}
// Wait FetchOps.
ClearFetchOp
(
graph_
,
&
fetch_ops
);
return
fetches
;
}
void
FastThreadedSSAGraphExecutor
::
InsertFetchOps
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
FeedFetchList
*
fetches
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
VarHandleBase
*>>
*
fetched_vars
,
std
::
unordered_map
<
OpHandleBase
*
,
std
::
atomic
<
int
>>
*
op_deps
,
std
::
vector
<
OpHandleBase
*>
*
fetch_ops
,
std
::
vector
<
OpHandleBase
*>
*
ready_fetch_ops
)
{
for
(
auto
&
fetch_var_name
:
fetch_tensors
)
{
for
(
auto
&
var_map
:
graph_
->
Get
<
details
::
GraphVars
>
(
details
::
kGraphVars
))
{
for
(
auto
&
var_map
:
graph_
->
Get
<
GraphVars
>
(
kGraphVars
))
{
auto
it
=
var_map
.
find
(
fetch_var_name
);
if
(
it
!=
var_map
.
end
())
{
fetched_vars
[
fetch_var_name
].
push_back
(
*
it
->
second
.
rbegin
());
(
*
fetched_vars
)
[
fetch_var_name
].
push_back
(
*
it
->
second
.
rbegin
());
}
}
}
for
(
size_t
i
=
0
;
i
<
fetch_tensors
.
size
();
++
i
)
{
auto
&
var_name
=
fetch_tensors
[
i
]
;
auto
fetched_var_it
=
fetched_vars
.
find
(
var_name
);
PADDLE_ENFORCE
(
fetched_var_it
!=
fetched_vars
.
end
(),
auto
&
var_name
=
fetch_tensors
.
at
(
i
)
;
auto
fetched_var_it
=
fetched_vars
->
find
(
var_name
);
PADDLE_ENFORCE
(
fetched_var_it
!=
fetched_vars
->
end
(),
"Cannot find fetched variable(%s).(Perhaps the main_program "
"is not set to ParallelExecutor)"
,
var_name
);
...
...
@@ -80,8 +139,8 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
ir
::
Node
*
fetch_node
=
graph_
->
CreateEmptyNode
(
"fetch"
,
ir
::
Node
::
Type
::
kOperation
);
auto
*
op
=
new
FetchOpHandle
(
fetch_node
,
&
fetches
,
i
,
&
local_scopes_
);
fetch_ops
.
emplace_back
(
op
);
auto
*
op
=
new
FetchOpHandle
(
fetch_node
,
fetches
,
i
,
&
local_scopes_
);
fetch_ops
->
emplace_back
(
op
);
for
(
auto
&
p
:
places_
)
{
op
->
SetDeviceContext
(
p
,
fetch_ctxs_
.
Get
(
p
));
...
...
@@ -94,55 +153,22 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
int
dep
=
static_cast
<
int
>
(
op
->
NotReadyInputSize
());
(
*
op_deps
)[
op
]
=
dep
;
if
(
dep
==
0
)
{
ready_fetch_ops
.
emplace_back
(
op
);
}
}
size_t
num_complete
=
0
;
remaining_
=
0
;
auto
complete_q
=
std
::
make_shared
<
BlockingQueue
<
size_t
>>
();
for
(
auto
op
:
bootstrap_ops_
)
{
RunOpAsync
(
op_deps
.
get
(),
op
,
complete_q
);
}
for
(
auto
op
:
ready_fetch_ops
)
{
RunOpAsync
(
op_deps
.
get
(),
op
,
complete_q
);
}
while
(
num_complete
!=
op_deps
->
size
())
{
size_t
num_comp
=
complete_q
->
Pop
();
if
(
num_comp
==
-
1UL
)
{
int
remaining
=
0
;
while
(
true
)
{
remaining
=
remaining_
;
if
(
remaining
==
0
)
{
break
;
}
for
(
int
i
=
0
;
i
<
remaining
;
++
i
)
{
complete_q
->
Pop
();
}
}
if
(
exception_
.
IsCaught
())
{
ClearFetchOp
(
graph_
,
&
fetch_ops
);
exception_
.
ReThrow
();
}
ready_fetch_ops
->
emplace_back
(
op
);
}
num_complete
+=
num_comp
;
}
// Wait FetchOps.
ClearFetchOp
(
graph_
,
&
fetch_ops
);
return
fetches
;
}
bool
FastThreadedSSAGraphExecutor
::
RunOp
(
OpHandleBase
*
op
,
const
std
::
shared_ptr
<
BlockingQueue
<
size_t
>>
&
complete_q
,
size_t
*
complete
)
{
try
{
RunOpSync
(
op
);
if
(
LIKELY
(
!
exception_
.
IsCaught
()))
{
if
(
LIKELY
(
!
strategy_
.
dry_run_
))
{
op
->
Run
(
strategy_
.
use_cuda_
);
RecordOps
(
op
);
}
++
(
*
complete
);
return
true
;
}
catch
(...)
{
exception_
.
Catch
(
std
::
current_exception
());
}
else
{
--
remaining_
;
complete_q
->
Push
(
-
1UL
);
return
false
;
...
...
@@ -194,6 +220,7 @@ void FastThreadedSSAGraphExecutor::RunOpAsync(
complete_q
->
Push
(
complete
);
});
}
void
FastThreadedSSAGraphExecutor
::
PrepareAtomicOpDeps
()
{
atomic_op_deps_
=
prepare_pool_
.
enqueue
([
&
]
{
auto
*
op_deps
=
new
std
::
unordered_map
<
OpHandleBase
*
,
std
::
atomic
<
int
>>
;
...
...
@@ -206,6 +233,44 @@ void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() {
}
const
ir
::
Graph
&
FastThreadedSSAGraphExecutor
::
Graph
()
const
{
return
*
graph_
;
}
void
FastThreadedSSAGraphExecutor
::
RecordOps
(
OpHandleBase
*
op
)
{
if
(
strategy_
.
num_threads_
==
1
&&
!
dynamic_cast
<
FetchOpHandle
*>
(
op
))
{
traced_ops_
.
emplace_back
(
op
);
}
}
void
FastThreadedSSAGraphExecutor
::
ExecutionFinal
(
std
::
vector
<
OpHandleBase
*>
*
fetch_ops
)
{
VLOG
(
3
)
<<
"caught exception "
<<
exception_
.
Type
()
<<
", rethrow it"
;
ClearFetchOp
(
graph_
,
fetch_ops
);
exception_
.
ReThrow
();
}
void
FastThreadedSSAGraphExecutor
::
RunTracedOps
(
const
std
::
vector
<
OpHandleBase
*>
&
traced_ops
)
{
for
(
auto
&
op
:
traced_ops
)
{
if
(
exception_
.
IsCaught
())
{
return
;
}
RunOpSync
(
op
);
}
}
void
FastThreadedSSAGraphExecutor
::
RunOpSync
(
OpHandleBase
*
op
)
{
try
{
if
(
VLOG_IS_ON
(
10
))
{
VLOG
(
10
)
<<
op
<<
" "
<<
op
->
Name
()
<<
" : "
<<
op
->
DebugString
();
}
if
(
LIKELY
(
!
strategy_
.
dry_run_
))
{
op
->
Run
(
strategy_
.
use_cuda_
);
}
VLOG
(
10
)
<<
op
<<
" "
<<
op
->
Name
()
<<
" Done "
;
}
catch
(...)
{
exception_
.
Catch
(
std
::
current_exception
());
}
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h
浏览文件 @
e336dc86
...
...
@@ -60,6 +60,8 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
::
ThreadPool
pool_
;
::
ThreadPool
prepare_pool_
;
std
::
vector
<
OpHandleBase
*>
traced_ops_
;
bool
RunOp
(
OpHandleBase
*
op
,
const
std
::
shared_ptr
<
BlockingQueue
<
size_t
>>
&
complete_q
,
size_t
*
complete
);
...
...
@@ -69,6 +71,22 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
const
std
::
shared_ptr
<
BlockingQueue
<
size_t
>>
&
complete_q
);
void
PrepareAtomicOpDeps
();
inline
void
RecordOps
(
OpHandleBase
*
op
);
inline
void
ExecutionFinal
(
std
::
vector
<
OpHandleBase
*>
*
fetch_ops
);
inline
void
RunOpSync
(
OpHandleBase
*
op
);
void
RunTracedOps
(
const
std
::
vector
<
OpHandleBase
*>
&
traced_ops
);
void
InsertFetchOps
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
FeedFetchList
*
fetches
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
VarHandleBase
*>>
*
fetched_vars
,
std
::
unordered_map
<
OpHandleBase
*
,
std
::
atomic
<
int
>>
*
op_deps
,
std
::
vector
<
OpHandleBase
*>
*
fetch_ops
,
std
::
vector
<
OpHandleBase
*>
*
ready_fetch_ops
);
};
}
// namespace details
}
// namespace framework
...
...
paddle/fluid/framework/details/ssa_graph_executor.cc
浏览文件 @
e336dc86
...
...
@@ -19,10 +19,13 @@ namespace framework {
namespace
details
{
SSAGraphExecutor
::~
SSAGraphExecutor
()
{}
void
ClearFetchOp
(
ir
::
Graph
*
graph
,
std
::
vector
<
FetchOpHandl
e
*>*
fetch_ops
)
{
void
ClearFetchOp
(
ir
::
Graph
*
graph
,
std
::
vector
<
OpHandleBas
e
*>*
fetch_ops
)
{
if
(
fetch_ops
->
empty
())
return
;
for
(
auto
&
op
:
*
fetch_ops
)
{
PADDLE_ENFORCE_NOT_NULL
(
dynamic_cast
<
FetchOpHandle
*>
(
op
),
"The input ops of ClearFetchOp function should be FetchOpHandle."
);
for
(
auto
&
out_var
:
op
->
Node
()
->
outputs
)
{
graph
->
RemoveNode
(
out_var
);
}
...
...
paddle/fluid/framework/details/ssa_graph_executor.h
浏览文件 @
e336dc86
...
...
@@ -38,7 +38,7 @@ class SSAGraphExecutor {
virtual
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>&
fetch_tensors
)
=
0
;
};
void
ClearFetchOp
(
ir
::
Graph
*
graph
,
std
::
vector
<
FetchOpHandl
e
*>*
fetch_ops
);
void
ClearFetchOp
(
ir
::
Graph
*
graph
,
std
::
vector
<
OpHandleBas
e
*>*
fetch_ops
);
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
浏览文件 @
e336dc86
...
...
@@ -53,74 +53,84 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
new
platform
::
RecordEvent
(
"ThreadedSSAGraphExecutorPrepare"
));
std
::
unique_ptr
<
OpDependentData
>
op_deps
=
op_deps_futures_
.
get
();
CopyOpDeps
();
VLOG
(
10
)
<<
"ThreadedSSAGraphExecutor::Run"
;
std
::
shared_ptr
<
BlockingQueue
<
VarHandleBase
*>>
ready_vars
(
new
BlockingQueue
<
VarHandleBase
*>
);
auto
&
pending_ops
=
op_deps
->
pending_ops_
;
auto
&
pending_vars
=
op_deps
->
pending_vars_
;
auto
&
ready_ops
=
op_deps
->
ready_ops_
;
// For ops (e.g. nccl_all_reduce) that need to coordinate multiple
// streams from multiple GPUs, it's faster to buffer them and schedule
// together since we currently cannot overlap computation and memcpy streams.
// Should revisit it if overlapping is available.
std
::
unordered_set
<
OpHandleBase
*>
delayed_ops
;
size_t
num_ops
=
op_deps
->
num_ops_
;
// Step 2. Insert FetchOps
std
::
vector
<
FetchOpHandl
e
*>
fetch_ops
;
std
::
vector
<
OpHandleBas
e
*>
fetch_ops
;
std
::
unordered_set
<
VarHandleBase
*>
fetch_dependencies
;
FeedFetchList
fetch_data
(
fetch_tensors
.
size
());
InsertFetchOps
(
fetch_tensors
,
&
fetch_ops
,
&
fetch_dependencies
,
&
ready_ops
,
&
pending_ops
,
&
pending_vars
,
&
fetch_data
);
auto
run_all_ops
=
[
&
](
std
::
unordered_set
<
OpHandleBase
*>
&
set
)
{
for
(
auto
*
op
:
set
)
{
RunOp
(
ready_vars
,
op
);
}
set
.
clear
();
};
// Clean run context
run_op_futures_
.
clear
();
exception_holder_
.
Clear
();
event
.
reset
(
nullptr
);
// Step 3. Execution
while
(
!
pending_vars
.
empty
())
{
// 1. Run All Ready ops
// Keep loop until all vars are ready.
run_all_ops
(
ready_ops
);
// 2. Find ready variable
bool
timeout
;
auto
cur_ready_vars
=
ready_vars
->
PopAll
(
1
,
&
timeout
);
if
(
timeout
)
{
if
(
exception_holder_
.
IsCaught
())
{
VLOG
(
3
)
<<
"caught exception "
<<
exception_holder_
.
Type
()
<<
", rethrow it"
;
if
(
strategy_
.
num_threads_
==
1
&&
traced_ops_
.
size
()
==
num_ops
)
{
// If the num_threads is 1, we can record the order of operator's
// execution in the first iteration, and in subsequent iterations,
// run the recorded operators directly. This strategy could make the
// execution faster.
VLOG
(
3
)
<<
"Run the traced ops."
;
RunTracedOps
(
traced_ops_
);
RunTracedOps
(
fetch_ops
);
if
(
exception_holder_
.
IsCaught
())
{
ExecutionFinal
(
&
fetch_ops
);
}
}
else
{
traced_ops_
.
clear
();
auto
run_all_ops
=
[
&
](
std
::
unordered_set
<
OpHandleBase
*>
&
set
)
{
for
(
auto
*
op
:
set
)
{
RunOp
(
ready_vars
,
op
);
}
set
.
clear
();
};
// Clean run context
run_op_futures_
.
clear
();
while
(
!
pending_vars
.
empty
())
{
// 1. Run All Ready ops
// Keep loop until all vars are ready.
run_all_ops
(
ready_ops
);
// 2. Find ready variable
bool
timeout
;
auto
cur_ready_vars
=
ready_vars
->
PopAll
(
1
,
&
timeout
);
if
(
timeout
)
{
for
(
auto
&
run_op_future
:
run_op_futures_
)
{
run_op_future
.
wait
();
}
ClearFetchOp
(
graph_
,
&
fetch_ops
);
exception_holder_
.
ReThrow
();
}
else
{
continue
;
if
(
exception_holder_
.
IsCaught
())
{
ExecutionFinal
(
&
fetch_ops
);
}
else
{
continue
;
}
}
}
// 3. Remove the dependency of ready_var.
// Find the ready_ops after the ready_var.
for
(
auto
ready_var
:
cur_ready_vars
)
{
pending_vars
.
erase
(
ready_var
);
for
(
auto
*
op
:
ready_var
->
PendingOps
())
{
auto
&
deps
=
pending_ops
[
op
];
--
deps
;
if
(
deps
==
0
)
{
ready_ops
.
insert
(
op
);
// 3. Remove the dependency of ready_var.
// Find the ready_ops after the ready_var.
for
(
auto
ready_var
:
cur_ready_vars
)
{
pending_vars
.
erase
(
ready_var
);
for
(
auto
*
op
:
ready_var
->
PendingOps
())
{
auto
&
deps
=
pending_ops
[
op
];
--
deps
;
if
(
deps
==
0
)
{
ready_ops
.
insert
(
op
);
}
}
}
}
PADDLE_ENFORCE
(
ready_ops
.
empty
());
}
PADDLE_ENFORCE
(
ready_ops
.
empty
());
// Wait FetchOps.
ClearFetchOp
(
graph_
,
&
fetch_ops
);
...
...
@@ -137,7 +147,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
void
ThreadedSSAGraphExecutor
::
InsertFetchOps
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
std
::
vector
<
FetchOpHandl
e
*>
*
fetch_ops
,
std
::
vector
<
OpHandleBas
e
*>
*
fetch_ops
,
std
::
unordered_set
<
VarHandleBase
*>
*
fetch_dependencies
,
std
::
unordered_set
<
OpHandleBase
*>
*
ready_ops
,
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
*
pending_ops
,
...
...
@@ -243,6 +253,9 @@ void ThreadedSSAGraphExecutor::PrepareOpDeps() {
InsertPendingOp
(
&
pending_ops
,
op
);
}
}
op_deps_
->
num_ops_
=
ready_ops
.
size
()
+
pending_ops
.
size
();
PADDLE_ENFORCE_GT
(
op_deps_
->
num_ops_
,
0
,
"The graph doesn't have operators."
);
for
(
auto
ready_var
:
ready_vars
)
{
pending_vars
.
erase
(
ready_var
);
for
(
auto
*
op
:
ready_var
->
PendingOps
())
{
...
...
@@ -264,6 +277,7 @@ void ThreadedSSAGraphExecutor::CopyOpDeps() {
op_deps_
->
pending_vars_
.
end
());
op_deps
->
ready_ops_
.
insert
(
op_deps_
->
ready_ops_
.
begin
(),
op_deps_
->
ready_ops_
.
end
());
op_deps
->
num_ops_
=
op_deps_
->
num_ops_
;
return
std
::
unique_ptr
<
OpDependentData
>
(
op_deps
);
});
}
...
...
@@ -272,25 +286,59 @@ void ThreadedSSAGraphExecutor::RunOp(
const
std
::
shared_ptr
<
BlockingQueue
<
VarHandleBase
*>>
&
ready_var_q
,
details
::
OpHandleBase
*
op
)
{
auto
op_run
=
[
ready_var_q
,
op
,
this
]
{
RunOpSync
(
op
);
try
{
if
(
VLOG_IS_ON
(
10
))
{
VLOG
(
10
)
<<
op
<<
" "
<<
op
->
Name
()
<<
" : "
<<
op
->
DebugString
();
}
if
(
LIKELY
(
!
strategy_
.
dry_run_
))
{
op
->
Run
(
strategy_
.
use_cuda_
);
}
VLOG
(
10
)
<<
op
<<
" "
<<
op
->
Name
()
<<
" Done "
;
ready_var_q
->
Extend
(
op
->
Outputs
());
VLOG
(
10
)
<<
op
<<
" "
<<
op
->
Name
()
<<
" Signal posted"
;
}
catch
(...)
{
exception_holder_
.
Catch
(
std
::
current_exception
());
}
};
if
(
pool_
)
{
run_op_futures_
.
emplace_back
(
pool_
->
enqueue
(
op_run
));
}
else
{
op_run
();
}
RecordOps
(
op
);
}
void
ThreadedSSAGraphExecutor
::
RunTracedOps
(
const
std
::
vector
<
OpHandleBase
*>
&
traced_ops
)
{
for
(
auto
&
op
:
traced_ops
)
{
if
(
exception_holder_
.
IsCaught
())
{
return
;
}
RunOpSync
(
op
);
}
}
void
ThreadedSSAGraphExecutor
::
RunOpSync
(
OpHandleBase
*
op
)
{
try
{
if
(
VLOG_IS_ON
(
10
))
{
VLOG
(
10
)
<<
op
<<
" "
<<
op
->
Name
()
<<
" : "
<<
op
->
DebugString
();
}
if
(
LIKELY
(
!
strategy_
.
dry_run_
))
{
op
->
Run
(
strategy_
.
use_cuda_
);
}
VLOG
(
10
)
<<
op
<<
" "
<<
op
->
Name
()
<<
" Done "
;
}
catch
(...)
{
exception_holder_
.
Catch
(
std
::
current_exception
());
}
}
void
ThreadedSSAGraphExecutor
::
ExecutionFinal
(
std
::
vector
<
OpHandleBase
*>
*
fetch_ops
)
{
VLOG
(
3
)
<<
"caught exception "
<<
exception_holder_
.
Type
()
<<
", rethrow it"
;
ClearFetchOp
(
graph_
,
fetch_ops
);
exception_holder_
.
ReThrow
();
}
void
ThreadedSSAGraphExecutor
::
RecordOps
(
OpHandleBase
*
op
)
{
if
(
strategy_
.
num_threads_
==
1
&&
!
dynamic_cast
<
FetchOpHandle
*>
(
op
))
{
traced_ops_
.
emplace_back
(
op
);
}
}
}
// namespace details
}
// namespace framework
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
浏览文件 @
e336dc86
...
...
@@ -44,6 +44,7 @@ struct OpDependentData {
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
pending_ops_
;
std
::
unordered_set
<
VarHandleBase
*>
pending_vars_
;
std
::
unordered_set
<
OpHandleBase
*>
ready_ops_
;
size_t
num_ops_
{
0
};
};
class
ThreadedSSAGraphExecutor
:
public
SSAGraphExecutor
{
...
...
@@ -80,6 +81,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
std
::
list
<
std
::
future
<
void
>>
run_op_futures_
;
::
ThreadPool
prepare_pool_
;
std
::
unique_ptr
<::
ThreadPool
>
pool_
;
std
::
vector
<
OpHandleBase
*>
traced_ops_
;
void
InsertPendingOp
(
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
*
pending_ops
,
OpHandleBase
*
op_instance
)
const
;
...
...
@@ -89,7 +91,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
VarHandleBase
*
var
)
const
;
void
InsertFetchOps
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
std
::
vector
<
FetchOpHandl
e
*>
*
fetch_ops
,
std
::
vector
<
OpHandleBas
e
*>
*
fetch_ops
,
std
::
unordered_set
<
VarHandleBase
*>
*
fetch_dependencies
,
std
::
unordered_set
<
OpHandleBase
*>
*
ready_ops
,
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
*
pending_ops
,
...
...
@@ -97,7 +99,16 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
FeedFetchList
*
fetch_data
);
void
PrepareOpDeps
();
void
CopyOpDeps
();
inline
void
RecordOps
(
OpHandleBase
*
op
);
inline
void
ExecutionFinal
(
std
::
vector
<
OpHandleBase
*>
*
fetch_ops
);
inline
void
RunOpSync
(
OpHandleBase
*
op
);
void
RunTracedOps
(
const
std
::
vector
<
OpHandleBase
*>
&
traced_ops
);
};
}
// namespace details
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor_fetch_feed.py
浏览文件 @
e336dc86
...
...
@@ -45,7 +45,8 @@ class TestFetchAndFeed(unittest.TestCase):
def
parallel_exe
(
self
,
use_cuda
,
run_parallel_exe
,
use_experimental_executor
=
False
,
use_faster_executor
=
False
,
num_threads
=
4
,
seed
=
1
):
main_program
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
...
...
@@ -72,7 +73,8 @@ class TestFetchAndFeed(unittest.TestCase):
build_strategy
.
enable_inplace
=
False
build_strategy
.
memory_optimize
=
False
exec_strategy
=
fluid
.
ExecutionStrategy
()
exec_strategy
.
use_experimental_executor
=
use_experimental_executor
exec_strategy
.
use_experimental_executor
=
use_faster_executor
exec_strategy
.
num_threads
=
num_threads
train_cp
=
compiler
.
CompiledProgram
(
main_program
).
with_data_parallel
(
loss_name
=
loss
.
name
,
build_strategy
=
build_strategy
,
...
...
@@ -143,24 +145,25 @@ class TestFetchAndFeed(unittest.TestCase):
if
batch_id
==
2
:
break
def
test_fetch_with_threaded_executor
(
self
):
if
core
.
is_compiled_with_cuda
():
self
.
parallel_exe
(
use_cuda
=
True
,
run_parallel_exe
=
self
.
run_parallel_exe_with_fetch
)
self
.
parallel_exe
(
use_cuda
=
False
,
run_parallel_exe
=
self
.
run_parallel_exe_with_fetch
)
def
test_fetch_with_fast_threaded_executor
(
self
):
def
check_executor
(
self
,
use_faster_executor
=
False
,
num_threads
=
4
):
if
core
.
is_compiled_with_cuda
():
self
.
parallel_exe
(
use_cuda
=
True
,
run_parallel_exe
=
self
.
run_parallel_exe_with_fetch
,
use_experimental_executor
=
True
)
use_faster_executor
=
use_faster_executor
,
num_threads
=
num_threads
)
self
.
parallel_exe
(
use_cuda
=
False
,
run_parallel_exe
=
self
.
run_parallel_exe_with_fetch
,
use_experimental_executor
=
True
)
use_faster_executor
=
use_faster_executor
,
num_threads
=
num_threads
)
def
test_fetch
(
self
):
for
use_faster_executor
in
{
True
,
False
}:
self
.
check_executor
(
use_faster_executor
=
use_faster_executor
,
num_threads
=
4
)
self
.
check_executor
(
use_faster_executor
=
use_faster_executor
,
num_threads
=
1
)
def
test_feed
(
self
):
if
core
.
is_compiled_with_cuda
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录