Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
751497db
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
751497db
编写于
6月 14, 2019
作者:
G
gongweibao
提交者:
GitHub
6月 14, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
cherrpick fixncclid 18025 test=release/1.5 (#18093)
上级
d263238f
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
110 addition
and
82 deletion
+110
-82
paddle/fluid/framework/details/all_reduce_op_handle.cc
paddle/fluid/framework/details/all_reduce_op_handle.cc
+1
-1
paddle/fluid/framework/details/all_reduce_op_handle.h
paddle/fluid/framework/details/all_reduce_op_handle.h
+1
-1
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+14
-12
paddle/fluid/framework/details/build_strategy.h
paddle/fluid/framework/details/build_strategy.h
+1
-1
paddle/fluid/framework/details/fused_all_reduce_op_handle.cc
paddle/fluid/framework/details/fused_all_reduce_op_handle.cc
+1
-1
paddle/fluid/framework/details/fused_all_reduce_op_handle.h
paddle/fluid/framework/details/fused_all_reduce_op_handle.h
+1
-1
paddle/fluid/framework/details/nccl_op_handle.h
paddle/fluid/framework/details/nccl_op_handle.h
+2
-2
paddle/fluid/framework/details/sparse_all_reduce_op_handle.cc
...le/fluid/framework/details/sparse_all_reduce_op_handle.cc
+1
-1
paddle/fluid/framework/details/sparse_all_reduce_op_handle.h
paddle/fluid/framework/details/sparse_all_reduce_op_handle.h
+1
-1
paddle/fluid/framework/ir/multi_devices_graph_pass/fuse_all_reduce_op_pass.cc
...rk/ir/multi_devices_graph_pass/fuse_all_reduce_op_pass.cc
+8
-8
paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc
...k/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc
+1
-1
paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h
...rk/ir/multi_devices_graph_pass/multi_devices_graph_pass.h
+1
-1
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+38
-40
paddle/fluid/framework/parallel_executor.h
paddle/fluid/framework/parallel_executor.h
+0
-7
paddle/fluid/framework/var_type_traits.cc
paddle/fluid/framework/var_type_traits.cc
+2
-0
paddle/fluid/framework/var_type_traits.h
paddle/fluid/framework/var_type_traits.h
+2
-1
paddle/fluid/framework/var_type_traits_test.cc
paddle/fluid/framework/var_type_traits_test.cc
+1
-0
paddle/fluid/platform/nccl_helper.h
paddle/fluid/platform/nccl_helper.h
+25
-3
python/paddle/fluid/tests/unittests/test_dist_base.py
python/paddle/fluid/tests/unittests/test_dist_base.py
+9
-0
未找到文件。
paddle/fluid/framework/details/all_reduce_op_handle.cc
浏览文件 @
751497db
...
@@ -35,7 +35,7 @@ namespace details {
...
@@ -35,7 +35,7 @@ namespace details {
AllReduceOpHandle
::
AllReduceOpHandle
(
ir
::
Node
*
node
,
AllReduceOpHandle
::
AllReduceOpHandle
(
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
platform
::
MultiNCCLContextMap
*
ctxs
)
const
platform
::
NCCLCommunicator
*
ctxs
)
:
NCCLOpHandleBase
(
node
,
places
,
ctxs
),
local_scopes_
(
local_scopes
)
{
:
NCCLOpHandleBase
(
node
,
places
,
ctxs
),
local_scopes_
(
local_scopes
)
{
PADDLE_ENFORCE_EQ
(
places_
.
size
(),
local_scopes_
.
size
());
PADDLE_ENFORCE_EQ
(
places_
.
size
(),
local_scopes_
.
size
());
}
}
...
...
paddle/fluid/framework/details/all_reduce_op_handle.h
浏览文件 @
751497db
...
@@ -34,7 +34,7 @@ class AllReduceOpHandle : public NCCLOpHandleBase {
...
@@ -34,7 +34,7 @@ class AllReduceOpHandle : public NCCLOpHandleBase {
public:
public:
AllReduceOpHandle
(
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
AllReduceOpHandle
(
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
platform
::
MultiNCCLContextMap
*
ctxs
);
const
platform
::
NCCLCommunicator
*
ctxs
);
#else
#else
class
AllReduceOpHandle
:
public
OpHandleBase
{
class
AllReduceOpHandle
:
public
OpHandleBase
{
public:
public:
...
...
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
751497db
...
@@ -266,14 +266,16 @@ bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const {
...
@@ -266,14 +266,16 @@ bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const {
return
framework
::
ir
::
MultiDevSSAGraphBuilder
().
count
(
pass_name
)
>
0
;
return
framework
::
ir
::
MultiDevSSAGraphBuilder
().
count
(
pass_name
)
>
0
;
}
}
ir
::
Graph
*
BuildStrategy
::
Apply
(
ir
::
Graph
*
BuildStrategy
::
Apply
(
ir
::
Graph
*
graph
,
ir
::
Graph
*
graph
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
string
&
loss_var_name
,
const
size_t
&
nranks
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
size_t
&
nranks
,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const
bool
use_cuda
,
platform
::
MultiNCCLContextMap
*
nccl_ctxs
)
const
{
const
bool
use_cuda
,
platform
::
NCCLCommunicator
*
nccl_ctxs
)
const
{
#else
#else
const
bool
use_cuda
)
const
{
const
bool
use_cuda
)
const
{
#endif
#endif
VLOG
(
3
)
<<
"apply all passes"
;
VLOG
(
3
)
<<
"apply all passes"
;
// Create a default one if not finalized by user.
// Create a default one if not finalized by user.
...
@@ -293,9 +295,9 @@ ir::Graph *BuildStrategy::Apply(
...
@@ -293,9 +295,9 @@ ir::Graph *BuildStrategy::Apply(
pass
->
Set
<
size_t
>
(
ir
::
kNRanks
,
new
size_t
(
nranks
));
pass
->
Set
<
size_t
>
(
ir
::
kNRanks
,
new
size_t
(
nranks
));
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform
::
MultiNCCLContextMap
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
platform
::
NCCLCommunicator
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
pass
->
Erase
(
kNCCLCtxs
);
pass
->
Erase
(
kNCCLCtxs
);
pass
->
SetNotOwned
<
platform
::
MultiNCCLContextMap
>
(
kNCCLCtxs
,
nctx
);
pass
->
SetNotOwned
<
platform
::
NCCLCommunicator
>
(
kNCCLCtxs
,
nctx
);
#endif
#endif
}
else
if
(
pass
->
Type
()
==
"alloc_continuous_space_for_grad_pass"
||
}
else
if
(
pass
->
Type
()
==
"alloc_continuous_space_for_grad_pass"
||
pass
->
Type
()
==
"fuse_adam_op_pass"
||
pass
->
Type
()
==
"fuse_adam_op_pass"
||
...
@@ -309,9 +311,9 @@ ir::Graph *BuildStrategy::Apply(
...
@@ -309,9 +311,9 @@ ir::Graph *BuildStrategy::Apply(
&
local_scopes
);
&
local_scopes
);
if
(
pass
->
Type
()
==
"fuse_all_reduce_op_pass"
)
{
if
(
pass
->
Type
()
==
"fuse_all_reduce_op_pass"
)
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform
::
MultiNCCLContextMap
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
platform
::
NCCLCommunicator
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
pass
->
Erase
(
kNCCLCtxs
);
pass
->
Erase
(
kNCCLCtxs
);
pass
->
SetNotOwned
<
platform
::
MultiNCCLContextMap
>
(
kNCCLCtxs
,
nctx
);
pass
->
SetNotOwned
<
platform
::
NCCLCommunicator
>
(
kNCCLCtxs
,
nctx
);
pass
->
Erase
(
kUseHierarchicalAllReduce
);
pass
->
Erase
(
kUseHierarchicalAllReduce
);
pass
->
Set
<
bool
>
(
kUseHierarchicalAllReduce
,
pass
->
Set
<
bool
>
(
kUseHierarchicalAllReduce
,
new
bool
(
use_hierarchical_allreduce_
));
new
bool
(
use_hierarchical_allreduce_
));
...
@@ -328,9 +330,9 @@ ir::Graph *BuildStrategy::Apply(
...
@@ -328,9 +330,9 @@ ir::Graph *BuildStrategy::Apply(
<<
enable_sequential_execution_
;
<<
enable_sequential_execution_
;
}
else
if
(
pass
->
Type
()
==
"all_reduce_deps_pass"
)
{
}
else
if
(
pass
->
Type
()
==
"all_reduce_deps_pass"
)
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform
::
MultiNCCLContextMap
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
platform
::
NCCLCommunicator
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
pass
->
Erase
(
kNCCLCtxs
);
pass
->
Erase
(
kNCCLCtxs
);
pass
->
SetNotOwned
<
platform
::
MultiNCCLContextMap
>
(
kNCCLCtxs
,
nctx
);
pass
->
SetNotOwned
<
platform
::
NCCLCommunicator
>
(
kNCCLCtxs
,
nctx
);
pass
->
Erase
(
kUseHierarchicalAllReduce
);
pass
->
Erase
(
kUseHierarchicalAllReduce
);
pass
->
Set
<
bool
>
(
kUseHierarchicalAllReduce
,
pass
->
Set
<
bool
>
(
kUseHierarchicalAllReduce
,
new
bool
(
use_hierarchical_allreduce_
));
new
bool
(
use_hierarchical_allreduce_
));
...
...
paddle/fluid/framework/details/build_strategy.h
浏览文件 @
751497db
...
@@ -149,7 +149,7 @@ struct BuildStrategy {
...
@@ -149,7 +149,7 @@ struct BuildStrategy {
const
size_t
&
nranks
,
const
size_t
&
nranks
,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const
bool
use_cuda
,
const
bool
use_cuda
,
platform
::
MultiNCCLContextMap
*
nccl_ctxs
)
const
;
platform
::
NCCLCommunicator
*
nccl_ctxs
)
const
;
#else
#else
const
bool
use_cuda
)
const
;
const
bool
use_cuda
)
const
;
#endif
#endif
...
...
paddle/fluid/framework/details/fused_all_reduce_op_handle.cc
浏览文件 @
751497db
...
@@ -44,7 +44,7 @@ typedef std::vector<std::vector<std::pair<std::string, const LoDTensor *>>>
...
@@ -44,7 +44,7 @@ typedef std::vector<std::vector<std::pair<std::string, const LoDTensor *>>>
FusedAllReduceOpHandle
::
FusedAllReduceOpHandle
(
FusedAllReduceOpHandle
::
FusedAllReduceOpHandle
(
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
size_t
num_of_all_reduce
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
size_t
num_of_all_reduce
,
const
platform
::
MultiNCCLContextMap
*
ctxs
)
const
platform
::
NCCLCommunicator
*
ctxs
)
:
NCCLOpHandleBase
(
node
,
places
,
ctxs
),
:
NCCLOpHandleBase
(
node
,
places
,
ctxs
),
local_scopes_
(
local_scopes
),
local_scopes_
(
local_scopes
),
num_of_all_reduce_
(
num_of_all_reduce
)
{
num_of_all_reduce_
(
num_of_all_reduce
)
{
...
...
paddle/fluid/framework/details/fused_all_reduce_op_handle.h
浏览文件 @
751497db
...
@@ -35,7 +35,7 @@ struct FusedAllReduceOpHandle : public NCCLOpHandleBase {
...
@@ -35,7 +35,7 @@ struct FusedAllReduceOpHandle : public NCCLOpHandleBase {
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
size_t
num_of_all_reduce
,
const
size_t
num_of_all_reduce
,
const
platform
::
MultiNCCLContextMap
*
ctxs
);
const
platform
::
NCCLCommunicator
*
ctxs
);
#else
#else
struct
FusedAllReduceOpHandle
:
public
OpHandleBase
{
struct
FusedAllReduceOpHandle
:
public
OpHandleBase
{
FusedAllReduceOpHandle
(
ir
::
Node
*
node
,
FusedAllReduceOpHandle
(
ir
::
Node
*
node
,
...
...
paddle/fluid/framework/details/nccl_op_handle.h
浏览文件 @
751497db
...
@@ -33,7 +33,7 @@ namespace details {
...
@@ -33,7 +33,7 @@ namespace details {
class
NCCLOpHandleBase
:
public
OpHandleBase
{
class
NCCLOpHandleBase
:
public
OpHandleBase
{
public:
public:
NCCLOpHandleBase
(
ir
::
Node
*
node
,
const
std
::
vector
<
platform
::
Place
>&
places
,
NCCLOpHandleBase
(
ir
::
Node
*
node
,
const
std
::
vector
<
platform
::
Place
>&
places
,
const
platform
::
MultiNCCLContextMap
*
nccl_ctxs
)
const
platform
::
NCCLCommunicator
*
nccl_ctxs
)
:
OpHandleBase
(
node
),
places_
(
places
),
nccl_ctxs_
(
nccl_ctxs
)
{
:
OpHandleBase
(
node
),
places_
(
places
),
nccl_ctxs_
(
nccl_ctxs
)
{
if
(
nccl_ctxs
==
nullptr
)
{
if
(
nccl_ctxs
==
nullptr
)
{
return
;
return
;
...
@@ -215,7 +215,7 @@ class NCCLOpHandleBase : public OpHandleBase {
...
@@ -215,7 +215,7 @@ class NCCLOpHandleBase : public OpHandleBase {
protected:
protected:
std
::
vector
<
platform
::
Place
>
places_
;
std
::
vector
<
platform
::
Place
>
places_
;
const
platform
::
MultiNCCLContextMap
*
nccl_ctxs_
{
nullptr
};
const
platform
::
NCCLCommunicator
*
nccl_ctxs_
{
nullptr
};
// When multi trainer call collective function, they need run the same order.
// When multi trainer call collective function, they need run the same order.
// Or the program will hang.So we use allreduce_deps_pass to set this
// Or the program will hang.So we use allreduce_deps_pass to set this
// run_order_.
// run_order_.
...
...
paddle/fluid/framework/details/sparse_all_reduce_op_handle.cc
浏览文件 @
751497db
...
@@ -30,7 +30,7 @@ namespace details {
...
@@ -30,7 +30,7 @@ namespace details {
SparseAllReduceOpHandle
::
SparseAllReduceOpHandle
(
SparseAllReduceOpHandle
::
SparseAllReduceOpHandle
(
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
platform
::
MultiNCCLContextMap
*
ctxs
,
bool
is_encoded
,
int
nranks
)
const
platform
::
NCCLCommunicator
*
ctxs
,
bool
is_encoded
,
int
nranks
)
:
AllReduceOpHandle
(
node
,
local_scopes
,
places
,
ctxs
),
:
AllReduceOpHandle
(
node
,
local_scopes
,
places
,
ctxs
),
is_encoded_
(
is_encoded
),
is_encoded_
(
is_encoded
),
nranks_
(
nranks
)
{
nranks_
(
nranks
)
{
...
...
paddle/fluid/framework/details/sparse_all_reduce_op_handle.h
浏览文件 @
751497db
...
@@ -32,7 +32,7 @@ class SparseAllReduceOpHandle : public AllReduceOpHandle {
...
@@ -32,7 +32,7 @@ class SparseAllReduceOpHandle : public AllReduceOpHandle {
SparseAllReduceOpHandle
(
ir
::
Node
*
node
,
SparseAllReduceOpHandle
(
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
platform
::
MultiNCCLContextMap
*
ctxs
,
const
platform
::
NCCLCommunicator
*
ctxs
,
bool
is_encoded
=
false
,
int
nranks
=
-
1
);
bool
is_encoded
=
false
,
int
nranks
=
-
1
);
std
::
string
Name
()
const
override
;
std
::
string
Name
()
const
override
;
...
...
paddle/fluid/framework/ir/multi_devices_graph_pass/fuse_all_reduce_op_pass.cc
浏览文件 @
751497db
...
@@ -35,7 +35,7 @@ class FuseAllReduceOpPass : public ir::Pass {
...
@@ -35,7 +35,7 @@ class FuseAllReduceOpPass : public ir::Pass {
auto
&
local_scopes
=
Get
<
const
std
::
vector
<
Scope
*>>
(
details
::
kLocalScopes
);
auto
&
local_scopes
=
Get
<
const
std
::
vector
<
Scope
*>>
(
details
::
kLocalScopes
);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto
*
multi_nccl_ctxs
=
auto
*
multi_nccl_ctxs
=
&
Get
<
platform
::
MultiNCCLContextMap
>
(
details
::
kNCCLCtxs
);
&
Get
<
platform
::
NCCLCommunicator
>
(
details
::
kNCCLCtxs
);
#endif
#endif
std
::
unordered_set
<
std
::
string
>
grads
;
std
::
unordered_set
<
std
::
string
>
grads
;
...
@@ -103,14 +103,14 @@ class FuseAllReduceOpPass : public ir::Pass {
...
@@ -103,14 +103,14 @@ class FuseAllReduceOpPass : public ir::Pass {
}
}
}
}
void
InsertFusedAllReduce
(
void
InsertFusedAllReduce
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
platform
::
Place
>
&
plac
es
,
const
std
::
vector
<
Scope
*>
&
local_scop
es
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
size_t
num_of_all_reduce
,
const
size_t
num_of_all_reduce
,
const
std
::
vector
<
ir
::
Node
*>
&
all_reduce_ops
,
const
std
::
vector
<
ir
::
Node
*>
&
all_reduce_ops
,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const
platform
::
MultiNCCLContextMap
*
multi_nccl_ctxs
,
const
platform
::
NCCLCommunicator
*
multi_nccl_ctxs
,
#endif
#endif
ir
::
Graph
*
result
)
const
{
ir
::
Graph
*
result
)
const
{
std
::
vector
<
details
::
VarHandleBase
*>
inputs
;
std
::
vector
<
details
::
VarHandleBase
*>
inputs
;
std
::
vector
<
details
::
VarHandleBase
*>
outputs
;
std
::
vector
<
details
::
VarHandleBase
*>
outputs
;
for
(
auto
&
op
:
all_reduce_ops
)
{
for
(
auto
&
op
:
all_reduce_ops
)
{
...
@@ -151,7 +151,7 @@ class FuseAllReduceOpPass : public ir::Pass {
...
@@ -151,7 +151,7 @@ class FuseAllReduceOpPass : public ir::Pass {
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const
platform
::
MultiNCCLContextMap
*
multi_nccl_ctxs
,
const
platform
::
NCCLCommunicator
*
multi_nccl_ctxs
,
#endif
#endif
ir
::
Graph
*
result
)
const
{
ir
::
Graph
*
result
)
const
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
...
...
paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc
浏览文件 @
751497db
...
@@ -157,7 +157,7 @@ void MultiDevSSAGraphBuilderBase::Init() const {
...
@@ -157,7 +157,7 @@ void MultiDevSSAGraphBuilderBase::Init() const {
local_scopes_
=
Get
<
const
std
::
vector
<
Scope
*>>
(
details
::
kLocalScopes
);
local_scopes_
=
Get
<
const
std
::
vector
<
Scope
*>>
(
details
::
kLocalScopes
);
strategy_
=
Get
<
const
details
::
BuildStrategy
>
(
kStrategy
);
strategy_
=
Get
<
const
details
::
BuildStrategy
>
(
kStrategy
);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
multi_nccl_ctxs_
=
&
Get
<
platform
::
MultiNCCLContextMap
>
(
details
::
kNCCLCtxs
);
multi_nccl_ctxs_
=
&
Get
<
platform
::
NCCLCommunicator
>
(
details
::
kNCCLCtxs
);
nccl_ctxs_
=
nullptr
;
nccl_ctxs_
=
nullptr
;
if
(
multi_nccl_ctxs_
)
{
if
(
multi_nccl_ctxs_
)
{
nccl_ctxs_
=
multi_nccl_ctxs_
->
DefaultFlatCtx
();
nccl_ctxs_
=
multi_nccl_ctxs_
->
DefaultFlatCtx
();
...
...
paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h
浏览文件 @
751497db
...
@@ -97,7 +97,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
...
@@ -97,7 +97,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
mutable
platform
::
NCCLContextMap
*
nccl_ctxs_
{
nullptr
};
mutable
platform
::
NCCLContextMap
*
nccl_ctxs_
{
nullptr
};
mutable
platform
::
MultiNCCLContextMap
*
multi_nccl_ctxs_
{
nullptr
};
mutable
platform
::
NCCLCommunicator
*
multi_nccl_ctxs_
{
nullptr
};
#endif
#endif
mutable
std
::
string
loss_var_name_
;
mutable
std
::
string
loss_var_name_
;
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
751497db
...
@@ -111,8 +111,8 @@ class ParallelExecutorPrivate {
...
@@ -111,8 +111,8 @@ class ParallelExecutorPrivate {
std
::
vector
<
ncclUniqueId
*>
flat_nccl_ids
;
std
::
vector
<
ncclUniqueId
*>
flat_nccl_ids
;
if
(
nranks_
==
1
)
{
if
(
nranks_
==
1
)
{
// FIXME(gongwb): need not to create ncclid when nranks==1
// FIXME(gongwb): need not to create ncclid when nranks==1
nccl_ctxs_
.
InitFlatCtxs
(
places_
,
flat_nccl_ids
,
bst
.
num_trainers_
,
nccl_ctxs_
->
InitFlatCtxs
(
places_
,
flat_nccl_ids
,
bst
.
num_trainers_
,
bst
.
trainer_id_
);
bst
.
trainer_id_
);
return
;
return
;
}
}
...
@@ -132,16 +132,16 @@ class ParallelExecutorPrivate {
...
@@ -132,16 +132,16 @@ class ParallelExecutorPrivate {
flat_nccl_ids
.
push_back
(
nccl_id
);
flat_nccl_ids
.
push_back
(
nccl_id
);
nccl_ctxs_
.
InitFlatCtxs
(
places_
,
flat_nccl_ids
,
bst
.
num_trainers_
,
nccl_ctxs_
->
InitFlatCtxs
(
places_
,
flat_nccl_ids
,
bst
.
num_trainers_
,
bst
.
trainer_id_
);
bst
.
trainer_id_
);
VLOG
(
1
)
<<
"init bst nccl context complete!"
;
VLOG
(
1
)
<<
"init bst nccl context complete!"
;
return
;
return
;
}
}
// num_trainers ==1 && places > 1
// num_trainers ==1 && places > 1
if
(
bst
.
num_trainers_
==
1
)
{
if
(
bst
.
num_trainers_
==
1
)
{
nccl_ctxs_
.
InitFlatCtxs
(
places_
,
flat_nccl_ids
,
bst
.
num_trainers_
,
nccl_ctxs_
->
InitFlatCtxs
(
places_
,
flat_nccl_ids
,
bst
.
num_trainers_
,
bst
.
trainer_id_
);
bst
.
trainer_id_
);
return
;
return
;
}
}
...
@@ -153,8 +153,8 @@ class ParallelExecutorPrivate {
...
@@ -153,8 +153,8 @@ class ParallelExecutorPrivate {
flat_nccl_ids
.
push_back
(
nccl_id
);
flat_nccl_ids
.
push_back
(
nccl_id
);
}
}
nccl_ctxs_
.
InitFlatCtxs
(
places_
,
flat_nccl_ids
,
bst
.
num_trainers_
,
nccl_ctxs_
->
InitFlatCtxs
(
places_
,
flat_nccl_ids
,
bst
.
num_trainers_
,
bst
.
trainer_id_
);
bst
.
trainer_id_
);
if
(
bst
.
use_hierarchical_allreduce_
)
{
if
(
bst
.
use_hierarchical_allreduce_
)
{
std
::
vector
<
ncclUniqueId
*>
inter_nccl_ids
;
std
::
vector
<
ncclUniqueId
*>
inter_nccl_ids
;
...
@@ -175,12 +175,30 @@ class ParallelExecutorPrivate {
...
@@ -175,12 +175,30 @@ class ParallelExecutorPrivate {
exter_nccl_ids
.
push_back
(
nccl_id
);
exter_nccl_ids
.
push_back
(
nccl_id
);
}
}
nccl_ctxs_
.
InitHierarchicalCtxs
(
places_
,
inter_nccl_ids
,
exter_nccl_ids
,
nccl_ctxs_
->
InitHierarchicalCtxs
(
bst
.
num_trainers_
,
bst
.
trainer_id
_
,
places_
,
inter_nccl_ids
,
exter_nccl_ids
,
bst
.
num_trainers
_
,
bst
.
hierarchical_allreduce_inter_nranks_
,
bst
.
trainer_id_
,
bst
.
hierarchical_allreduce_inter_nranks_
,
bst
.
hierarchical_allreduce_exter_nranks_
);
bst
.
hierarchical_allreduce_exter_nranks_
);
}
}
}
}
void
InitOrGetNCCLCommunicator
(
framework
::
Scope
*
scope
,
const
BuildStrategy
&
bst
)
{
const
std
::
string
var_name
=
"NCCLCommunicator"
;
auto
var
=
scope
->
FindVar
(
var_name
);
if
(
var
!=
nullptr
)
{
PADDLE_ENFORCE
(
var
->
IsInitialized
(),
"if %s exists, it must be initialized"
,
var_name
);
VLOG
(
1
)
<<
"find "
<<
var_name
<<
" in scope, so use it and does not recreate!"
;
nccl_ctxs_
=
var
->
GetMutable
<
platform
::
NCCLCommunicator
>
();
return
;
}
VLOG
(
1
)
<<
"not find "
<<
var_name
<<
" in scope, so recreate it!"
;
nccl_ctxs_
=
scope
->
Var
(
var_name
)
->
GetMutable
<
platform
::
NCCLCommunicator
>
();
InitNCCLCtxs
(
scope
,
bst
);
}
#endif
#endif
BuildStrategy
build_strategy_
;
BuildStrategy
build_strategy_
;
...
@@ -190,7 +208,7 @@ class ParallelExecutorPrivate {
...
@@ -190,7 +208,7 @@ class ParallelExecutorPrivate {
std
::
unique_ptr
<
details
::
SSAGraphExecutor
>
executor_
;
std
::
unique_ptr
<
details
::
SSAGraphExecutor
>
executor_
;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform
::
MultiNCCLContextMap
nccl_ctxs_
;
platform
::
NCCLCommunicator
*
nccl_ctxs_
{
nullptr
}
;
#endif
#endif
bool
own_local_scope_
;
bool
own_local_scope_
;
bool
use_cuda_
;
bool
use_cuda_
;
...
@@ -281,27 +299,6 @@ bool ParallelExecutor::NeedCreateLocalExeScope() {
...
@@ -281,27 +299,6 @@ bool ParallelExecutor::NeedCreateLocalExeScope() {
return
executor
&&
executor
->
NeedCreateLocalExeScope
();
return
executor
&&
executor
->
NeedCreateLocalExeScope
();
}
}
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
/*
* When nccl inits nccl comm using ncclCommInitAll, it meets error when
* allreduce ophandle and sync_batch_norm_op use ncclallreduce parallelly. So
* create a new nccl comm for sync_batch_norm_op. And these codes should be
* polished with a unified nccl management.
*/
platform
::
NCCLContextMap
*
ParallelExecutor
::
GetNCCLContextForSyncbatchNomrOp
(
framework
::
Scope
*
scope
)
{
auto
*
nccl_id_var
=
scope
->
FindVar
(
NCCL_ID_VARNAME
);
if
(
nccl_id_var
!=
nullptr
)
{
return
member_
->
nccl_ctxs_
.
DefaultFlatCtx
();
}
if
(
dev_nccl_ctxs_
.
get
()
==
nullptr
)
{
dev_nccl_ctxs_
.
reset
(
new
platform
::
NCCLContextMap
(
member_
->
places_
));
}
return
dev_nccl_ctxs_
.
get
();
}
#endif
ParallelExecutor
::
ParallelExecutor
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
ParallelExecutor
::
ParallelExecutor
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
std
::
string
>
&
bcast_vars
,
const
std
::
vector
<
std
::
string
>
&
bcast_vars
,
const
std
::
string
&
loss_var_name
,
const
std
::
string
&
loss_var_name
,
...
@@ -369,7 +366,7 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
...
@@ -369,7 +366,7 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
if
(
member_
->
use_cuda_
)
{
if
(
member_
->
use_cuda_
)
{
// Bcast Parameters to all GPUs
// Bcast Parameters to all GPUs
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
member_
->
Init
NCCLCtxs
(
scope
,
build_strategy
);
member_
->
Init
OrGetNCCLCommunicator
(
scope
,
build_strategy
);
// Initialize device context's nccl comm, will be used by normal
// Initialize device context's nccl comm, will be used by normal
// Operators like sync_batch_norm, and collective ops.
// Operators like sync_batch_norm, and collective ops.
...
@@ -378,7 +375,8 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
...
@@ -378,7 +375,8 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
// NOTE: NCCL group-calls and non-group-calls can not use the same
// NOTE: NCCL group-calls and non-group-calls can not use the same
// NCCL communicator, so for ParallelGraph and Multi-Process mode, re-use
// NCCL communicator, so for ParallelGraph and Multi-Process mode, re-use
// same communicators.
// same communicators.
auto
*
nccl_ctxs
=
GetNCCLContextForSyncbatchNomrOp
(
scope
);
auto
*
nccl_ctxs
=
member_
->
nccl_ctxs_
->
GetSyncBatchNormCtx
(
scope
,
member_
->
places_
);
for
(
size_t
dev_id
=
0
;
dev_id
<
member_
->
places_
.
size
();
++
dev_id
)
{
for
(
size_t
dev_id
=
0
;
dev_id
<
member_
->
places_
.
size
();
++
dev_id
)
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
::
Instance
();
...
@@ -415,18 +413,18 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
...
@@ -415,18 +413,18 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
VLOG
(
3
)
<<
"use local async mode"
;
VLOG
(
3
)
<<
"use local async mode"
;
graph
=
build_strategy
.
Apply
(
graph
,
{
member_
->
places_
[
0
]},
loss_var_name
,
graph
=
build_strategy
.
Apply
(
graph
,
{
member_
->
places_
[
0
]},
loss_var_name
,
{
member_
->
local_scopes_
[
0
]},
1
,
{
member_
->
local_scopes_
[
0
]},
1
,
member_
->
use_cuda_
,
&
member_
->
nccl_ctxs_
);
member_
->
use_cuda_
,
member_
->
nccl_ctxs_
);
for
(
size_t
i
=
1
;
i
<
member_
->
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
1
;
i
<
member_
->
places_
.
size
();
++
i
)
{
graphs
[
i
]
=
graphs
[
i
]
=
build_strategy
.
Apply
(
graphs
[
i
],
{
member_
->
places_
[
i
]},
loss_var_name
,
build_strategy
.
Apply
(
graphs
[
i
],
{
member_
->
places_
[
i
]},
loss_var_name
,
{
member_
->
local_scopes_
[
i
]},
1
,
{
member_
->
local_scopes_
[
i
]},
1
,
member_
->
use_cuda_
,
&
member_
->
nccl_ctxs_
);
member_
->
use_cuda_
,
member_
->
nccl_ctxs_
);
async_graphs
[
i
]
=
graphs
[
i
];
async_graphs
[
i
]
=
graphs
[
i
];
}
}
}
else
{
}
else
{
graph
=
build_strategy
.
Apply
(
graph
,
member_
->
places_
,
loss_var_name
,
graph
=
build_strategy
.
Apply
(
graph
,
member_
->
places_
,
loss_var_name
,
member_
->
local_scopes_
,
member_
->
nranks_
,
member_
->
local_scopes_
,
member_
->
nranks_
,
member_
->
use_cuda_
,
&
member_
->
nccl_ctxs_
);
member_
->
use_cuda_
,
member_
->
nccl_ctxs_
);
}
}
#else
#else
if
(
build_strategy
.
async_mode_
)
{
if
(
build_strategy
.
async_mode_
)
{
...
@@ -559,7 +557,7 @@ void ParallelExecutor::BCastParamsToDevices(
...
@@ -559,7 +557,7 @@ void ParallelExecutor::BCastParamsToDevices(
PADDLE_ENFORCE_EQ
(
member_
->
places_
.
size
(),
buffers
.
size
(),
PADDLE_ENFORCE_EQ
(
member_
->
places_
.
size
(),
buffers
.
size
(),
"variables' buffer size to bcast NOT equal to places"
);
"variables' buffer size to bcast NOT equal to places"
);
{
{
auto
*
nccl_ctxs
=
member_
->
nccl_ctxs_
.
DefaultFlatCtx
();
auto
*
nccl_ctxs
=
member_
->
nccl_ctxs_
->
DefaultFlatCtx
();
platform
::
NCCLGroupGuard
guard
;
platform
::
NCCLGroupGuard
guard
;
for
(
size_t
i
=
0
;
i
<
member_
->
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
member_
->
places_
.
size
();
++
i
)
{
auto
&
nccl_ctx
=
nccl_ctxs
->
at
(
member_
->
places_
[
i
]);
auto
&
nccl_ctx
=
nccl_ctxs
->
at
(
member_
->
places_
[
i
]);
...
...
paddle/fluid/framework/parallel_executor.h
浏览文件 @
751497db
...
@@ -87,13 +87,6 @@ class ParallelExecutor {
...
@@ -87,13 +87,6 @@ class ParallelExecutor {
ParallelExecutorPrivate
*
member_
;
ParallelExecutorPrivate
*
member_
;
std
::
vector
<
std
::
unique_ptr
<
ir
::
Graph
>>
async_graphs_
;
std
::
vector
<
std
::
unique_ptr
<
ir
::
Graph
>>
async_graphs_
;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
// used for compatible with syncbatch norm op
std
::
unique_ptr
<
platform
::
NCCLContextMap
>
dev_nccl_ctxs_
;
platform
::
NCCLContextMap
*
GetNCCLContextForSyncbatchNomrOp
(
framework
::
Scope
*
scope
);
#endif
};
};
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/var_type_traits.cc
浏览文件 @
751497db
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include <unordered_map>
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
...
@@ -22,6 +23,7 @@
...
@@ -22,6 +23,7 @@
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
#ifndef _WIN32
#ifndef _WIN32
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
#endif
#include <cudnn.h>
#include <cudnn.h>
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
...
...
paddle/fluid/framework/var_type_traits.h
浏览文件 @
751497db
...
@@ -36,6 +36,7 @@ namespace platform {
...
@@ -36,6 +36,7 @@ namespace platform {
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
#ifndef _WIN32
#ifndef _WIN32
class
Communicator
;
class
Communicator
;
class
NCCLCommunicator
;
#endif
#endif
#endif
#endif
}
// namespace platform
}
// namespace platform
...
@@ -140,7 +141,7 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl<
...
@@ -140,7 +141,7 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl<
std
::
map
<
size_t
,
Tensor
>
,
operators
::
reader
::
LoDTensorBlockingQueueHolder
,
std
::
map
<
size_t
,
Tensor
>
,
operators
::
reader
::
LoDTensorBlockingQueueHolder
,
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
#ifndef _WIN32
#ifndef _WIN32
ncclUniqueId
,
platform
::
Communicator
,
ncclUniqueId
,
platform
::
Communicator
,
platform
::
NCCLCommunicator
,
#endif
#endif
operators
::
CudnnRNNCache
,
operators
::
CudnnRNNCache
,
#endif
#endif
...
...
paddle/fluid/framework/var_type_traits_test.cc
浏览文件 @
751497db
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
#ifndef _WIN32
#ifndef _WIN32
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
#endif
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/cudnn_rnn_cache.h"
#include "paddle/fluid/operators/cudnn_rnn_cache.h"
...
...
paddle/fluid/platform/nccl_helper.h
浏览文件 @
751497db
...
@@ -176,10 +176,10 @@ inline std::string GetHierarchicalInterNCCLVarName(size_t pos) {
...
@@ -176,10 +176,10 @@ inline std::string GetHierarchicalInterNCCLVarName(size_t pos) {
static_cast
<
int
>
(
pos
));
static_cast
<
int
>
(
pos
));
}
}
class
MultiNCCLContextMap
{
class
NCCLCommunicator
{
public:
public:
MultiNCCLContextMap
()
{}
NCCLCommunicator
()
{}
virtual
~
MultiNCCLContextMap
()
{}
virtual
~
NCCLCommunicator
()
{}
NCCLContextMap
*
DefaultFlatCtx
()
const
{
NCCLContextMap
*
DefaultFlatCtx
()
const
{
if
(
flat_ctxs_
.
size
()
==
0
)
{
if
(
flat_ctxs_
.
size
()
==
0
)
{
...
@@ -206,6 +206,25 @@ class MultiNCCLContextMap {
...
@@ -206,6 +206,25 @@ class MultiNCCLContextMap {
return
GetHierarchicalInterCtx
(
run_order
);
return
GetHierarchicalInterCtx
(
run_order
);
}
}
/*
*When nccl inits nccl comm using ncclCommInitAll, it meets error when
*allreduce ophandle and sync_batch_norm_op use ncclallreduce parallelly. So
*create a new nccl comm for sync_batch_norm_op. And these codes should be
*polished with a unified nccl management.
*/
NCCLContextMap
*
GetSyncBatchNormCtx
(
framework
::
Scope
*
scope
,
const
std
::
vector
<
platform
::
Place
>
&
places
)
{
auto
*
nccl_id_var
=
scope
->
FindVar
(
NCCL_ID_VARNAME
);
if
(
nccl_id_var
!=
nullptr
)
{
return
DefaultFlatCtx
();
}
if
(
sync_batch_norm_ctx_
.
get
()
==
nullptr
)
{
sync_batch_norm_ctx_
.
reset
(
new
NCCLContextMap
(
places
));
}
return
sync_batch_norm_ctx_
.
get
();
}
void
InitFlatCtxs
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
void
InitFlatCtxs
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
ncclUniqueId
*>
&
nccl_ids
,
const
std
::
vector
<
ncclUniqueId
*>
&
nccl_ids
,
size_t
trainers_num
,
size_t
trainer_id
)
{
size_t
trainers_num
,
size_t
trainer_id
)
{
...
@@ -290,6 +309,9 @@ class MultiNCCLContextMap {
...
@@ -290,6 +309,9 @@ class MultiNCCLContextMap {
// And h_exter_ctxs_ can support multi comm too.
// And h_exter_ctxs_ can support multi comm too.
std
::
vector
<
std
::
unique_ptr
<
NCCLContextMap
>>
h_inter_ctxs_
;
std
::
vector
<
std
::
unique_ptr
<
NCCLContextMap
>>
h_inter_ctxs_
;
std
::
vector
<
std
::
unique_ptr
<
NCCLContextMap
>>
h_exter_ctxs_
;
std
::
vector
<
std
::
unique_ptr
<
NCCLContextMap
>>
h_exter_ctxs_
;
// just used for sync_batch_norm op.
std
::
unique_ptr
<
NCCLContextMap
>
sync_batch_norm_ctx_
;
};
};
}
// namespace platform
}
// namespace platform
...
...
python/paddle/fluid/tests/unittests/test_dist_base.py
浏览文件 @
751497db
...
@@ -167,6 +167,15 @@ class TestDistRunnerBase(object):
...
@@ -167,6 +167,15 @@ class TestDistRunnerBase(object):
build_strategy
=
build_stra
,
build_strategy
=
build_stra
,
exec_strategy
=
exec_strategy
)
exec_strategy
=
exec_strategy
)
if
args
.
use_cuda
and
args
.
update_method
==
"nccl2"
:
# it just for test share_vars_from feature.
test_exe
=
fluid
.
ParallelExecutor
(
use_cuda
=
True
,
loss_name
=
avg_cost
.
name
,
build_strategy
=
build_stra
,
main_program
=
test_program
,
share_vars_from
=
binary
.
_executor
)
feed_var_list
=
[
feed_var_list
=
[
var
for
var
in
trainer_prog
.
global_block
().
vars
.
values
()
var
for
var
in
trainer_prog
.
global_block
().
vars
.
values
()
if
var
.
is_data
if
var
.
is_data
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录