Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ca8c77d9
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ca8c77d9
编写于
12月 28, 2018
作者:
Y
Yancey1989
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
selecte execution according to strategy test=develop
上级
4743c9cd
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
86 addition
and
101 deletion
+86
-101
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+3
-4
paddle/fluid/framework/details/build_strategy.h
paddle/fluid/framework/details/build_strategy.h
+8
-3
paddle/fluid/framework/details/multi_devices_graph_pass.cc
paddle/fluid/framework/details/multi_devices_graph_pass.cc
+6
-6
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+51
-26
paddle/fluid/framework/parallel_executor.h
paddle/fluid/framework/parallel_executor.h
+3
-0
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+0
-8
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+2
-1
python/paddle/fluid/tests/unittests/parallel_executor_test_base.py
...ddle/fluid/tests/unittests/parallel_executor_test_base.py
+0
-2
python/paddle/fluid/tests/unittests/test_parallel_executor_crf.py
...addle/fluid/tests/unittests/test_parallel_executor_crf.py
+1
-7
python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py
...dle/fluid/tests/unittests/test_parallel_executor_mnist.py
+9
-30
python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py
...fluid/tests/unittests/test_parallel_executor_seresnext.py
+3
-12
python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py
...uid/tests/unittests/test_parallel_executor_transformer.py
+0
-2
未找到文件。
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
ca8c77d9
...
@@ -134,7 +134,7 @@ std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy(
...
@@ -134,7 +134,7 @@ std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy(
std
::
unique_ptr
<
ir
::
Graph
>
BuildStrategy
::
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
BuildStrategy
::
Apply
(
const
ProgramDesc
&
main_program
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
ProgramDesc
&
main_program
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
string
&
loss_var_name
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
size_t
&
n
um_parallel_device
s
,
const
size_t
&
n
rank
s
,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const
bool
use_cuda
,
platform
::
NCCLContextMap
*
nccl_ctxs
)
const
{
const
bool
use_cuda
,
platform
::
NCCLContextMap
*
nccl_ctxs
)
const
{
#else
#else
...
@@ -153,9 +153,8 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
...
@@ -153,9 +153,8 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass
->
Erase
(
"local_scopes"
);
pass
->
Erase
(
"local_scopes"
);
pass
->
SetNotOwned
<
const
std
::
vector
<
Scope
*>>
(
"local_scopes"
,
pass
->
SetNotOwned
<
const
std
::
vector
<
Scope
*>>
(
"local_scopes"
,
&
local_scopes
);
&
local_scopes
);
pass
->
Erase
(
"num_parallel_devices"
);
pass
->
Erase
(
"nranks"
);
pass
->
Set
<
size_t
>
(
"num_parallel_devices"
,
pass
->
Set
<
size_t
>
(
"nranks"
,
new
size_t
(
nranks
));
new
size_t
(
num_parallel_devices
));
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform
::
NCCLContextMap
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
platform
::
NCCLContextMap
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
...
...
paddle/fluid/framework/details/build_strategy.h
浏览文件 @
ca8c77d9
...
@@ -84,8 +84,6 @@ struct BuildStrategy {
...
@@ -84,8 +84,6 @@ struct BuildStrategy {
bool
fuse_broadcast_op_
{
false
};
bool
fuse_broadcast_op_
{
false
};
bool
enable_parallel_graph_
{
false
};
int
num_trainers_
{
1
};
int
num_trainers_
{
1
};
int
trainer_id_
{
0
};
int
trainer_id_
{
0
};
std
::
vector
<
std
::
string
>
trainers_endpoints_
;
std
::
vector
<
std
::
string
>
trainers_endpoints_
;
...
@@ -112,7 +110,7 @@ struct BuildStrategy {
...
@@ -112,7 +110,7 @@ struct BuildStrategy {
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
string
&
loss_var_name
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
size_t
&
n
um_parallel_devices_
,
const
size_t
&
n
ranks
,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const
bool
use_cuda
,
const
bool
use_cuda
,
platform
::
NCCLContextMap
*
nccl_ctxs
)
const
;
platform
::
NCCLContextMap
*
nccl_ctxs
)
const
;
...
@@ -120,6 +118,13 @@ struct BuildStrategy {
...
@@ -120,6 +118,13 @@ struct BuildStrategy {
const
bool
use_cuda
)
const
;
const
bool
use_cuda
)
const
;
#endif
#endif
// If set true, ParallelExecutor would build the main_program into multiple
// graphs,
// each of the graphs would run with one device. This approach can achieve
// better performance
// on some scenarios.
mutable
bool
enable_parallel_graph_
=
false
;
private:
private:
mutable
bool
is_finalized_
=
false
;
mutable
bool
is_finalized_
=
false
;
mutable
std
::
shared_ptr
<
ir
::
PassBuilder
>
pass_builder_
;
mutable
std
::
shared_ptr
<
ir
::
PassBuilder
>
pass_builder_
;
...
...
paddle/fluid/framework/details/multi_devices_graph_pass.cc
浏览文件 @
ca8c77d9
...
@@ -138,7 +138,7 @@ static const char kLossVarName[] = "loss_var_name";
...
@@ -138,7 +138,7 @@ static const char kLossVarName[] = "loss_var_name";
static
const
char
kPlaces
[]
=
"places"
;
static
const
char
kPlaces
[]
=
"places"
;
static
const
char
kLocalScopes
[]
=
"local_scopes"
;
static
const
char
kLocalScopes
[]
=
"local_scopes"
;
static
const
char
kStrategy
[]
=
"strategy"
;
static
const
char
kStrategy
[]
=
"strategy"
;
static
const
char
kN
umParallelDevices
[]
=
"num_parallel_device
s"
;
static
const
char
kN
Ranks
[]
=
"nrank
s"
;
void
MultiDevSSAGraphBuilder
::
Init
()
const
{
void
MultiDevSSAGraphBuilder
::
Init
()
const
{
all_vars_
.
clear
();
all_vars_
.
clear
();
...
@@ -174,7 +174,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
...
@@ -174,7 +174,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
auto
nodes
=
graph
->
ReleaseNodes
();
auto
nodes
=
graph
->
ReleaseNodes
();
ir
::
Graph
&
result
=
*
graph
;
ir
::
Graph
&
result
=
*
graph
;
size_t
n
um_parallel_devices
=
Get
<
size_t
>
(
kNumParallelDevice
s
);
size_t
n
ranks
=
Get
<
size_t
>
(
kNRank
s
);
for
(
auto
&
node
:
nodes
)
{
for
(
auto
&
node
:
nodes
)
{
if
(
node
->
IsVar
()
&&
node
->
Var
())
{
if
(
node
->
IsVar
()
&&
node
->
Var
())
{
...
@@ -251,7 +251,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
...
@@ -251,7 +251,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
CreateComputationalOps
(
&
result
,
node
,
places_
.
size
());
CreateComputationalOps
(
&
result
,
node
,
places_
.
size
());
}
}
if
(
!
is_forwarding
&&
n
um_parallel_device
s
>
1UL
)
{
if
(
!
is_forwarding
&&
n
rank
s
>
1UL
)
{
bool
is_bk_op
=
bool
is_bk_op
=
static_cast
<
bool
>
(
boost
::
get
<
int
>
(
node
->
Op
()
->
GetAttr
(
static_cast
<
bool
>
(
boost
::
get
<
int
>
(
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
&
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
&
...
@@ -649,13 +649,13 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(
...
@@ -649,13 +649,13 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(
void
MultiDevSSAGraphBuilder
::
CreateScaleLossGradOp
(
void
MultiDevSSAGraphBuilder
::
CreateScaleLossGradOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
loss_grad_name
,
ir
::
Graph
*
result
,
const
std
::
string
&
loss_grad_name
,
ir
::
Node
*
out_var_node
,
proto
::
VarType
::
Type
dtype
)
const
{
ir
::
Node
*
out_var_node
,
proto
::
VarType
::
Type
dtype
)
const
{
size_t
n
um_parallel_devices
=
Get
<
size_t
>
(
"num_parallel_device
s"
);
size_t
n
ranks
=
Get
<
size_t
>
(
"nrank
s"
);
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
// Insert ScaleCost OpHandle
// Insert ScaleCost OpHandle
auto
*
dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
places_
[
i
]);
auto
*
dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
places_
[
i
]);
auto
*
op_handle
=
new
ScaleLossGradOpHandle
(
auto
*
op_handle
=
new
ScaleLossGradOpHandle
(
result
->
CreateEmptyNode
(
"scale_loss_grad"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"scale_loss_grad"
,
ir
::
Node
::
Type
::
kOperation
),
n
um_parallel_device
s
,
local_scopes_
[
i
],
places_
[
i
],
dev_ctx
,
dtype
);
n
rank
s
,
local_scopes_
[
i
],
places_
[
i
],
dev_ctx
,
dtype
);
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
op_handle
);
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
op_handle
);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
// FIXME: Currently ScaleLossGradOp only use device_count as scale
...
@@ -888,4 +888,4 @@ REGISTER_PASS(multi_devices_pass,
...
@@ -888,4 +888,4 @@ REGISTER_PASS(multi_devices_pass,
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLocalScopes
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLocalScopes
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kStrategy
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kStrategy
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kN
umParallelDevice
s
);
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kN
Rank
s
);
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
ca8c77d9
...
@@ -107,7 +107,7 @@ class ParallelExecutorPrivate {
...
@@ -107,7 +107,7 @@ class ParallelExecutorPrivate {
bool
own_local_scope_
;
bool
own_local_scope_
;
bool
use_cuda_
;
bool
use_cuda_
;
bool
use_all_reduce_
;
bool
use_all_reduce_
;
size_t
n
um_parallel_device
s_
;
size_t
n
rank
s_
;
// global_ref_cnts_ is only initialized when ParallelExecutor constructs, and
// global_ref_cnts_ is only initialized when ParallelExecutor constructs, and
// then keeps unchanged
// then keeps unchanged
...
@@ -203,7 +203,7 @@ ParallelExecutor::ParallelExecutor(
...
@@ -203,7 +203,7 @@ ParallelExecutor::ParallelExecutor(
member_
->
build_strategy_
=
build_strategy
;
member_
->
build_strategy_
=
build_strategy
;
member_
->
use_all_reduce_
=
member_
->
use_all_reduce_
=
build_strategy
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
;
build_strategy
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
;
member_
->
n
um_parallel_device
s_
=
num_trainers
*
places
.
size
();
member_
->
n
rank
s_
=
num_trainers
*
places
.
size
();
if
(
!
member_
->
use_all_reduce_
)
{
if
(
!
member_
->
use_all_reduce_
)
{
PADDLE_ENFORCE
(
places
.
size
()
>
1
,
PADDLE_ENFORCE
(
places
.
size
()
>
1
,
...
@@ -211,16 +211,14 @@ ParallelExecutor::ParallelExecutor(
...
@@ -211,16 +211,14 @@ ParallelExecutor::ParallelExecutor(
"the number of places must be greater than 1."
);
"the number of places must be greater than 1."
);
}
}
if
(
build_strategy
.
enable_parallel_graph_
)
{
// FIXME(Yancey1989): parallel graph mode get better performance
PADDLE_ENFORCE
(
// in GPU allreduce distributed training. Need an elegant way to
member_
->
use_all_reduce_
,
// choice the execution strategy.
"build_strategy.reduce should be `AllReduce` if you want to enable"
build_strategy
.
enable_parallel_graph_
=
"ParallelGraph."
);
EnableParallelGraphExecution
(
main_program
,
exec_strategy
,
build_strategy
);
PADDLE_ENFORCE
(
member_
->
use_cuda_
,
VLOG
(
1
)
<<
"Enable ParallelGraph Execution: "
"execution_strategy.use_cuda should be True if you want to enable "
<<
build_strategy
.
enable_parallel_graph_
;
"ParallelGraph."
);
}
// Step 1. Bcast the bcast_vars to devs.
// Step 1. Bcast the bcast_vars to devs.
// Create local scopes
// Create local scopes
...
@@ -242,20 +240,20 @@ ParallelExecutor::ParallelExecutor(
...
@@ -242,20 +240,20 @@ ParallelExecutor::ParallelExecutor(
// Bcast Parameters to all GPUs
// Bcast Parameters to all GPUs
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto
*
nccl_id_var
=
scope
->
FindVar
(
NCCL_ID_VARNAME
);
auto
*
nccl_id_var
=
scope
->
FindVar
(
NCCL_ID_VARNAME
);
ncclUniqueId
*
nccl_id
=
nullptr
;
std
::
unique_ptr
<
ncclUniqueId
>
nccl_id
;
// nccl collective would broadcast nccl
i
d by gen_nccl_id operator.
// nccl collective would broadcast nccl
UniqueI
d by gen_nccl_id operator.
if
(
nccl_id_var
!=
nullptr
)
{
if
(
nccl_id_var
!=
nullptr
)
{
nccl_id
=
nccl_id_var
->
GetMutable
<
ncclUniqueId
>
(
);
nccl_id
.
reset
(
nccl_id_var
->
GetMutable
<
ncclUniqueId
>
()
);
}
}
if
(
build_strategy
.
enable_parallel_graph_
&&
places
.
size
()
>
1
)
{
if
(
build_strategy
.
enable_parallel_graph_
&&
member_
->
nranks_
>
1UL
)
{
if
(
nccl_id
==
nullptr
)
{
if
(
nccl_id
.
get
()
==
nullptr
)
{
nccl_id
=
new
ncclUniqueId
(
);
nccl_id
.
reset
(
new
ncclUniqueId
()
);
PADDLE_ENFORCE
(
platform
::
dynload
::
ncclGetUniqueId
(
nccl_id
));
platform
::
dynload
::
ncclGetUniqueId
(
nccl_id
.
get
(
));
}
}
}
}
member_
->
nccl_ctxs_
.
reset
(
new
platform
::
NCCLContextMap
(
member_
->
nccl_ctxs_
.
reset
(
new
platform
::
NCCLContextMap
(
member_
->
places_
,
nccl_id
,
num_trainers
,
trainer_id
));
member_
->
places_
,
nccl_id
.
get
()
,
num_trainers
,
trainer_id
));
#else
#else
PADDLE_THROW
(
"Not compiled with CUDA"
);
PADDLE_THROW
(
"Not compiled with CUDA"
);
#endif
#endif
...
@@ -268,27 +266,25 @@ ParallelExecutor::ParallelExecutor(
...
@@ -268,27 +266,25 @@ ParallelExecutor::ParallelExecutor(
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
// ncclOp
std
::
vector
<
std
::
unique_ptr
<
ir
::
Graph
>>
graphs
;
std
::
vector
<
std
::
unique_ptr
<
ir
::
Graph
>>
graphs
;
member_
->
num_parallel_devices_
=
member_
->
places_
.
size
()
*
num_trainers
;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
if
(
build_strategy
.
enable_parallel_graph_
)
{
if
(
build_strategy
.
enable_parallel_graph_
)
{
for
(
size_t
i
=
0
;
i
<
member_
->
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
member_
->
places_
.
size
();
++
i
)
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
=
build_strategy
.
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
=
build_strategy
.
Apply
(
main_program
,
{
member_
->
places_
[
i
]},
loss_var_name
,
main_program
,
{
member_
->
places_
[
i
]},
loss_var_name
,
{
member_
->
local_scopes_
[
i
]},
member_
->
n
um_parallel_devices
_
,
{
member_
->
local_scopes_
[
i
]},
member_
->
n
ranks_
,
member_
->
use_cuda
_
,
member_
->
use_cuda_
,
member_
->
nccl_ctxs_
.
get
());
member_
->
nccl_ctxs_
.
get
());
graphs
.
push_back
(
std
::
move
(
graph
));
graphs
.
push_back
(
std
::
move
(
graph
));
}
}
}
else
{
}
else
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
=
build_strategy
.
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
=
build_strategy
.
Apply
(
main_program
,
member_
->
places_
,
loss_var_name
,
member_
->
local_scopes_
,
main_program
,
member_
->
places_
,
loss_var_name
,
member_
->
local_scopes_
,
member_
->
num_parallel_devices_
,
member_
->
use_cuda_
,
member_
->
nranks_
,
member_
->
use_cuda_
,
member_
->
nccl_ctxs_
.
get
());
member_
->
nccl_ctxs_
.
get
());
graphs
.
push_back
(
std
::
move
(
graph
));
graphs
.
push_back
(
std
::
move
(
graph
));
}
}
#else
#else
std
::
unique_ptr
<
ir
::
Graph
>
graph
=
build_strategy
.
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
=
build_strategy
.
Apply
(
main_program
,
member_
->
places_
,
loss_var_name
,
member_
->
local_scopes_
,
main_program
,
member_
->
places_
,
loss_var_name
,
member_
->
local_scopes_
,
member_
->
n
um_parallel_device
s_
,
member_
->
use_cuda_
);
member_
->
n
rank
s_
,
member_
->
use_cuda_
);
graphs
.
push_back
(
std
::
move
(
graph
));
graphs
.
push_back
(
std
::
move
(
graph
));
#endif
#endif
auto
max_memory_size
=
GetEagerDeletionThreshold
();
auto
max_memory_size
=
GetEagerDeletionThreshold
();
...
@@ -470,6 +466,35 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
...
@@ -470,6 +466,35 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
}
}
}
}
bool
ParallelExecutor
::
EnableParallelGraphExecution
(
const
ProgramDesc
&
main_program
,
const
ExecutionStrategy
&
exec_strategy
,
const
BuildStrategy
&
build_strategy
)
const
{
bool
enable_parallel_graph
=
true
;
// TODO(Yancey1989): support sparse update in ParallelGraph mode.
for
(
auto
&
var_desc
:
main_program
.
Block
(
0
).
AllVars
())
{
if
(
var_desc
->
GetType
()
==
proto
::
VarType
::
SELECTED_ROWS
)
{
enable_parallel_graph
=
false
;
}
}
// TODO(Yancey1989): support pserver mode
for
(
auto
&
op_desc
:
main_program
.
Block
(
0
).
AllOps
())
{
if
(
op_desc
->
Type
()
==
"send"
||
op_desc
->
Type
()
==
"recv"
)
{
enable_parallel_graph
=
false
;
break
;
}
}
if
(
!
member_
->
use_all_reduce_
||
!
member_
->
use_cuda_
)
enable_parallel_graph
=
false
;
if
(
build_strategy
.
enable_sequential_execution_
||
exec_strategy
.
type_
==
ExecutionStrategy
::
ExecutorType
::
kExperimental
)
enable_parallel_graph
=
false
;
return
enable_parallel_graph
;
}
ParallelExecutor
::~
ParallelExecutor
()
{
ParallelExecutor
::~
ParallelExecutor
()
{
for
(
auto
&
p
:
member_
->
places_
)
{
for
(
auto
&
p
:
member_
->
places_
)
{
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
)
->
Wait
();
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
)
->
Wait
();
...
...
paddle/fluid/framework/parallel_executor.h
浏览文件 @
ca8c77d9
...
@@ -68,6 +68,9 @@ class ParallelExecutor {
...
@@ -68,6 +68,9 @@ class ParallelExecutor {
private:
private:
void
BCastParamsToDevices
(
const
std
::
unordered_set
<
std
::
string
>
&
vars
)
const
;
void
BCastParamsToDevices
(
const
std
::
unordered_set
<
std
::
string
>
&
vars
)
const
;
bool
EnableParallelGraphExecution
(
const
ProgramDesc
&
main_program
,
const
ExecutionStrategy
&
exec_strategy
,
const
BuildStrategy
&
build_strategy
)
const
;
ParallelExecutorPrivate
*
member_
;
ParallelExecutorPrivate
*
member_
;
};
};
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
ca8c77d9
...
@@ -980,14 +980,6 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -980,14 +980,6 @@ All parameter, weight, gradient are variables in Paddle.
R"DOC(The type is BOOL, fuse_elewise_add_act_ops indicate whether
R"DOC(The type is BOOL, fuse_elewise_add_act_ops indicate whether
to fuse elementwise_add_op and activation_op,
to fuse elementwise_add_op and activation_op,
it may make the execution faster. Default False)DOC"
)
it may make the execution faster. Default False)DOC"
)
.
def_property
(
"enable_parallel_graph"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
enable_parallel_graph_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
enable_parallel_graph_
=
b
;
},
R"DOC(The type is BOOL, if set True, ParallelExecutor would build the main_program into multiple graphs,
each of the graphs would run with one device. This approach can achieve better performance in
some scenarios. Please note, this approach only supports all-reduce mode
on GPU device)DOC"
)
.
def_property
(
.
def_property
(
"memory_optimize"
,
"memory_optimize"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
memory_optimize_
;
},
[](
const
BuildStrategy
&
self
)
{
return
self
.
memory_optimize_
;
},
...
...
python/paddle/fluid/__init__.py
浏览文件 @
ca8c77d9
...
@@ -156,7 +156,8 @@ def __bootstrap__():
...
@@ -156,7 +156,8 @@ def __bootstrap__():
read_env_flags
+=
[
read_env_flags
+=
[
'fraction_of_gpu_memory_to_use'
,
'cudnn_deterministic'
,
'fraction_of_gpu_memory_to_use'
,
'cudnn_deterministic'
,
'enable_cublas_tensor_op_math'
,
'conv_workspace_size_limit'
,
'enable_cublas_tensor_op_math'
,
'conv_workspace_size_limit'
,
'cudnn_exhaustive_search'
,
'memory_optimize_debug'
,
'selected_gpus'
'cudnn_exhaustive_search'
,
'memory_optimize_debug'
,
'selected_gpus'
,
'sync_nccl_allreduce'
]
]
core
.
init_gflags
([
sys
.
argv
[
0
]]
+
core
.
init_gflags
([
sys
.
argv
[
0
]]
+
...
...
python/paddle/fluid/tests/unittests/parallel_executor_test_base.py
浏览文件 @
ca8c77d9
...
@@ -39,7 +39,6 @@ class TestParallelExecutorBase(unittest.TestCase):
...
@@ -39,7 +39,6 @@ class TestParallelExecutorBase(unittest.TestCase):
seed
=
None
,
seed
=
None
,
use_parallel_executor
=
True
,
use_parallel_executor
=
True
,
use_reduce
=
False
,
use_reduce
=
False
,
use_parallel_graph
=
False
,
use_ir_memory_optimize
=
False
,
use_ir_memory_optimize
=
False
,
fuse_elewise_add_act_ops
=
False
,
fuse_elewise_add_act_ops
=
False
,
optimizer
=
fluid
.
optimizer
.
Adam
,
optimizer
=
fluid
.
optimizer
.
Adam
,
...
@@ -80,7 +79,6 @@ class TestParallelExecutorBase(unittest.TestCase):
...
@@ -80,7 +79,6 @@ class TestParallelExecutorBase(unittest.TestCase):
if
use_fast_executor
:
if
use_fast_executor
:
exec_strategy
.
use_experimental_executor
=
True
exec_strategy
.
use_experimental_executor
=
True
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
.
enable_parallel_graph
=
use_parallel_graph
build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
.
ReduceStrategy
.
Reduce
\
build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
.
ReduceStrategy
.
Reduce
\
if
use_reduce
else
fluid
.
BuildStrategy
.
ReduceStrategy
.
AllReduce
if
use_reduce
else
fluid
.
BuildStrategy
.
ReduceStrategy
.
AllReduce
build_strategy
.
fuse_elewise_add_act_ops
=
fuse_elewise_add_act_ops
build_strategy
.
fuse_elewise_add_act_ops
=
fuse_elewise_add_act_ops
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor_crf.py
浏览文件 @
ca8c77d9
...
@@ -175,14 +175,13 @@ class TestCRFModel(unittest.TestCase):
...
@@ -175,14 +175,13 @@ class TestCRFModel(unittest.TestCase):
print
(
pe
.
run
(
feed
=
feeder
.
feed
(
cur_batch
),
print
(
pe
.
run
(
feed
=
feeder
.
feed
(
cur_batch
),
fetch_list
=
[
avg_cost
.
name
])[
0
])
fetch_list
=
[
avg_cost
.
name
])[
0
])
def
_new_build_strategy
(
self
,
use_reduce
=
False
,
use_parallel_graph
=
False
):
def
_new_build_strategy
(
self
,
use_reduce
=
False
):
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
=
fluid
.
BuildStrategy
()
if
use_reduce
:
if
use_reduce
:
build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
.
ReduceStrategy
.
Reduce
build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
.
ReduceStrategy
.
Reduce
else
:
else
:
build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
.
ReduceStrategy
.
AllReduce
build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
.
ReduceStrategy
.
AllReduce
build_strategy
.
enable_parallel_graph
=
use_parallel_graph
return
build_strategy
return
build_strategy
...
@@ -204,11 +203,6 @@ class TestCRFModel(unittest.TestCase):
...
@@ -204,11 +203,6 @@ class TestCRFModel(unittest.TestCase):
is_sparse
=
False
,
is_sparse
=
False
,
build_strategy
=
self
.
_new_build_strategy
(),
build_strategy
=
self
.
_new_build_strategy
(),
use_cuda
=
True
)
use_cuda
=
True
)
self
.
check_network_convergence
(
is_sparse
=
False
,
build_strategy
=
self
.
_new_build_strategy
(
use_parallel_graph
=
True
),
use_cuda
=
True
)
self
.
check_network_convergence
(
self
.
check_network_convergence
(
is_sparse
=
False
,
is_sparse
=
False
,
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py
浏览文件 @
ca8c77d9
...
@@ -100,10 +100,7 @@ class TestMNIST(TestParallelExecutorBase):
...
@@ -100,10 +100,7 @@ class TestMNIST(TestParallelExecutorBase):
self
.
assertAlmostEqual
(
loss
[
0
],
loss
[
1
],
delta
=
1e-4
)
self
.
assertAlmostEqual
(
loss
[
0
],
loss
[
1
],
delta
=
1e-4
)
# simple_fc
# simple_fc
def
check_simple_fc_convergence
(
self
,
def
check_simple_fc_convergence
(
self
,
use_cuda
,
use_reduce
=
False
):
use_cuda
,
use_reduce
=
False
,
use_parallel_graph
=
False
):
if
use_cuda
and
not
core
.
is_compiled_with_cuda
():
if
use_cuda
and
not
core
.
is_compiled_with_cuda
():
return
return
...
@@ -114,15 +111,13 @@ class TestMNIST(TestParallelExecutorBase):
...
@@ -114,15 +111,13 @@ class TestMNIST(TestParallelExecutorBase):
feed_dict
=
{
"image"
:
img
,
feed_dict
=
{
"image"
:
img
,
"label"
:
label
},
"label"
:
label
},
use_cuda
=
use_cuda
,
use_cuda
=
use_cuda
,
use_reduce
=
use_reduce
,
use_reduce
=
use_reduce
)
use_parallel_graph
=
use_parallel_graph
)
def
test_simple_fc
(
self
):
def
test_simple_fc
(
self
):
# use_cuda
# use_cuda
if
core
.
is_compiled_with_cuda
():
if
core
.
is_compiled_with_cuda
():
self
.
check_simple_fc_convergence
(
True
)
self
.
check_simple_fc_convergence
(
True
)
self
.
check_simple_fc_convergence
(
self
.
check_simple_fc_convergence
(
True
,
use_reduce
=
False
)
True
,
use_reduce
=
False
,
use_parallel_graph
=
True
)
self
.
check_simple_fc_convergence
(
False
)
self
.
check_simple_fc_convergence
(
False
)
def
test_simple_fc_with_new_strategy
(
self
):
def
test_simple_fc_with_new_strategy
(
self
):
...
@@ -130,9 +125,7 @@ class TestMNIST(TestParallelExecutorBase):
...
@@ -130,9 +125,7 @@ class TestMNIST(TestParallelExecutorBase):
self
.
_compare_reduce_and_allreduce
(
simple_fc_net
,
True
)
self
.
_compare_reduce_and_allreduce
(
simple_fc_net
,
True
)
self
.
_compare_reduce_and_allreduce
(
simple_fc_net
,
False
)
self
.
_compare_reduce_and_allreduce
(
simple_fc_net
,
False
)
def
check_simple_fc_parallel_accuracy
(
self
,
def
check_simple_fc_parallel_accuracy
(
self
,
use_cuda
):
use_cuda
,
use_parallel_graph
=
False
):
if
use_cuda
and
not
core
.
is_compiled_with_cuda
():
if
use_cuda
and
not
core
.
is_compiled_with_cuda
():
return
return
...
@@ -144,16 +137,7 @@ class TestMNIST(TestParallelExecutorBase):
...
@@ -144,16 +137,7 @@ class TestMNIST(TestParallelExecutorBase):
feed_dict
=
{
"image"
:
img
,
feed_dict
=
{
"image"
:
img
,
"label"
:
label
},
"label"
:
label
},
use_cuda
=
use_cuda
,
use_cuda
=
use_cuda
,
use_parallel_executor
=
False
,
use_parallel_executor
=
False
)
use_parallel_graph
=
use_parallel_graph
)
parallel_first_loss
,
parallel_last_loss
=
self
.
check_network_convergence
(
method
=
simple_fc_net
,
seed
=
1
,
feed_dict
=
{
"image"
:
img
,
"label"
:
label
},
use_cuda
=
use_cuda
,
use_parallel_executor
=
True
,
use_parallel_graph
=
use_parallel_graph
)
self
.
assertAlmostEquals
(
self
.
assertAlmostEquals
(
np
.
mean
(
parallel_first_loss
),
np
.
mean
(
parallel_first_loss
),
...
@@ -165,15 +149,11 @@ class TestMNIST(TestParallelExecutorBase):
...
@@ -165,15 +149,11 @@ class TestMNIST(TestParallelExecutorBase):
def
test_simple_fc_parallel_accuracy
(
self
):
def
test_simple_fc_parallel_accuracy
(
self
):
if
core
.
is_compiled_with_cuda
():
if
core
.
is_compiled_with_cuda
():
self
.
check_simple_fc_parallel_accuracy
(
True
)
self
.
check_simple_fc_parallel_accuracy
(
True
)
self
.
check_simple_fc_parallel_accuracy
(
self
.
check_simple_fc_parallel_accuracy
(
True
)
True
,
use_parallel_graph
=
True
)
# FIXME(Yancey1989): ParallelGraph executor type support CPU mode
# FIXME(Yancey1989): ParallelGraph executor type support CPU mode
self
.
check_simple_fc_parallel_accuracy
(
False
)
self
.
check_simple_fc_parallel_accuracy
(
False
)
def
check_batchnorm_fc_convergence
(
self
,
def
check_batchnorm_fc_convergence
(
self
,
use_cuda
,
use_fast_executor
):
use_cuda
,
use_fast_executor
,
use_parallel_graph
=
False
):
if
use_cuda
and
not
core
.
is_compiled_with_cuda
():
if
use_cuda
and
not
core
.
is_compiled_with_cuda
():
return
return
...
@@ -184,8 +164,7 @@ class TestMNIST(TestParallelExecutorBase):
...
@@ -184,8 +164,7 @@ class TestMNIST(TestParallelExecutorBase):
feed_dict
=
{
"image"
:
img
,
feed_dict
=
{
"image"
:
img
,
"label"
:
label
},
"label"
:
label
},
use_cuda
=
use_cuda
,
use_cuda
=
use_cuda
,
use_fast_executor
=
use_fast_executor
,
use_fast_executor
=
use_fast_executor
)
use_parallel_graph
=
use_parallel_graph
)
def
test_batchnorm_fc
(
self
):
def
test_batchnorm_fc
(
self
):
for
use_cuda
in
(
False
,
True
):
for
use_cuda
in
(
False
,
True
):
...
@@ -193,7 +172,7 @@ class TestMNIST(TestParallelExecutorBase):
...
@@ -193,7 +172,7 @@ class TestMNIST(TestParallelExecutorBase):
self
.
check_batchnorm_fc_convergence
(
use_cuda
,
use_fast_executor
)
self
.
check_batchnorm_fc_convergence
(
use_cuda
,
use_fast_executor
)
self
.
check_batchnorm_fc_convergence
(
self
.
check_batchnorm_fc_convergence
(
use_cuda
=
True
,
use_fast_executor
=
False
,
use_parallel_graph
=
True
)
use_cuda
=
True
,
use_fast_executor
=
False
)
def
test_batchnorm_fc_with_new_strategy
(
self
):
def
test_batchnorm_fc_with_new_strategy
(
self
):
# FIXME(zcd): close this test temporally.
# FIXME(zcd): close this test temporally.
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor_seresnext.py
浏览文件 @
ca8c77d9
...
@@ -277,9 +277,7 @@ class TestResnet(TestParallelExecutorBase):
...
@@ -277,9 +277,7 @@ class TestResnet(TestParallelExecutorBase):
use_cuda
=
True
,
use_cuda
=
True
,
use_reduce
=
False
,
use_reduce
=
False
,
iter
=
20
,
iter
=
20
,
delta2
=
1e-6
,
delta2
=
1e-6
):
use_parallel_graph
=
False
,
lr_scale
=
1.0
):
if
use_cuda
and
not
core
.
is_compiled_with_cuda
():
if
use_cuda
and
not
core
.
is_compiled_with_cuda
():
return
return
...
@@ -298,8 +296,7 @@ class TestResnet(TestParallelExecutorBase):
...
@@ -298,8 +296,7 @@ class TestResnet(TestParallelExecutorBase):
use_cuda
=
use_cuda
,
use_cuda
=
use_cuda
,
use_reduce
=
use_reduce
,
use_reduce
=
use_reduce
,
optimizer
=
optimizer
,
optimizer
=
optimizer
,
use_parallel_executor
=
False
,
use_parallel_executor
=
False
)
use_parallel_graph
=
use_parallel_graph
)
parallel_first_loss
,
parallel_last_loss
=
self
.
check_network_convergence
(
parallel_first_loss
,
parallel_last_loss
=
self
.
check_network_convergence
(
model
,
model
,
feed_dict
=
{
"image"
:
img
,
feed_dict
=
{
"image"
:
img
,
...
@@ -308,8 +305,7 @@ class TestResnet(TestParallelExecutorBase):
...
@@ -308,8 +305,7 @@ class TestResnet(TestParallelExecutorBase):
batch_size
=
batch_size
,
batch_size
=
batch_size
,
use_cuda
=
use_cuda
,
use_cuda
=
use_cuda
,
use_reduce
=
use_reduce
,
use_reduce
=
use_reduce
,
optimizer
=
optimizer
,
optimizer
=
optimizer
)
use_parallel_graph
=
use_parallel_graph
)
self
.
assertAlmostEquals
(
self
.
assertAlmostEquals
(
np
.
mean
(
parallel_first_loss
),
single_first_loss
[
0
],
delta
=
1e-6
)
np
.
mean
(
parallel_first_loss
),
single_first_loss
[
0
],
delta
=
1e-6
)
...
@@ -320,11 +316,6 @@ class TestResnet(TestParallelExecutorBase):
...
@@ -320,11 +316,6 @@ class TestResnet(TestParallelExecutorBase):
if
core
.
is_compiled_with_cuda
():
if
core
.
is_compiled_with_cuda
():
self
.
_check_resnet_convergence
(
self
.
_check_resnet_convergence
(
model
=
SE_ResNeXt50Small
,
use_cuda
=
True
)
model
=
SE_ResNeXt50Small
,
use_cuda
=
True
)
self
.
_check_resnet_convergence
(
model
=
SE_ResNeXt50Small
,
use_cuda
=
True
,
use_parallel_graph
=
True
,
lr_scale
=
core
.
get_cuda_device_count
())
self
.
_check_resnet_convergence
(
self
.
_check_resnet_convergence
(
model
=
SE_ResNeXt50Small
,
use_cuda
=
False
,
iter
=
2
,
delta2
=
1e-3
)
model
=
SE_ResNeXt50Small
,
use_cuda
=
False
,
iter
=
2
,
delta2
=
1e-3
)
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor_transformer.py
浏览文件 @
ca8c77d9
...
@@ -175,8 +175,6 @@ class TestTransformer(TestParallelExecutorBase):
...
@@ -175,8 +175,6 @@ class TestTransformer(TestParallelExecutorBase):
self
.
check_network_convergence
(
transformer
,
use_cuda
=
True
)
self
.
check_network_convergence
(
transformer
,
use_cuda
=
True
)
self
.
check_network_convergence
(
self
.
check_network_convergence
(
transformer
,
use_cuda
=
True
,
enable_sequential_execution
=
True
)
transformer
,
use_cuda
=
True
,
enable_sequential_execution
=
True
)
self
.
check_network_convergence
(
transformer
,
use_cuda
=
True
,
use_parallel_graph
=
True
)
self
.
check_network_convergence
(
transformer
,
use_cuda
=
False
,
iter
=
5
)
self
.
check_network_convergence
(
transformer
,
use_cuda
=
False
,
iter
=
5
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录