Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
f3463ecb
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f3463ecb
编写于
2月 14, 2019
作者:
Y
Yancey1989
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine pg execution
上级
46a6cac9
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
309 addition
and
86 deletion
+309
-86
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+7
-3
paddle/fluid/framework/details/multi_devices_graph_pass.cc
paddle/fluid/framework/details/multi_devices_graph_pass.cc
+32
-22
paddle/fluid/framework/details/multi_devices_graph_pass.h
paddle/fluid/framework/details/multi_devices_graph_pass.h
+11
-5
paddle/fluid/framework/details/multi_devices_helper.h
paddle/fluid/framework/details/multi_devices_helper.h
+9
-2
paddle/fluid/framework/details/op_handle_base.h
paddle/fluid/framework/details/op_handle_base.h
+3
-0
paddle/fluid/framework/details/parallel_ssa_graph_executor.cc
...le/fluid/framework/details/parallel_ssa_graph_executor.cc
+64
-1
paddle/fluid/framework/details/parallel_ssa_graph_executor.h
paddle/fluid/framework/details/parallel_ssa_graph_executor.h
+11
-0
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
...le/fluid/framework/details/threaded_ssa_graph_executor.cc
+2
-2
paddle/fluid/framework/ir/graph.h
paddle/fluid/framework/ir/graph.h
+19
-7
paddle/fluid/framework/ir/graph_helper.h
paddle/fluid/framework/ir/graph_helper.h
+3
-1
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+39
-42
python/paddle/fluid/tests/unittests/parallel_executor_test_base.py
...ddle/fluid/tests/unittests/parallel_executor_test_base.py
+2
-1
python/paddle/fluid/tests/unittests/test_parallel_executor_pg.py
...paddle/fluid/tests/unittests/test_parallel_executor_pg.py
+107
-0
未找到文件。
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
f3463ecb
...
@@ -35,8 +35,8 @@ static inline bool SeqOnlyAllReduceOps(const BuildStrategy &strategy) {
...
@@ -35,8 +35,8 @@ static inline bool SeqOnlyAllReduceOps(const BuildStrategy &strategy) {
// Should fix the allreduce op order if scheduling
// Should fix the allreduce op order if scheduling
// them in multiple threads or processes to avoid hang.
// them in multiple threads or processes to avoid hang.
return
(
!
strategy
.
enable_sequential_execution_
&&
return
(
!
strategy
.
enable_sequential_execution_
&&
strategy
.
num_trainers_
>
1
)
||
strategy
.
num_trainers_
>
1
)
&&
strategy
.
enable_parallel_graph_
;
!
strategy
.
enable_parallel_graph_
;
}
}
class
ParallelExecutorPassBuilder
:
public
ir
::
PassBuilder
{
class
ParallelExecutorPassBuilder
:
public
ir
::
PassBuilder
{
...
@@ -106,7 +106,9 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
...
@@ -106,7 +106,9 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
}
}
// Verify that the graph is correct for multi-device executor.
// Verify that the graph is correct for multi-device executor.
AppendPass
(
"multi_devices_check_pass"
);
auto
multi_devices_pass
=
AppendPass
(
"multi_devices_check_pass"
);
multi_devices_pass
->
Set
<
bool
>
(
kEnablePG
,
new
bool
(
strategy
.
enable_parallel_graph_
));
if
(
SeqOnlyAllReduceOps
(
strategy
))
{
if
(
SeqOnlyAllReduceOps
(
strategy
))
{
AppendPass
(
"all_reduce_deps_pass"
);
AppendPass
(
"all_reduce_deps_pass"
);
...
@@ -180,6 +182,8 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
...
@@ -180,6 +182,8 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
&
local_scopes
);
&
local_scopes
);
pass
->
Erase
(
kNRanks
);
pass
->
Erase
(
kNRanks
);
pass
->
Set
<
size_t
>
(
kNRanks
,
new
size_t
(
nranks
));
pass
->
Set
<
size_t
>
(
kNRanks
,
new
size_t
(
nranks
));
pass
->
Erase
(
kEnablePG
);
pass
->
Set
<
bool
>
(
kEnablePG
,
new
bool
(
true
));
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform
::
NCCLContextMap
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
platform
::
NCCLContextMap
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
...
...
paddle/fluid/framework/details/multi_devices_graph_pass.cc
浏览文件 @
f3463ecb
...
@@ -36,11 +36,6 @@ namespace framework {
...
@@ -36,11 +36,6 @@ namespace framework {
namespace
details
{
namespace
details
{
namespace
{
namespace
{
// TODO(panyx0718): Clean this up as well.
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef
std
::
vector
<
OpHandleBase
*>
GraphOps
;
const
char
kGraphOps
[]
=
"ops"
;
bool
OpHaveRole
(
const
ir
::
Node
&
node
,
const
framework
::
OpRole
&
role
)
{
bool
OpHaveRole
(
const
ir
::
Node
&
node
,
const
framework
::
OpRole
&
role
)
{
return
boost
::
get
<
int
>
(
return
boost
::
get
<
int
>
(
...
@@ -206,7 +201,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
...
@@ -206,7 +201,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
auto
&
g_name
=
backward_vars
[
i
+
1
];
auto
&
g_name
=
backward_vars
[
i
+
1
];
VLOG
(
10
)
<<
"Bcast "
<<
g_name
<<
" for parameter "
<<
p_name
;
VLOG
(
10
)
<<
"Bcast "
<<
g_name
<<
" for parameter "
<<
p_name
;
InsertCollectiveOp
(
&
result
,
p_name
,
g_name
);
InsertCollectiveOp
(
&
result
,
node
,
p_name
,
g_name
);
}
}
}
catch
(
boost
::
bad_get
e
)
{
}
catch
(
boost
::
bad_get
e
)
{
}
}
...
@@ -226,7 +221,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
...
@@ -226,7 +221,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
* Only variables should be the leaves of graph.
* Only variables should be the leaves of graph.
*/
*/
AddOutputToLeafOps
(
&
result
);
AddOutputToLeafOps
(
&
result
);
result
.
Erase
(
kGraphOps
);
//
result.Erase(kGraphOps);
return
graph
;
return
graph
;
}
}
...
@@ -391,20 +386,34 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result,
...
@@ -391,20 +386,34 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result,
}
}
void
MultiDevSSAGraphBuilderBase
::
CreateAllReduceOp
(
void
MultiDevSSAGraphBuilderBase
::
CreateAllReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
)
const
{
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
const
std
::
string
&
og
)
const
{
OpHandleBase
*
op_handle
=
nullptr
;
auto
append_allreduce_op
=
[
&
](
std
::
vector
<
Scope
*>
&
scopes
,
std
::
vector
<
platform
::
Place
>
&
places
)
->
OpHandleBase
*
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
,
nccl_ctxs_
));
scopes
,
places
,
nccl_ctxs_
));
#else
#else
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
));
scopes
,
places
));
#endif
#endif
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
();
return
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
();
};
if
(
!
strategy_
.
enable_parallel_graph_
)
op_handle
=
append_allreduce_op
(
local_scopes_
,
places_
);
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
auto
p
=
places_
[
i
];
std
::
vector
<
Scope
*>
ss
{
local_scopes_
[
i
]};
std
::
vector
<
platform
::
Place
>
ps
{
p
};
if
(
strategy_
.
enable_parallel_graph_
)
op_handle
=
append_allreduce_op
(
ss
,
ps
);
SetCommunicationContext
(
op_handle
,
p
);
SetCommunicationContext
(
op_handle
,
p
);
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
i
][
og
];
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
i
][
og
];
PADDLE_ENFORCE
(
!
vars
.
empty
());
PADDLE_ENFORCE
(
!
vars
.
empty
());
...
@@ -501,13 +510,13 @@ bool MultiDevSSAGraphBuilderBase::IsSparseGradient(
...
@@ -501,13 +510,13 @@ bool MultiDevSSAGraphBuilderBase::IsSparseGradient(
}
}
void
AllReduceSSAGraphBuilder
::
InsertCollectiveOp
(
void
AllReduceSSAGraphBuilder
::
InsertCollectiveOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
const
std
::
string
&
p_name
,
const
std
::
string
&
g_name
)
const
{
const
std
::
string
&
g_name
)
const
{
if
(
IsSparseGradient
(
g_name
))
{
if
(
IsSparseGradient
(
g_name
))
{
CreateReduceOp
(
result
,
g_name
,
0
);
CreateReduceOp
(
result
,
g_name
,
0
);
CreateBroadcastOp
(
result
,
g_name
,
0
);
CreateBroadcastOp
(
result
,
g_name
,
0
);
}
else
{
}
else
{
CreateAllReduceOp
(
result
,
g_name
);
CreateAllReduceOp
(
result
,
node
,
g_name
);
}
}
}
}
...
@@ -580,7 +589,7 @@ void ReduceSSAGraphBuilder::ResetState() const {
...
@@ -580,7 +589,7 @@ void ReduceSSAGraphBuilder::ResetState() const {
}
}
void
ReduceSSAGraphBuilder
::
InsertCollectiveOp
(
void
ReduceSSAGraphBuilder
::
InsertCollectiveOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
const
std
::
string
&
p_name
,
const
std
::
string
&
g_name
)
const
{
const
std
::
string
&
g_name
)
const
{
size_t
cur_device_id
=
GetAppropriateDeviceID
({
g_name
});
size_t
cur_device_id
=
GetAppropriateDeviceID
({
g_name
});
CreateReduceOp
(
result
,
g_name
,
cur_device_id
);
CreateReduceOp
(
result
,
g_name
,
cur_device_id
);
...
@@ -900,7 +909,7 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
...
@@ -900,7 +909,7 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
return
op_dev_id
;
return
op_dev_id
;
}
}
void
DistSSAGraphBuilder
::
InsertCollectiveOp
(
ir
::
Graph
*
result
,
void
DistSSAGraphBuilder
::
InsertCollectiveOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
const
std
::
string
&
p_name
,
const
std
::
string
&
p_name
,
const
std
::
string
&
g_name
)
const
{
const
std
::
string
&
g_name
)
const
{
size_t
cur_device_id
=
0
;
size_t
cur_device_id
=
0
;
...
@@ -915,7 +924,7 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
...
@@ -915,7 +924,7 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
CreateReduceOp
(
result
,
g_name
,
0
);
CreateReduceOp
(
result
,
g_name
,
0
);
CreateBroadcastOp
(
result
,
g_name
,
0
);
CreateBroadcastOp
(
result
,
g_name
,
0
);
}
else
{
}
else
{
CreateAllReduceOp
(
result
,
g_name
);
CreateAllReduceOp
(
result
,
node
,
g_name
);
}
}
break
;
break
;
default:
default:
...
@@ -966,7 +975,8 @@ static int MultiDevSSAGraphBuilderRegister(const std::string &builder_mode) {
...
@@ -966,7 +975,8 @@ static int MultiDevSSAGraphBuilderRegister(const std::string &builder_mode) {
.RequirePassAttr(paddle::framework::details::kPlaces) \
.RequirePassAttr(paddle::framework::details::kPlaces) \
.RequirePassAttr(paddle::framework::details::kLocalScopes) \
.RequirePassAttr(paddle::framework::details::kLocalScopes) \
.RequirePassAttr(paddle::framework::details::kStrategy) \
.RequirePassAttr(paddle::framework::details::kStrategy) \
.RequirePassAttr(paddle::framework::details::kNRanks)
.RequirePassAttr(paddle::framework::details::kNRanks) \
.RequirePassAttr(paddle::framework::details::kEnablePG)
REGISTER_MULTI_DEVICES_PASS
(
reduce_mode_multi_devices_pass
,
REGISTER_MULTI_DEVICES_PASS
(
reduce_mode_multi_devices_pass
,
paddle
::
framework
::
details
::
ReduceSSAGraphBuilder
);
paddle
::
framework
::
details
::
ReduceSSAGraphBuilder
);
...
...
paddle/fluid/framework/details/multi_devices_graph_pass.h
浏览文件 @
f3463ecb
...
@@ -36,6 +36,7 @@ constexpr char kPlaces[] = "places";
...
@@ -36,6 +36,7 @@ constexpr char kPlaces[] = "places";
constexpr
char
kLocalScopes
[]
=
"local_scopes"
;
constexpr
char
kLocalScopes
[]
=
"local_scopes"
;
constexpr
char
kStrategy
[]
=
"strategy"
;
constexpr
char
kStrategy
[]
=
"strategy"
;
constexpr
char
kNRanks
[]
=
"nranks"
;
constexpr
char
kNRanks
[]
=
"nranks"
;
constexpr
char
kEnablePG
[]
=
"enable_pg"
;
class
MultiDevSSAGraphBuilderBase
:
public
ir
::
Pass
{
class
MultiDevSSAGraphBuilderBase
:
public
ir
::
Pass
{
protected:
protected:
...
@@ -46,7 +47,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
...
@@ -46,7 +47,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
virtual
std
::
vector
<
ir
::
Node
*>
SortOperations
(
const
ir
::
Graph
&
graph
)
const
;
virtual
std
::
vector
<
ir
::
Node
*>
SortOperations
(
const
ir
::
Graph
&
graph
)
const
;
virtual
void
InsertCollectiveOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
virtual
void
InsertCollectiveOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
const
std
::
string
&
p_name
,
const
std
::
string
&
g_name
)
const
=
0
;
const
std
::
string
&
g_name
)
const
=
0
;
virtual
bool
DealWithSpecialOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
=
0
;
virtual
bool
DealWithSpecialOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
=
0
;
...
@@ -75,7 +77,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
...
@@ -75,7 +77,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
bool
IsSparseGradient
(
const
std
::
string
&
og
)
const
;
bool
IsSparseGradient
(
const
std
::
string
&
og
)
const
;
void
CreateAllReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
)
const
;
void
CreateAllReduceOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
const
std
::
string
&
og
)
const
;
void
CreateBroadcastOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
void
CreateBroadcastOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
size_t
src_dev_id
)
const
;
size_t
src_dev_id
)
const
;
...
@@ -106,7 +109,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
...
@@ -106,7 +109,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
class
AllReduceSSAGraphBuilder
:
public
MultiDevSSAGraphBuilderBase
{
class
AllReduceSSAGraphBuilder
:
public
MultiDevSSAGraphBuilderBase
{
protected:
protected:
virtual
void
InsertCollectiveOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
virtual
void
InsertCollectiveOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
const
std
::
string
&
p_name
,
const
std
::
string
&
g_name
)
const
;
const
std
::
string
&
g_name
)
const
;
virtual
bool
DealWithSpecialOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
{
virtual
bool
DealWithSpecialOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
{
...
@@ -135,7 +139,8 @@ class ReduceSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
...
@@ -135,7 +139,8 @@ class ReduceSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
protected:
protected:
virtual
void
Init
()
const
;
virtual
void
Init
()
const
;
virtual
void
InsertCollectiveOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
virtual
void
InsertCollectiveOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
const
std
::
string
&
p_name
,
const
std
::
string
&
g_name
)
const
;
const
std
::
string
&
g_name
)
const
;
virtual
bool
DealWithSpecialOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
virtual
bool
DealWithSpecialOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
)
const
;
...
@@ -164,7 +169,8 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
...
@@ -164,7 +169,8 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
virtual
void
InsertPostprocessOps
(
ir
::
Graph
*
result
)
const
;
virtual
void
InsertPostprocessOps
(
ir
::
Graph
*
result
)
const
;
virtual
void
InsertCollectiveOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
virtual
void
InsertCollectiveOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
const
std
::
string
&
p_name
,
const
std
::
string
&
g_name
)
const
;
const
std
::
string
&
g_name
)
const
;
virtual
void
ResetState
()
const
;
virtual
void
ResetState
()
const
;
...
...
paddle/fluid/framework/details/multi_devices_helper.h
浏览文件 @
f3463ecb
...
@@ -36,13 +36,20 @@ namespace details {
...
@@ -36,13 +36,20 @@ namespace details {
// map from variable name to variables. The variables, who have the same name,
// map from variable name to variables. The variables, who have the same name,
// will have a differsent version. The offset in the
// will have a differsent version. The offset in the
// `std::vector<VarHandle*>` is the version of varaibles.
// `std::vector<VarHandle*>` is the version of varaibles.
typedef
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
VarHandle
*>>>
typedef
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
VarHandle
*>>>
GraphVars
;
GraphVars
;
const
char
kGraphVars
[]
=
"vars"
;
const
char
kGraphVars
[]
=
"vars"
;
// aux variables to represent dependency. Useful to resolve data hazard.
// aux variables to represent dependency. Useful to resolve data hazard.
typedef
std
::
unordered_set
<
VarHandleBase
*>
GraphDepVars
;
typedef
std
::
unordered_set
<
VarHandleBase
*>
GraphDepVars
;
const
char
kGraphDepVars
[]
=
"dep_vars"
;
const
char
kGraphDepVars
[]
=
"dep_vars"
;
// TODO(panyx0718): Clean this up as well.
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef
std
::
vector
<
OpHandleBase
*>
GraphOps
;
const
char
kGraphOps
[]
=
"ops"
;
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/details/op_handle_base.h
浏览文件 @
f3463ecb
...
@@ -70,6 +70,9 @@ class OpHandleBase {
...
@@ -70,6 +70,9 @@ class OpHandleBase {
auto
it
=
dev_ctxes_
.
find
(
place
);
auto
it
=
dev_ctxes_
.
find
(
place
);
return
it
!=
dev_ctxes_
.
end
()
?
it
->
second
:
nullptr
;
return
it
!=
dev_ctxes_
.
end
()
?
it
->
second
:
nullptr
;
}
}
const
std
::
map
<
platform
::
Place
,
platform
::
DeviceContext
*>
&
DeviceContext
()
{
return
dev_ctxes_
;
}
void
SetDeviceContext
(
platform
::
Place
place
,
platform
::
DeviceContext
*
ctx_
)
{
void
SetDeviceContext
(
platform
::
Place
place
,
platform
::
DeviceContext
*
ctx_
)
{
dev_ctxes_
[
place
]
=
ctx_
;
dev_ctxes_
[
place
]
=
ctx_
;
...
...
paddle/fluid/framework/details/parallel_ssa_graph_executor.cc
浏览文件 @
f3463ecb
...
@@ -13,11 +13,74 @@
...
@@ -13,11 +13,74 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
std
::
vector
<
std
::
unique_ptr
<
ir
::
Graph
>>
SeparateMultiDevicesGraph
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
{
std
::
vector
<
std
::
unique_ptr
<
ir
::
Graph
>>
graphs
;
graphs
.
reserve
(
places
.
size
());
for
(
size_t
i
=
0
;
i
<
places
.
size
();
++
i
)
{
ProgramDesc
empty
;
graphs
.
emplace_back
(
std
::
unique_ptr
<
ir
::
Graph
>
(
new
ir
::
Graph
(
empty
)));
auto
&
g
=
graphs
.
back
();
g
->
Set
(
kGraphVars
,
new
GraphVars
(
1UL
));
g
->
Set
(
kGraphDepVars
,
new
GraphDepVars
);
g
->
Set
(
kGraphOps
,
new
GraphOps
);
}
for
(
auto
&
op
:
graph
->
Get
<
GraphOps
>
(
kGraphOps
))
{
auto
&
dev_ctx
=
op
->
DeviceContext
();
auto
&
p
=
dev_ctx
.
begin
()
->
first
;
#ifdef PADDLE_WITH_CUDA
int
dev_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
p
).
device
;
auto
&
dev_ops
=
graphs
[
dev_id
]
->
Get
<
GraphOps
>
(
kGraphOps
);
auto
&
dev_dummys
=
graphs
[
dev_id
]
->
Get
<
GraphDepVars
>
(
kGraphDepVars
);
dev_ops
.
emplace_back
(
op
);
graphs
[
dev_id
]
->
AddNode
(
graph
->
ReleaseNode
(
op
->
Node
()).
release
());
for
(
auto
&
var
:
op
->
Inputs
())
{
auto
dummy_ptr
=
dynamic_cast
<
DummyVarHandle
*>
(
var
);
if
(
dummy_ptr
)
{
dev_dummys
.
insert
(
var
);
if
(
graph
->
Nodes
().
count
(
var
->
Node
()))
graphs
[
dev_id
]
->
AddNode
(
graph
->
ReleaseNode
(
var
->
Node
()).
release
());
}
}
for
(
auto
&
var
:
op
->
Outputs
())
{
auto
dummy_ptr
=
dynamic_cast
<
DummyVarHandle
*>
(
var
);
if
(
dummy_ptr
)
{
dev_dummys
.
insert
(
var
);
if
(
graph
->
Nodes
().
count
(
var
->
Node
()))
graphs
[
dev_id
]
->
AddNode
(
graph
->
ReleaseNode
(
var
->
Node
()).
release
());
}
}
#else
PADDLE_THROW
(
"Parallel Graph Execution only support CUDAPlace."
);
#endif
}
for
(
size_t
dev_id
=
0
;
dev_id
<
places
.
size
();
++
dev_id
)
{
auto
&
dev_vars
=
graphs
[
dev_id
]
->
Get
<
GraphVars
>
(
kGraphVars
)[
0
];
auto
&
origin_vars
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
)[
dev_id
];
for
(
auto
&
name_pair
:
origin_vars
)
{
dev_vars
.
emplace
(
name_pair
.
first
,
name_pair
.
second
);
for
(
auto
&
version_pair
:
name_pair
.
second
)
{
if
(
graph
->
Nodes
().
count
(
version_pair
->
Node
()))
{
graphs
[
dev_id
]
->
AddNode
(
graph
->
ReleaseNode
(
version_pair
->
Node
()).
release
());
}
}
}
}
return
graphs
;
}
ParallelSSAGraphExecutor
::
ParallelSSAGraphExecutor
(
ParallelSSAGraphExecutor
::
ParallelSSAGraphExecutor
(
const
ExecutionStrategy
&
strategy
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
ExecutionStrategy
&
strategy
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
...
@@ -37,7 +100,7 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
...
@@ -37,7 +100,7 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
<<
" to run the operators of the graph on each device."
;
<<
" to run the operators of the graph on each device."
;
for
(
size_t
i
=
0
;
i
<
places
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places
.
size
();
++
i
)
{
executors_
.
emplace_back
(
new
details
::
ThreadedSSAGraphExecutor
(
executors_
.
emplace_back
(
new
details
::
ThreadedSSAGraphExecutor
(
strategy_
,
{
local_scopes_
[
i
]},
{
places_
[
i
]},
std
::
move
(
graphs_
[
i
]
)));
strategy_
,
local_scopes_
,
{
places_
[
i
]},
std
::
move
(
graphs_
.
at
(
i
)
)));
}
}
}
}
...
...
paddle/fluid/framework/details/parallel_ssa_graph_executor.h
浏览文件 @
f3463ecb
...
@@ -14,16 +14,24 @@
...
@@ -14,16 +14,24 @@
#pragma once
#pragma once
#include <fstream>
#include <sstream>
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "ThreadPool.h"
#include "ThreadPool.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
std
::
vector
<
std
::
unique_ptr
<
ir
::
Graph
>>
SeparateMultiDevicesGraph
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
std
::
unique_ptr
<
ir
::
Graph
>
graph
);
class
ParallelSSAGraphExecutor
:
public
SSAGraphExecutor
{
class
ParallelSSAGraphExecutor
:
public
SSAGraphExecutor
{
public:
public:
ParallelSSAGraphExecutor
(
const
ExecutionStrategy
&
strategy
,
ParallelSSAGraphExecutor
(
const
ExecutionStrategy
&
strategy
,
...
@@ -31,11 +39,14 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
...
@@ -31,11 +39,14 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
std
::
vector
<
std
::
unique_ptr
<
ir
::
Graph
>>
&&
graphs
);
std
::
vector
<
std
::
unique_ptr
<
ir
::
Graph
>>
&&
graphs
);
~
ParallelSSAGraphExecutor
()
final
=
default
;
~
ParallelSSAGraphExecutor
()
final
=
default
;
const
ir
::
Graph
&
Graph
()
const
override
{
return
*
graphs_
[
0
];
}
const
ir
::
Graph
&
Graph
()
const
override
{
return
*
graphs_
[
0
];
}
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
override
;
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
override
;
private:
private:
// std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph();
ExecutionStrategy
strategy_
;
ExecutionStrategy
strategy_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
unique_ptr
<::
ThreadPool
>
pool_
{
nullptr
};
std
::
unique_ptr
<::
ThreadPool
>
pool_
{
nullptr
};
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
浏览文件 @
f3463ecb
...
@@ -56,10 +56,10 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
...
@@ -56,10 +56,10 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
}
}
}
}
}
}
for
(
auto
&
var
:
graph_
->
Get
<
details
::
GraphDepVars
>
(
details
::
kGraphDepVars
))
{
for
(
auto
&
var
:
graph_
->
Get
<
details
::
GraphDepVars
>
(
details
::
kGraphDepVars
))
{
InsertPendingVar
(
&
pending_vars
,
ready_vars
.
get
(),
var
);
InsertPendingVar
(
&
pending_vars
,
ready_vars
.
get
(),
var
);
}
}
for
(
auto
&
op
:
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
*
graph_
))
{
for
(
auto
&
op
:
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
*
graph_
))
{
if
(
op
->
Inputs
().
empty
())
{
// Special case, Op has no input.
if
(
op
->
Inputs
().
empty
())
{
// Special case, Op has no input.
ready_ops
.
insert
(
op
);
ready_ops
.
insert
(
op
);
...
@@ -219,7 +219,7 @@ void ThreadedSSAGraphExecutor::RunOp(
...
@@ -219,7 +219,7 @@ void ThreadedSSAGraphExecutor::RunOp(
VLOG
(
10
)
<<
op
<<
" "
<<
op
->
Name
()
<<
" Done "
;
VLOG
(
10
)
<<
op
<<
" "
<<
op
->
Name
()
<<
" Done "
;
running_ops_
--
;
running_ops_
--
;
ready_var_q
->
Extend
(
op
->
Outputs
());
ready_var_q
->
Extend
(
op
->
Outputs
());
VLOG
(
10
)
<<
op
<<
" "
<<
op
->
Name
()
<<
"Signal posted"
;
VLOG
(
10
)
<<
op
<<
" "
<<
op
->
Name
()
<<
"
Signal posted"
;
}
catch
(...)
{
}
catch
(...)
{
exception_holder_
.
Catch
(
std
::
current_exception
());
exception_holder_
.
Catch
(
std
::
current_exception
());
}
}
...
...
paddle/fluid/framework/ir/graph.h
浏览文件 @
f3463ecb
...
@@ -167,6 +167,14 @@ class Graph {
...
@@ -167,6 +167,14 @@ class Graph {
return
ret
;
return
ret
;
}
}
std
::
unique_ptr
<
ir
::
Node
>
ReleaseNode
(
ir
::
Node
*
node
)
{
std
::
unique_ptr
<
ir
::
Node
>
ret
;
ret
.
reset
(
nodes_
.
at
(
node
).
release
());
nodes_
.
erase
(
node
);
node_set_
.
erase
(
node
);
return
ret
;
}
void
RemoveNode
(
ir
::
Node
*
node
)
{
void
RemoveNode
(
ir
::
Node
*
node
)
{
PADDLE_ENFORCE
(
node_set_
.
find
(
node
)
!=
node_set_
.
end
());
PADDLE_ENFORCE
(
node_set_
.
find
(
node
)
!=
node_set_
.
end
());
node_set_
.
erase
(
node
);
node_set_
.
erase
(
node
);
...
@@ -183,13 +191,6 @@ class Graph {
...
@@ -183,13 +191,6 @@ class Graph {
return
nullptr
;
return
nullptr
;
}
}
void
ResolveHazard
(
const
std
::
map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
&
var_nodes
);
private:
std
::
map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
InitFromProgram
(
const
ProgramDesc
&
program
);
// This method takes ownership of `node`.
// This method takes ownership of `node`.
ir
::
Node
*
AddNode
(
ir
::
Node
*
node
)
{
ir
::
Node
*
AddNode
(
ir
::
Node
*
node
)
{
PADDLE_ENFORCE
(
node_set_
.
find
(
node
)
==
node_set_
.
end
());
PADDLE_ENFORCE
(
node_set_
.
find
(
node
)
==
node_set_
.
end
());
...
@@ -198,6 +199,17 @@ class Graph {
...
@@ -198,6 +199,17 @@ class Graph {
return
node
;
return
node
;
}
}
bool
ContainNode
(
ir
::
Node
*
node
)
{
return
node_set_
.
find
(
node
)
!=
node_set_
.
end
();
}
void
ResolveHazard
(
const
std
::
map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
&
var_nodes
);
private:
std
::
map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
InitFromProgram
(
const
ProgramDesc
&
program
);
// NOTE: program_ shouldn't be exposed to user.
// NOTE: program_ shouldn't be exposed to user.
const
ProgramDesc
program_
;
const
ProgramDesc
program_
;
std
::
map
<
std
::
string
,
boost
::
any
>
attrs_
;
std
::
map
<
std
::
string
,
boost
::
any
>
attrs_
;
...
...
paddle/fluid/framework/ir/graph_helper.h
浏览文件 @
f3463ecb
...
@@ -59,7 +59,9 @@ template <typename T>
...
@@ -59,7 +59,9 @@ template <typename T>
std
::
vector
<
T
*>
FilterByNodeWrapper
(
const
Graph
&
graph
)
{
std
::
vector
<
T
*>
FilterByNodeWrapper
(
const
Graph
&
graph
)
{
std
::
vector
<
T
*>
ret
;
std
::
vector
<
T
*>
ret
;
for
(
ir
::
Node
*
n
:
graph
.
Nodes
())
{
for
(
ir
::
Node
*
n
:
graph
.
Nodes
())
{
if
(
n
->
IsWrappedBy
<
T
>
())
ret
.
push_back
(
&
n
->
Wrapper
<
T
>
());
if
(
n
->
IsWrappedBy
<
T
>
())
{
ret
.
push_back
(
&
n
->
Wrapper
<
T
>
());
}
}
}
return
ret
;
return
ret
;
}
}
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
f3463ecb
...
@@ -26,6 +26,7 @@ limitations under the License. */
...
@@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler.h"
...
@@ -201,7 +202,6 @@ ParallelExecutor::ParallelExecutor(
...
@@ -201,7 +202,6 @@ ParallelExecutor::ParallelExecutor(
member_
->
use_all_reduce_
=
member_
->
use_all_reduce_
=
build_strategy
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
;
build_strategy
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
;
member_
->
nranks_
=
build_strategy
.
num_trainers_
*
places
.
size
();
member_
->
nranks_
=
build_strategy
.
num_trainers_
*
places
.
size
();
if
(
!
member_
->
use_all_reduce_
)
{
if
(
!
member_
->
use_all_reduce_
)
{
PADDLE_ENFORCE
(
places
.
size
()
>
1
,
PADDLE_ENFORCE
(
places
.
size
()
>
1
,
"If you set build_strategy.reduce with 'Reduce',"
"If you set build_strategy.reduce with 'Reduce',"
...
@@ -229,9 +229,10 @@ ParallelExecutor::ParallelExecutor(
...
@@ -229,9 +229,10 @@ ParallelExecutor::ParallelExecutor(
// choice the execution strategy.
// choice the execution strategy.
build_strategy
.
enable_parallel_graph_
=
build_strategy
.
enable_parallel_graph_
=
EnableParallelGraphExecution
(
main_program
,
exec_strategy
,
build_strategy
);
EnableParallelGraphExecution
(
main_program
,
exec_strategy
,
build_strategy
);
if
(
build_strategy
.
enable_parallel_graph_
)
VLOG
(
1
)
<<
"Enable ParallelGraph Execution: "
VLOG
(
0
)
<<
"The Executor would execute the graph by ParallelGraph "
<<
build_strategy
.
enable_parallel_graph_
;
"Execution which can get better performance,"
<<
"you can force it off by env FLAGS_enable_parallel_graph=0"
;
if
(
member_
->
use_cuda_
)
{
if
(
member_
->
use_cuda_
)
{
// Bcast Parameters to all GPUs
// Bcast Parameters to all GPUs
...
@@ -265,58 +266,42 @@ ParallelExecutor::ParallelExecutor(
...
@@ -265,58 +266,42 @@ ParallelExecutor::ParallelExecutor(
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
// ncclOp
std
::
vector
<
std
::
unique_ptr
<
ir
::
Graph
>>
graphs
;
std
::
unique_ptr
<
ir
::
Graph
>
graph
;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
if
(
build_strategy
.
enable_parallel_graph_
)
{
graph
=
build_strategy
.
Apply
(
main_program
,
member_
->
places_
,
loss_var_name
,
for
(
size_t
i
=
0
;
i
<
member_
->
places_
.
size
();
++
i
)
{
member_
->
local_scopes_
,
member_
->
nranks_
,
std
::
unique_ptr
<
ir
::
Graph
>
graph
=
build_strategy
.
Apply
(
member_
->
use_cuda_
,
member_
->
nccl_ctxs_
.
get
());
main_program
,
{
member_
->
places_
[
i
]},
loss_var_name
,
{
member_
->
local_scopes_
[
i
]},
member_
->
nranks_
,
member_
->
use_cuda_
,
member_
->
nccl_ctxs_
.
get
());
graphs
.
push_back
(
std
::
move
(
graph
));
}
}
else
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
=
build_strategy
.
Apply
(
main_program
,
member_
->
places_
,
loss_var_name
,
member_
->
local_scopes_
,
member_
->
nranks_
,
member_
->
use_cuda_
,
member_
->
nccl_ctxs_
.
get
());
graphs
.
push_back
(
std
::
move
(
graph
));
}
#else
#else
std
::
unique_ptr
<
ir
::
Graph
>
graph
=
build_strategy
.
Apply
(
graph
=
build_strategy
.
Apply
(
main_program
,
member_
->
places_
,
loss_var_name
,
main_program
,
member_
->
places_
,
loss_var_name
,
member_
->
local_scopes_
,
member_
->
local_scopes_
,
member_
->
nranks_
,
member_
->
nranks_
,
member_
->
use_cuda_
);
member_
->
use_cuda_
);
graphs
.
push_back
(
std
::
move
(
graph
));
#endif
#endif
auto
max_memory_size
=
GetEagerDeletionThreshold
();
auto
max_memory_size
=
GetEagerDeletionThreshold
();
if
(
max_memory_size
>=
0
)
{
if
(
max_memory_size
>=
0
)
{
for
(
size_t
i
=
0
;
i
<
graphs
.
size
();
++
i
)
{
graph
=
member_
->
PrepareGCAndRefCnts
(
std
::
move
(
graph
),
graphs
[
i
]
=
member_
->
PrepareGCAndRefCnts
(
static_cast
<
size_t
>
(
max_memory_size
));
std
::
move
(
graphs
[
i
]),
static_cast
<
size_t
>
(
max_memory_size
));
}
}
}
// Step 3. Create vars in each scope. Passes may also create new vars.
// Step 3. Create vars in each scope. Passes may also create new vars.
// skip control vars and empty vars
// skip control vars and empty vars
std
::
vector
<
details
::
VariableInfo
>
var_infos
;
std
::
vector
<
details
::
VariableInfo
>
var_infos
;
for
(
auto
&
graph
:
graphs
)
{
for
(
auto
&
node
:
graph
->
Nodes
())
{
for
(
auto
&
node
:
graph
->
Nodes
())
{
if
(
node
->
IsVar
()
&&
!
node
->
IsCtrlVar
()
&&
node
->
Var
())
{
if
(
node
->
IsVar
()
&&
!
node
->
IsCtrlVar
()
&&
node
->
Var
())
{
var_infos
.
emplace_back
();
var_infos
.
emplace_back
();
var_infos
.
back
().
name_
=
node
->
Var
()
->
Name
();
var_infos
.
back
().
name_
=
node
->
Var
()
->
Name
();
var_infos
.
back
().
type_
=
node
->
Var
()
->
GetType
();
var_infos
.
back
().
type_
=
node
->
Var
()
->
GetType
();
var_infos
.
back
().
persistable_
=
node
->
Var
()
->
Persistable
();
var_infos
.
back
().
persistable_
=
node
->
Var
()
->
Persistable
();
}
}
}
}
}
// If the loss_var_name is given, the number of graph should be only one.
// If the loss_var_name is given, the number of graph should be only one.
if
(
loss_var_name
.
size
())
{
if
(
loss_var_name
.
size
())
{
size_t
graph_num
=
ir
::
GraphNum
(
*
graph
s
[
0
]
);
size_t
graph_num
=
ir
::
GraphNum
(
*
graph
);
if
(
graph_num
>
1
)
{
if
(
graph_num
>
1
)
{
LOG
(
WARNING
)
LOG
(
WARNING
)
<<
"The number of graph should be only one, "
<<
"The number of graph should be only one, "
"but the current graph has "
"but the current graph has "
<<
ir
::
GraphNum
(
*
graph
s
[
0
]
)
<<
ir
::
GraphNum
(
*
graph
)
<<
" sub_graphs. If you want to see the nodes of the "
<<
" sub_graphs. If you want to see the nodes of the "
"sub_graphs, you should use 'FLAGS_print_sub_graph_dir' "
"sub_graphs, you should use 'FLAGS_print_sub_graph_dir' "
"to specify the output dir. NOTES: if you not do training, "
"to specify the output dir. NOTES: if you not do training, "
...
@@ -325,18 +310,30 @@ ParallelExecutor::ParallelExecutor(
...
@@ -325,18 +310,30 @@ ParallelExecutor::ParallelExecutor(
}
}
if
(
build_strategy
.
enable_parallel_graph_
)
{
if
(
build_strategy
.
enable_parallel_graph_
)
{
auto
parallel_graph
=
details
::
SeparateMultiDevicesGraph
(
member_
->
places_
,
std
::
move
(
graph
));
auto
seq_allreduce_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"all_reduce_deps_pass"
);
seq_allreduce_pass
->
Erase
(
details
::
kAllOpDescs
);
seq_allreduce_pass
->
Set
<
const
std
::
vector
<
OpDesc
*>>
(
details
::
kAllOpDescs
,
new
std
::
vector
<
OpDesc
*>
(
main_program
.
Block
(
0
).
AllOps
()));
for
(
size_t
i
=
0
;
i
<
parallel_graph
.
size
();
++
i
)
{
parallel_graph
[
i
]
=
seq_allreduce_pass
->
Apply
(
std
::
move
(
parallel_graph
[
i
]));
}
member_
->
executor_
.
reset
(
new
details
::
ParallelSSAGraphExecutor
(
member_
->
executor_
.
reset
(
new
details
::
ParallelSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
member_
->
places_
,
exec_strategy
,
member_
->
local_scopes_
,
member_
->
places_
,
std
::
move
(
graphs
)));
std
::
move
(
parallel_graph
)));
}
else
{
}
else
{
if
(
exec_strategy
.
type_
==
ExecutionStrategy
::
kDefault
)
{
if
(
exec_strategy
.
type_
==
ExecutionStrategy
::
kDefault
)
{
member_
->
executor_
.
reset
(
new
details
::
ThreadedSSAGraphExecutor
(
member_
->
executor_
.
reset
(
new
details
::
ThreadedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
member_
->
places_
,
exec_strategy
,
member_
->
local_scopes_
,
member_
->
places_
,
std
::
move
(
graph
s
[
0
]
)));
std
::
move
(
graph
)));
}
else
{
}
else
{
member_
->
executor_
.
reset
(
new
details
::
FastThreadedSSAGraphExecutor
(
member_
->
executor_
.
reset
(
new
details
::
FastThreadedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
member_
->
places_
,
exec_strategy
,
member_
->
local_scopes_
,
member_
->
places_
,
std
::
move
(
graph
s
[
0
]
)));
std
::
move
(
graph
)));
}
}
}
}
...
@@ -487,8 +484,8 @@ bool ParallelExecutor::EnableParallelGraphExecution(
...
@@ -487,8 +484,8 @@ bool ParallelExecutor::EnableParallelGraphExecution(
}
}
}
}
if
(
!
member_
->
use_all_reduce_
||
!
member_
->
use_cuda_
)
//
if (!member_->use_all_reduce_ || !member_->use_cuda_)
enable_parallel_graph
=
false
;
if
(
!
member_
->
use_all_reduce_
)
enable_parallel_graph
=
false
;
if
(
build_strategy
.
enable_sequential_execution_
||
if
(
build_strategy
.
enable_sequential_execution_
||
exec_strategy
.
type_
==
ExecutionStrategy
::
ExecutorType
::
kExperimental
)
exec_strategy
.
type_
==
ExecutionStrategy
::
ExecutorType
::
kExperimental
)
...
...
python/paddle/fluid/tests/unittests/parallel_executor_test_base.py
浏览文件 @
f3463ecb
...
@@ -72,6 +72,7 @@ class TestParallelExecutorBase(unittest.TestCase):
...
@@ -72,6 +72,7 @@ class TestParallelExecutorBase(unittest.TestCase):
exe
.
run
(
startup
)
exe
.
run
(
startup
)
exec_strategy
=
fluid
.
ExecutionStrategy
()
exec_strategy
=
fluid
.
ExecutionStrategy
()
exec_strategy
.
allow_op_delay
=
allow_op_delay
exec_strategy
.
allow_op_delay
=
allow_op_delay
exec_strategy
.
num_threads
=
1
if
use_fast_executor
:
if
use_fast_executor
:
exec_strategy
.
use_experimental_executor
=
True
exec_strategy
.
use_experimental_executor
=
True
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
=
fluid
.
BuildStrategy
()
...
@@ -99,7 +100,7 @@ class TestParallelExecutorBase(unittest.TestCase):
...
@@ -99,7 +100,7 @@ class TestParallelExecutorBase(unittest.TestCase):
first_loss
,
=
run_executor
(
first_loss
,
=
run_executor
(
exe
=
exe
,
binary
=
binary
,
feed
=
feed_dict
,
fetch_list
=
[
loss
.
name
])
exe
=
exe
,
binary
=
binary
,
feed
=
feed_dict
,
fetch_list
=
[
loss
.
name
])
for
i
in
range
(
iter
):
for
_
in
range
(
iter
):
run_executor
(
run_executor
(
exe
=
exe
,
binary
=
binary
,
feed
=
feed_dict
,
fetch_list
=
[])
exe
=
exe
,
binary
=
binary
,
feed
=
feed_dict
,
fetch_list
=
[])
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor_pg.py
0 → 100644
浏览文件 @
f3463ecb
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
os
os
.
environ
[
'FLAGS_enable_parallel_graph'
]
=
str
(
1
)
import
paddle.fluid.core
as
core
import
os
import
paddle.fluid
as
fluid
from
parallel_executor_test_base
import
TestParallelExecutorBase
def
simple_fc_net
(
use_feed
):
img
=
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
784
],
dtype
=
'float32'
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
hidden
=
img
for
_
in
range
(
4
):
hidden
=
fluid
.
layers
.
fc
(
hidden
,
size
=
200
,
act
=
'tanh'
,
bias_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
1.0
)))
prediction
=
fluid
.
layers
.
fc
(
hidden
,
size
=
10
,
act
=
'softmax'
)
loss
=
fluid
.
layers
.
cross_entropy
(
input
=
prediction
,
label
=
label
)
loss
=
fluid
.
layers
.
mean
(
loss
)
return
loss
class
TestMNIST
(
TestParallelExecutorBase
):
@
classmethod
def
setUpClass
(
cls
):
os
.
environ
[
'CPU_NUM'
]
=
str
(
4
)
def
_init_data
(
self
):
np
.
random
.
seed
(
5
)
img
=
np
.
random
.
random
(
size
=
[
32
,
784
]).
astype
(
np
.
float32
)
label
=
np
.
ones
(
shape
=
[
32
,
1
],
dtype
=
'int64'
)
return
img
,
label
# simple_fc
def
check_simple_fc_convergence
(
self
,
use_cuda
,
use_reduce
=
False
):
if
use_cuda
and
not
core
.
is_compiled_with_cuda
():
return
img
,
label
=
self
.
_init_data
()
self
.
check_network_convergence
(
simple_fc_net
,
feed_dict
=
{
"image"
:
img
,
"label"
:
label
},
use_cuda
=
use_cuda
,
use_reduce
=
use_reduce
)
def
test_simple_fc
(
self
):
# use_cuda
self
.
check_simple_fc_convergence
(
True
)
def
check_simple_fc_parallel_accuracy
(
self
,
use_cuda
):
if
use_cuda
and
not
core
.
is_compiled_with_cuda
():
return
img
,
label
=
self
.
_init_data
()
single_first_loss
,
single_last_loss
=
self
.
check_network_convergence
(
method
=
simple_fc_net
,
seed
=
1
,
feed_dict
=
{
"image"
:
img
,
"label"
:
label
},
use_cuda
=
use_cuda
,
use_parallel_executor
=
False
)
parallel_first_loss
,
parallel_last_loss
=
self
.
check_network_convergence
(
method
=
simple_fc_net
,
seed
=
1
,
feed_dict
=
{
"image"
:
img
,
"label"
:
label
},
use_cuda
=
use_cuda
,
use_parallel_executor
=
True
)
self
.
assertAlmostEquals
(
np
.
mean
(
parallel_first_loss
),
single_first_loss
,
delta
=
1e-6
,
)
self
.
assertAlmostEquals
(
np
.
mean
(
parallel_last_loss
),
single_last_loss
,
delta
=
1e-6
)
def
test_simple_fc_parallel_accuracy
(
self
):
self
.
check_simple_fc_parallel_accuracy
(
True
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录