Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
d0ac9253
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d0ac9253
编写于
4月 01, 2018
作者:
X
Xin Pan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Improve ParallelExecutor performance
上级
dd75fbde
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
63 addition
and
14 deletion
+63
-14
paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc
paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc
+1
-1
paddle/fluid/framework/details/nccl_all_reduce_op_handle.h
paddle/fluid/framework/details/nccl_all_reduce_op_handle.h
+5
-0
paddle/fluid/framework/details/op_handle_base.h
paddle/fluid/framework/details/op_handle_base.h
+4
-0
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
...le/fluid/framework/details/threaded_ssa_graph_executor.cc
+39
-10
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
+10
-2
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+2
-0
python/paddle/fluid/parallel_executor.py
python/paddle/fluid/parallel_executor.py
+2
-1
未找到文件。
paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc
浏览文件 @
d0ac9253
...
@@ -76,7 +76,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
...
@@ -76,7 +76,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
}
}
}
}
std
::
string
NCCLAllReduceOpHandle
::
Name
()
const
{
return
"
NCCL AllR
educe"
;
}
std
::
string
NCCLAllReduceOpHandle
::
Name
()
const
{
return
"
nccl_all_r
educe"
;
}
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/details/nccl_all_reduce_op_handle.h
浏览文件 @
d0ac9253
...
@@ -14,6 +14,9 @@
...
@@ -14,6 +14,9 @@
#pragma once
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
...
@@ -34,6 +37,8 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {
...
@@ -34,6 +37,8 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {
std
::
string
Name
()
const
override
;
std
::
string
Name
()
const
override
;
bool
IsDelayedOp
()
override
{
return
true
;
};
protected:
protected:
void
RunImpl
()
override
;
void
RunImpl
()
override
;
};
};
...
...
paddle/fluid/framework/details/op_handle_base.h
浏览文件 @
d0ac9253
...
@@ -13,6 +13,8 @@
...
@@ -13,6 +13,8 @@
// limitations under the License.
// limitations under the License.
#pragma once
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
...
@@ -53,6 +55,8 @@ class OpHandleBase {
...
@@ -53,6 +55,8 @@ class OpHandleBase {
void
AddOutput
(
VarHandleBase
*
out
);
void
AddOutput
(
VarHandleBase
*
out
);
virtual
bool
IsDelayedOp
()
{
return
false
;
}
protected:
protected:
virtual
void
RunImpl
()
=
0
;
virtual
void
RunImpl
()
=
0
;
};
};
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
浏览文件 @
d0ac9253
...
@@ -29,17 +29,27 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
...
@@ -29,17 +29,27 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
local_scopes_
(
local_scopes
),
local_scopes_
(
local_scopes
),
places_
(
places
),
places_
(
places
),
fetch_ctxs_
(
places
),
fetch_ctxs_
(
places
),
use_event_
(
use_event
)
{}
use_event_
(
use_event
),
running_ops_
(
0
)
{}
void
ThreadedSSAGraphExecutor
::
RunDelayedOps
(
const
std
::
unordered_set
<
OpHandleBase
*>
&
delayed_ops
)
{
for
(
auto
op
:
delayed_ops
)
{
op
->
Run
(
use_event_
);
}
}
FeedFetchList
ThreadedSSAGraphExecutor
::
Run
(
FeedFetchList
ThreadedSSAGraphExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
{
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
{
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
pending_ops
;
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
pending_ops
;
std
::
unordered_set
<
VarHandleBase
*>
pending_vars
;
std
::
unordered_set
<
VarHandleBase
*>
pending_vars
;
BlockingQueue
<
VarHandleBase
*>
ready_vars
;
BlockingQueue
<
VarHandleBase
*>
ready_vars
;
std
::
unordered_set
<
OpHandleBase
*>
ready_ops
;
std
::
unordered_set
<
OpHandleBase
*>
ready_ops
;
std
::
unordered_set
<
OpHandleBase
*>
delayed_ops
;
std
::
unordered_set
<
OpHandleBase
*>
after_delayed_ops
;
std
::
unordered_set
<
VarHandleBase
*>
delayed_vars
;
auto
InsertPendingVar
=
[
&
pending_vars
,
&
ready_vars
](
VarHandleBase
&
var
)
{
auto
InsertPendingVar
=
[
&
pending_vars
,
&
ready_vars
](
VarHandleBase
&
var
)
{
pending_vars
.
insert
(
&
var
);
pending_vars
.
insert
(
&
var
);
if
(
var
.
generated_op_
==
nullptr
)
{
if
(
var
.
generated_op_
==
nullptr
)
{
...
@@ -106,7 +116,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
...
@@ -106,7 +116,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto
run_all_ready_ops
=
[
&
]
{
auto
run_all_ready_ops
=
[
&
]
{
for
(
auto
*
op
:
ready_ops
)
{
for
(
auto
*
op
:
ready_ops
)
{
RunOp
(
ready_vars
,
op
);
if
(
op
->
IsDelayedOp
())
{
delayed_ops
.
insert
(
op
);
delayed_vars
.
insert
(
op
->
outputs_
.
begin
(),
op
->
outputs_
.
end
());
ready_vars
.
Extend
(
op
->
outputs_
);
continue
;
}
running_ops_
++
;
RunOp
(
&
ready_vars
,
op
);
}
}
ready_ops
.
clear
();
ready_ops
.
clear
();
};
};
...
@@ -124,7 +141,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
...
@@ -124,7 +141,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// 2. Find ready variable
// 2. Find ready variable
bool
timeout
;
bool
timeout
;
auto
cur_ready_vars
=
ready_vars
.
PopAll
(
1
000
,
&
timeout
);
auto
cur_ready_vars
=
ready_vars
.
PopAll
(
1
,
&
timeout
);
if
(
timeout
)
{
if
(
timeout
)
{
if
(
exception_
)
{
if
(
exception_
)
{
...
@@ -141,13 +158,24 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
...
@@ -141,13 +158,24 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto
&
deps
=
pending_ops
[
op
];
auto
&
deps
=
pending_ops
[
op
];
--
deps
;
--
deps
;
if
(
deps
==
0
)
{
if
(
deps
==
0
)
{
if
(
delayed_vars
.
find
(
ready_var
)
!=
delayed_vars
.
end
())
{
after_delayed_ops
.
insert
(
op
);
}
else
{
ready_ops
.
insert
(
op
);
ready_ops
.
insert
(
op
);
}
}
}
}
}
}
}
if
(
ready_ops
.
empty
()
&&
!
delayed_ops
.
empty
()
&&
running_ops_
==
0
)
{
RunDelayedOps
(
delayed_ops
);
delayed_ops
.
clear
();
for
(
auto
*
op
:
after_delayed_ops
)
{
ready_ops
.
insert
(
op
);
}
after_delayed_ops
.
clear
();
}
// Keep loop until all vars are ready.
// Keep loop until all vars are ready.
}
}
++
computation_count_
;
++
computation_count_
;
auto
sync_computation
=
[
&
]
{
auto
sync_computation
=
[
&
]
{
...
@@ -182,12 +210,13 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
...
@@ -182,12 +210,13 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
}
}
void
ThreadedSSAGraphExecutor
::
RunOp
(
void
ThreadedSSAGraphExecutor
::
RunOp
(
BlockingQueue
<
VarHandleBase
*>
&
ready_var_q
,
details
::
OpHandleBase
*
op
)
{
BlockingQueue
<
VarHandleBase
*>
*
ready_var_q
,
details
::
OpHandleBase
*
op
)
{
auto
op_run
=
[
&
ready_var_q
,
op
,
this
]
{
auto
op_run
=
[
ready_var_q
,
op
,
this
]
{
try
{
try
{
VLOG
(
10
)
<<
op
->
Name
()
<<
" : "
<<
op
->
DebugString
();
VLOG
(
10
)
<<
op
->
Name
()
<<
" : "
<<
op
->
DebugString
();
op
->
Run
(
use_event_
);
op
->
Run
(
use_event_
);
ready_var_q
.
Extend
(
op
->
outputs_
);
running_ops_
--
;
ready_var_q
->
Extend
(
op
->
outputs_
);
}
catch
(
platform
::
EnforceNotMet
ex
)
{
}
catch
(
platform
::
EnforceNotMet
ex
)
{
exception_
.
reset
(
new
platform
::
EnforceNotMet
(
ex
));
exception_
.
reset
(
new
platform
::
EnforceNotMet
(
ex
));
}
catch
(...)
{
}
catch
(...)
{
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
浏览文件 @
d0ac9253
...
@@ -14,7 +14,12 @@
...
@@ -14,7 +14,12 @@
#pragma once
#pragma once
#include <chrono>
#include <deque>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include <functional>
#include <functional>
#include "ThreadPool.h" // ThreadPool in thrird party
#include "ThreadPool.h" // ThreadPool in thrird party
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
...
@@ -79,9 +84,11 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
...
@@ -79,9 +84,11 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
~
ThreadedSSAGraphExecutor
()
{}
~
ThreadedSSAGraphExecutor
()
{}
private:
private:
void
RunOp
(
BlockingQueue
<
VarHandleBase
*>
&
ready_var_q
,
void
RunOp
(
BlockingQueue
<
VarHandleBase
*>
*
ready_var_q
,
details
::
OpHandleBase
*
op
);
details
::
OpHandleBase
*
op
);
void
RunDelayedOps
(
const
std
::
unordered_set
<
OpHandleBase
*>
&
delayed_ops
);
private:
private:
std
::
unique_ptr
<::
ThreadPool
>
pool_
;
std
::
unique_ptr
<::
ThreadPool
>
pool_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
Scope
*>
local_scopes_
;
...
@@ -89,6 +96,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
...
@@ -89,6 +96,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
platform
::
DeviceContextPool
fetch_ctxs_
;
platform
::
DeviceContextPool
fetch_ctxs_
;
const
bool
use_event_
;
const
bool
use_event_
;
std
::
unique_ptr
<
platform
::
EnforceNotMet
>
exception_
;
std
::
unique_ptr
<
platform
::
EnforceNotMet
>
exception_
;
std
::
atomic
<
int
>
running_ops_
;
size_t
computation_count_
{
0
};
size_t
computation_count_
{
0
};
size_t
max_async_computation
{
100
};
size_t
max_async_computation
{
100
};
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
d0ac9253
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/framework/parallel_executor.h"
#include "paddle/fluid/framework/parallel_executor.h"
#include "paddle/fluid/platform/profiler.h"
#include <string>
#include <string>
#include <vector>
#include <vector>
...
@@ -151,6 +152,7 @@ void ParallelExecutor::BCastParamsToGPUs(
...
@@ -151,6 +152,7 @@ void ParallelExecutor::BCastParamsToGPUs(
void
ParallelExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
void
ParallelExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
const
std
::
string
&
fetched_var_name
)
{
const
std
::
string
&
fetched_var_name
)
{
platform
::
RecordBlock
b
(
0
);
auto
fetch_data
=
member_
->
executor_
->
Run
(
fetch_tensors
);
auto
fetch_data
=
member_
->
executor_
->
Run
(
fetch_tensors
);
*
member_
->
global_scope_
->
Var
(
fetched_var_name
)
->
GetMutable
<
FeedFetchList
>
()
=
*
member_
->
global_scope_
->
Var
(
fetched_var_name
)
->
GetMutable
<
FeedFetchList
>
()
=
fetch_data
;
fetch_data
;
...
...
python/paddle/fluid/parallel_executor.py
浏览文件 @
d0ac9253
...
@@ -16,6 +16,7 @@ import core
...
@@ -16,6 +16,7 @@ import core
import
multiprocessing
import
multiprocessing
import
framework
import
framework
import
executor
import
executor
import
sys
__all__
=
[
'ParallelExecutor'
]
__all__
=
[
'ParallelExecutor'
]
...
@@ -35,7 +36,7 @@ class ParallelExecutor(object):
...
@@ -35,7 +36,7 @@ class ParallelExecutor(object):
places
.
append
(
p
)
places
.
append
(
p
)
if
num_threads
is
None
:
if
num_threads
is
None
:
num_threads
=
min
(
len
(
places
)
*
2
,
multiprocessing
.
cpu_count
()
)
num_threads
=
len
(
places
)
startup
=
framework
.
default_startup_program
()
startup
=
framework
.
default_startup_program
()
main
=
framework
.
default_main_program
()
main
=
framework
.
default_main_program
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录