Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
65bbf950
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
65bbf950
编写于
5月 27, 2019
作者:
G
gongweibao
提交者:
GitHub
5月 27, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add multi-ncclcomm and 2D ncclallreduce support. (#17263)
上级
b1bd483a
变更
27
显示空白变更内容
内联
并排
Showing
27 changed file
with
862 addition
and
166 deletion
+862
-166
paddle/fluid/framework/details/all_reduce_op_handle.cc
paddle/fluid/framework/details/all_reduce_op_handle.cc
+8
-23
paddle/fluid/framework/details/all_reduce_op_handle.h
paddle/fluid/framework/details/all_reduce_op_handle.h
+12
-5
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+21
-12
paddle/fluid/framework/details/build_strategy.h
paddle/fluid/framework/details/build_strategy.h
+12
-1
paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc
...uid/framework/details/fast_threaded_ssa_graph_executor.cc
+1
-0
paddle/fluid/framework/details/fused_all_reduce_op_handle.cc
paddle/fluid/framework/details/fused_all_reduce_op_handle.cc
+5
-17
paddle/fluid/framework/details/fused_all_reduce_op_handle.h
paddle/fluid/framework/details/fused_all_reduce_op_handle.h
+8
-5
paddle/fluid/framework/details/multi_devices_helper.h
paddle/fluid/framework/details/multi_devices_helper.h
+1
-0
paddle/fluid/framework/details/nccl_op_handle.h
paddle/fluid/framework/details/nccl_op_handle.h
+234
-0
paddle/fluid/framework/details/op_handle_base.cc
paddle/fluid/framework/details/op_handle_base.cc
+5
-0
paddle/fluid/framework/details/parallel_ssa_graph_executor.cc
...le/fluid/framework/details/parallel_ssa_graph_executor.cc
+1
-0
paddle/fluid/framework/details/sparse_all_reduce_op_handle.cc
...le/fluid/framework/details/sparse_all_reduce_op_handle.cc
+3
-2
paddle/fluid/framework/details/sparse_all_reduce_op_handle.h
paddle/fluid/framework/details/sparse_all_reduce_op_handle.h
+1
-1
paddle/fluid/framework/ir/multi_devices_graph_pass/all_reduce_deps_pass.cc
...ework/ir/multi_devices_graph_pass/all_reduce_deps_pass.cc
+27
-13
paddle/fluid/framework/ir/multi_devices_graph_pass/fuse_all_reduce_op_pass.cc
...rk/ir/multi_devices_graph_pass/fuse_all_reduce_op_pass.cc
+13
-12
paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc
...k/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc
+8
-4
paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h
...rk/ir/multi_devices_graph_pass/multi_devices_graph_pass.h
+2
-1
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+97
-35
paddle/fluid/framework/parallel_executor.h
paddle/fluid/framework/parallel_executor.h
+0
-3
paddle/fluid/operators/distributed_ops/gen_nccl_id_op.cc
paddle/fluid/operators/distributed_ops/gen_nccl_id_op.cc
+166
-26
paddle/fluid/platform/nccl_helper.h
paddle/fluid/platform/nccl_helper.h
+130
-2
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+28
-0
python/paddle/fluid/compiler.py
python/paddle/fluid/compiler.py
+8
-0
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+4
-0
python/paddle/fluid/tests/unittests/test_dist_base.py
python/paddle/fluid/tests/unittests/test_dist_base.py
+13
-1
python/paddle/fluid/tests/unittests/test_dist_mnist.py
python/paddle/fluid/tests/unittests/test_dist_mnist.py
+14
-0
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+40
-3
未找到文件。
paddle/fluid/framework/details/all_reduce_op_handle.cc
浏览文件 @
65bbf950
...
@@ -35,16 +35,9 @@ namespace details {
...
@@ -35,16 +35,9 @@ namespace details {
AllReduceOpHandle
::
AllReduceOpHandle
(
ir
::
Node
*
node
,
AllReduceOpHandle
::
AllReduceOpHandle
(
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
platform
::
NCCLContextMap
*
ctxs
)
const
platform
::
MultiNCCLContextMap
*
ctxs
)
:
OpHandleBase
(
node
),
:
NCCLOpHandleBase
(
node
,
places
,
ctxs
),
local_scopes_
(
local_scopes
)
{
local_scopes_
(
local_scopes
),
PADDLE_ENFORCE_EQ
(
places_
.
size
(),
local_scopes_
.
size
());
places_
(
places
),
nccl_ctxs_
(
ctxs
)
{
if
(
nccl_ctxs_
)
{
for
(
auto
&
p
:
places_
)
{
this
->
SetDeviceContext
(
p
,
nccl_ctxs_
->
DevCtx
(
p
));
}
}
}
}
#else
#else
AllReduceOpHandle
::
AllReduceOpHandle
(
ir
::
Node
*
node
,
AllReduceOpHandle
::
AllReduceOpHandle
(
ir
::
Node
*
node
,
...
@@ -71,7 +64,9 @@ void AllReduceOpHandle::RunAllReduceFuncs(
...
@@ -71,7 +64,9 @@ void AllReduceOpHandle::RunAllReduceFuncs(
if
(
FLAGS_sync_nccl_allreduce
)
{
if
(
FLAGS_sync_nccl_allreduce
)
{
for
(
auto
&
p
:
places_
)
{
for
(
auto
&
p
:
places_
)
{
int
dev_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
p
).
device
;
int
dev_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
p
).
device
;
auto
&
nccl_ctx
=
nccl_ctxs_
->
at
(
dev_id
);
auto
*
nccl_ctxs
=
nccl_ctxs_
->
GetRunEnvNCCLCtx
(
run_order_
,
use_hierarchical_allreduce_
);
auto
&
nccl_ctx
=
nccl_ctxs
->
at
(
dev_id
);
auto
stream
=
nccl_ctx
.
stream
();
auto
stream
=
nccl_ctx
.
stream
();
cudaError_t
e_sync
=
cudaStreamSynchronize
(
stream
);
cudaError_t
e_sync
=
cudaStreamSynchronize
(
stream
);
if
(
e_sync
!=
0
)
{
if
(
e_sync
!=
0
)
{
...
@@ -134,19 +129,9 @@ void AllReduceOpHandle::RunImpl() {
...
@@ -134,19 +129,9 @@ void AllReduceOpHandle::RunImpl() {
numel
=
static_cast
<
size_t
>
(
lod_tensor
.
numel
());
numel
=
static_cast
<
size_t
>
(
lod_tensor
.
numel
());
}
}
int
dev_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
p
).
device
;
auto
&
nccl_ctx
=
nccl_ctxs_
->
at
(
dev_id
);
auto
stream
=
nccl_ctx
.
stream
();
auto
comm
=
nccl_ctx
.
comm_
;
VLOG
(
10
)
<<
"before all reduce buffer:"
<<
buffer
<<
", numel:"
<<
numel
<<
", dev_id:"
<<
dev_id
<<
", dtype:"
<<
dtype
<<
", place:"
<<
p
;
all_reduce_calls
.
emplace_back
([
=
]
{
all_reduce_calls
.
emplace_back
([
=
]
{
PADDLE_ENFORCE
(
platform
::
dynload
::
ncclAllReduce
(
NCCLAllReduce
(
p
,
buffer
,
buffer
,
numel
,
buffer
,
buffer
,
numel
,
static_cast
<
ncclDataType_t
>
(
dtype
),
ncclSum
,
static_cast
<
ncclDataType_t
>
(
dtype
),
ncclSum
);
comm
,
stream
));
});
});
}
}
RunAllReduceFuncs
(
all_reduce_calls
);
RunAllReduceFuncs
(
all_reduce_calls
);
...
...
paddle/fluid/framework/details/all_reduce_op_handle.h
浏览文件 @
65bbf950
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/framework/details/nccl_op_handle.h"
#include "paddle/fluid/platform/nccl_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
#endif
...
@@ -28,13 +29,15 @@ namespace paddle {
...
@@ -28,13 +29,15 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
class
AllReduceOpHandle
:
public
OpHandleBase
{
public:
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
class
AllReduceOpHandle
:
public
NCCLOpHandleBase
{
public:
AllReduceOpHandle
(
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
AllReduceOpHandle
(
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
platform
::
NCCLContextMap
*
ctxs
);
const
platform
::
Multi
NCCLContextMap
*
ctxs
);
#else
#else
class
AllReduceOpHandle
:
public
OpHandleBase
{
public:
AllReduceOpHandle
(
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
AllReduceOpHandle
(
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
);
const
std
::
vector
<
platform
::
Place
>
&
places
);
#endif
#endif
...
@@ -46,13 +49,17 @@ class AllReduceOpHandle : public OpHandleBase {
...
@@ -46,13 +49,17 @@ class AllReduceOpHandle : public OpHandleBase {
protected:
protected:
void
RunImpl
()
override
;
void
RunImpl
()
override
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
Scope
*>
local_scopes_
;
#if !(defined(PADDLE_WITH_CUDA) && !defined(_WIN32))
// NCCLOpHandleBase already have these attributes.
// Will polish it by class inheritance framework.
std
::
vector
<
platform
::
Place
>
places_
;
std
::
vector
<
platform
::
Place
>
places_
;
#endif
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
void
RunAllReduceFuncs
(
void
RunAllReduceFuncs
(
const
std
::
vector
<
std
::
function
<
void
()
>>
&
all_reduce_calls
);
const
std
::
vector
<
std
::
function
<
void
()
>>
&
all_reduce_calls
);
const
platform
::
NCCLContextMap
*
nccl_ctxs_
;
#endif
#endif
};
};
...
...
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
65bbf950
...
@@ -256,14 +256,12 @@ bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const {
...
@@ -256,14 +256,12 @@ bool BuildStrategy::IsMultiDevPass(const std::string &pass_name) const {
return
framework
::
ir
::
MultiDevSSAGraphBuilder
().
count
(
pass_name
)
>
0
;
return
framework
::
ir
::
MultiDevSSAGraphBuilder
().
count
(
pass_name
)
>
0
;
}
}
ir
::
Graph
*
BuildStrategy
::
Apply
(
ir
::
Graph
*
graph
,
ir
::
Graph
*
BuildStrategy
::
Apply
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
ir
::
Graph
*
graph
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
string
&
loss_var_name
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
size_t
&
nranks
,
const
size_t
&
nranks
,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const
bool
use_cuda
,
const
bool
use_cuda
,
platform
::
MultiNCCLContextMap
*
nccl_ctxs
)
const
{
platform
::
NCCLContextMap
*
nccl_ctxs
)
const
{
#else
#else
const
bool
use_cuda
)
const
{
const
bool
use_cuda
)
const
{
#endif
#endif
...
@@ -285,9 +283,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
...
@@ -285,9 +283,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
pass
->
Set
<
size_t
>
(
ir
::
kNRanks
,
new
size_t
(
nranks
));
pass
->
Set
<
size_t
>
(
ir
::
kNRanks
,
new
size_t
(
nranks
));
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform
::
NCCLContextMap
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
platform
::
Multi
NCCLContextMap
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
pass
->
Erase
(
kNCCLCtxs
);
pass
->
Erase
(
kNCCLCtxs
);
pass
->
SetNotOwned
<
platform
::
NCCLContextMap
>
(
kNCCLCtxs
,
nctx
);
pass
->
SetNotOwned
<
platform
::
Multi
NCCLContextMap
>
(
kNCCLCtxs
,
nctx
);
#endif
#endif
}
else
if
(
pass
->
Type
()
==
"alloc_continuous_space_for_grad_pass"
||
}
else
if
(
pass
->
Type
()
==
"alloc_continuous_space_for_grad_pass"
||
pass
->
Type
()
==
"fuse_adam_op_pass"
||
pass
->
Type
()
==
"fuse_adam_op_pass"
||
...
@@ -301,9 +299,12 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
...
@@ -301,9 +299,12 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
&
local_scopes
);
&
local_scopes
);
if
(
pass
->
Type
()
==
"fuse_all_reduce_op_pass"
)
{
if
(
pass
->
Type
()
==
"fuse_all_reduce_op_pass"
)
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform
::
NCCLContextMap
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
platform
::
Multi
NCCLContextMap
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
pass
->
Erase
(
kNCCLCtxs
);
pass
->
Erase
(
kNCCLCtxs
);
pass
->
SetNotOwned
<
platform
::
NCCLContextMap
>
(
kNCCLCtxs
,
nctx
);
pass
->
SetNotOwned
<
platform
::
MultiNCCLContextMap
>
(
kNCCLCtxs
,
nctx
);
pass
->
Erase
(
kUseHierarchicalAllReduce
);
pass
->
Set
<
bool
>
(
kUseHierarchicalAllReduce
,
new
bool
(
use_hierarchical_allreduce_
));
#endif
#endif
}
}
}
else
if
(
pass
->
Type
()
==
"alloc_continuous_space_for_grad_pass"
)
{
}
else
if
(
pass
->
Type
()
==
"alloc_continuous_space_for_grad_pass"
)
{
...
@@ -316,6 +317,14 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
...
@@ -316,6 +317,14 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
LOG
(
INFO
)
<<
"set enable_sequential_execution:"
LOG
(
INFO
)
<<
"set enable_sequential_execution:"
<<
enable_sequential_execution_
;
<<
enable_sequential_execution_
;
}
else
if
(
pass
->
Type
()
==
"all_reduce_deps_pass"
)
{
}
else
if
(
pass
->
Type
()
==
"all_reduce_deps_pass"
)
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform
::
MultiNCCLContextMap
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
pass
->
Erase
(
kNCCLCtxs
);
pass
->
SetNotOwned
<
platform
::
MultiNCCLContextMap
>
(
kNCCLCtxs
,
nctx
);
pass
->
Erase
(
kUseHierarchicalAllReduce
);
pass
->
Set
<
bool
>
(
kUseHierarchicalAllReduce
,
new
bool
(
use_hierarchical_allreduce_
));
#endif
LOG
(
INFO
)
<<
"SeqOnlyAllReduceOps:"
<<
SeqOnlyAllReduceOps
(
*
this
)
LOG
(
INFO
)
<<
"SeqOnlyAllReduceOps:"
<<
SeqOnlyAllReduceOps
(
*
this
)
<<
", num_trainers:"
<<
num_trainers_
;
<<
", num_trainers:"
<<
num_trainers_
;
}
else
if
(
pass
->
Type
()
==
"fuse_relu_depthwise_conv_pass"
)
{
}
else
if
(
pass
->
Type
()
==
"fuse_relu_depthwise_conv_pass"
)
{
...
...
paddle/fluid/framework/details/build_strategy.h
浏览文件 @
65bbf950
...
@@ -111,6 +111,17 @@ struct BuildStrategy {
...
@@ -111,6 +111,17 @@ struct BuildStrategy {
bool
cache_runtime_context_
{
false
};
bool
cache_runtime_context_
{
false
};
std
::
unordered_set
<
std
::
string
>
mkldnn_enabled_op_types_
;
std
::
unordered_set
<
std
::
string
>
mkldnn_enabled_op_types_
;
size_t
nccl_comm_num_
{
1
};
// The picture is here:
// https://github.com/PaddlePaddle/Paddle/pull/17263#discussion_r285411396
bool
use_hierarchical_allreduce_
{
false
};
// Nccl ranks in a node when use hierarchical allreduce, it's setted to gpu
// cards' number in most cases.
size_t
hierarchical_allreduce_inter_nranks_
{
0
};
// Nccl ranks bewteen nodes when use hierarchical allreduce, it's setted to
// nodes number.
size_t
hierarchical_allreduce_exter_nranks_
{
0
};
// NOTE:
// NOTE:
// Before you add new options, think if it's a general strategy that works
// Before you add new options, think if it's a general strategy that works
// with other strategy. If not, the strategy should be created through
// with other strategy. If not, the strategy should be created through
...
@@ -136,7 +147,7 @@ struct BuildStrategy {
...
@@ -136,7 +147,7 @@ struct BuildStrategy {
const
size_t
&
nranks
,
const
size_t
&
nranks
,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const
bool
use_cuda
,
const
bool
use_cuda
,
platform
::
NCCLContextMap
*
nccl_ctxs
)
const
;
platform
::
Multi
NCCLContextMap
*
nccl_ctxs
)
const
;
#else
#else
const
bool
use_cuda
)
const
;
const
bool
use_cuda
)
const
;
#endif
#endif
...
...
paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc
浏览文件 @
65bbf950
...
@@ -49,6 +49,7 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
...
@@ -49,6 +49,7 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
FeedFetchList
FastThreadedSSAGraphExecutor
::
Run
(
FeedFetchList
FastThreadedSSAGraphExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
{
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
{
VLOG
(
3
)
<<
"enter FastThreadedSSAGraphExecutor Run"
;
std
::
unique_ptr
<
std
::
unordered_map
<
OpHandleBase
*
,
std
::
atomic
<
int
>>>
std
::
unique_ptr
<
std
::
unordered_map
<
OpHandleBase
*
,
std
::
atomic
<
int
>>>
op_deps
=
atomic_op_deps_
.
get
();
op_deps
=
atomic_op_deps_
.
get
();
PrepareAtomicOpDeps
();
PrepareAtomicOpDeps
();
...
...
paddle/fluid/framework/details/fused_all_reduce_op_handle.cc
浏览文件 @
65bbf950
...
@@ -44,17 +44,10 @@ typedef std::vector<std::vector<std::pair<std::string, const LoDTensor *>>>
...
@@ -44,17 +44,10 @@ typedef std::vector<std::vector<std::pair<std::string, const LoDTensor *>>>
FusedAllReduceOpHandle
::
FusedAllReduceOpHandle
(
FusedAllReduceOpHandle
::
FusedAllReduceOpHandle
(
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
size_t
num_of_all_reduce
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
size_t
num_of_all_reduce
,
const
platform
::
NCCLContextMap
*
ctxs
)
const
platform
::
Multi
NCCLContextMap
*
ctxs
)
:
OpHandleBase
(
node
),
:
NCCLOpHandleBase
(
node
,
places
,
ctxs
),
local_scopes_
(
local_scopes
),
local_scopes_
(
local_scopes
),
places_
(
places
),
num_of_all_reduce_
(
num_of_all_reduce
)
{
num_of_all_reduce_
(
num_of_all_reduce
),
nccl_ctxs_
(
ctxs
)
{
if
(
nccl_ctxs_
)
{
for
(
auto
&
p
:
places_
)
{
this
->
SetDeviceContext
(
p
,
nccl_ctxs_
->
DevCtx
(
p
));
}
}
PADDLE_ENFORCE_EQ
(
places_
.
size
(),
local_scopes_
.
size
());
PADDLE_ENFORCE_EQ
(
places_
.
size
(),
local_scopes_
.
size
());
}
}
#else
#else
...
@@ -167,14 +160,9 @@ void FusedAllReduceOpHandle::RunImpl() {
...
@@ -167,14 +160,9 @@ void FusedAllReduceOpHandle::RunImpl() {
auto
&
p
=
places_
[
i
];
auto
&
p
=
places_
[
i
];
void
*
buffer
=
const_cast
<
void
*>
(
lod_tensor_data
.
at
(
i
));
void
*
buffer
=
const_cast
<
void
*>
(
lod_tensor_data
.
at
(
i
));
int
dev_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
p
).
device
;
auto
&
nccl_ctx
=
nccl_ctxs_
->
at
(
dev_id
);
auto
stream
=
nccl_ctx
.
stream
();
auto
comm
=
nccl_ctx
.
comm_
;
all_reduce_calls
.
emplace_back
([
=
]
{
all_reduce_calls
.
emplace_back
([
=
]
{
PADDLE_ENFORCE
(
platform
::
dynload
::
ncclAllReduce
(
NCCLAllReduce
(
p
,
buffer
,
buffer
,
numel
,
buffer
,
buffer
,
numel
,
static_cast
<
ncclDataType_t
>
(
nccl_dtype
),
static_cast
<
ncclDataType_t
>
(
nccl_dtype
),
ncclSum
);
ncclSum
,
comm
,
stream
));
});
});
}
}
...
...
paddle/fluid/framework/details/fused_all_reduce_op_handle.h
浏览文件 @
65bbf950
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/framework/details/nccl_op_handle.h"
#include "paddle/fluid/platform/nccl_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
#endif
...
@@ -28,14 +29,15 @@ namespace paddle {
...
@@ -28,14 +29,15 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
struct
FusedAllReduceOpHandle
:
public
OpHandleBase
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
struct
FusedAllReduceOpHandle
:
public
NCCLOpHandleBase
{
FusedAllReduceOpHandle
(
ir
::
Node
*
node
,
FusedAllReduceOpHandle
(
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
size_t
num_of_all_reduce
,
const
size_t
num_of_all_reduce
,
const
platform
::
NCCLContextMap
*
ctxs
);
const
platform
::
Multi
NCCLContextMap
*
ctxs
);
#else
#else
struct
FusedAllReduceOpHandle
:
public
OpHandleBase
{
FusedAllReduceOpHandle
(
ir
::
Node
*
node
,
FusedAllReduceOpHandle
(
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
...
@@ -52,11 +54,12 @@ struct FusedAllReduceOpHandle : public OpHandleBase {
...
@@ -52,11 +54,12 @@ struct FusedAllReduceOpHandle : public OpHandleBase {
private:
private:
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
Scope
*>
local_scopes_
;
#if !(defined(PADDLE_WITH_CUDA) && !defined(_WIN32))
// NCCLOpHandleBase already have these attributes.
// Will polish it by class inheritance framework.
std
::
vector
<
platform
::
Place
>
places_
;
std
::
vector
<
platform
::
Place
>
places_
;
size_t
num_of_all_reduce_
;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const
platform
::
NCCLContextMap
*
nccl_ctxs_
;
#endif
#endif
size_t
num_of_all_reduce_
;
// Check the dtype of the input
// Check the dtype of the input
void
GetDTypeAndNumel
(
void
GetDTypeAndNumel
(
...
...
paddle/fluid/framework/details/multi_devices_helper.h
浏览文件 @
65bbf950
...
@@ -45,6 +45,7 @@ constexpr char kGraphVars[] = "vars";
...
@@ -45,6 +45,7 @@ constexpr char kGraphVars[] = "vars";
constexpr
char
kPlaces
[]
=
"places"
;
constexpr
char
kPlaces
[]
=
"places"
;
constexpr
char
kLocalScopes
[]
=
"local_scopes"
;
constexpr
char
kLocalScopes
[]
=
"local_scopes"
;
constexpr
char
kNCCLCtxs
[]
=
"nccl_ctxs"
;
constexpr
char
kNCCLCtxs
[]
=
"nccl_ctxs"
;
constexpr
char
kUseHierarchicalAllReduce
[]
=
"use_hierarchical_allreduce"
;
// aux variables to represent dependency. Useful to resolve data hazard.
// aux variables to represent dependency. Useful to resolve data hazard.
typedef
std
::
unordered_set
<
VarHandleBase
*>
GraphDepVars
;
typedef
std
::
unordered_set
<
VarHandleBase
*>
GraphDepVars
;
...
...
paddle/fluid/framework/details/nccl_op_handle.h
0 → 100644
浏览文件 @
65bbf950
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/nccl_helper.h"
DECLARE_bool
(
sync_nccl_allreduce
);
namespace
paddle
{
namespace
framework
{
namespace
details
{
class
NCCLOpHandleBase
:
public
OpHandleBase
{
public:
NCCLOpHandleBase
(
ir
::
Node
*
node
,
const
std
::
vector
<
platform
::
Place
>&
places
,
const
platform
::
MultiNCCLContextMap
*
nccl_ctxs
)
:
OpHandleBase
(
node
),
places_
(
places
),
nccl_ctxs_
(
nccl_ctxs
)
{
if
(
nccl_ctxs
==
nullptr
)
{
return
;
}
// init device context
auto
default_nccl_ctxs
=
nccl_ctxs_
->
DefaultFlatCtx
();
for
(
auto
&
p
:
places_
)
{
this
->
SetDeviceContext
(
p
,
default_nccl_ctxs
->
DevCtx
(
p
));
}
}
virtual
~
NCCLOpHandleBase
()
{
for
(
auto
&
ev
:
inter_events_
)
{
PADDLE_ENFORCE
(
cudaEventDestroy
(
ev
.
second
));
}
for
(
auto
&
ev
:
exter_events_
)
{
PADDLE_ENFORCE
(
cudaEventDestroy
(
ev
.
second
));
}
}
void
SetRunEnv
(
int
run_order
,
bool
use_hierarchical_allreduce
)
{
PADDLE_ENFORCE
(
run_order
>=
0
,
"run_order must >= 0"
);
run_order_
=
run_order
;
use_hierarchical_allreduce_
=
use_hierarchical_allreduce
;
VLOG
(
10
)
<<
"SetRunEnv "
<<
" run_order:"
<<
run_order
<<
", use_hierarchical_allreduce:"
<<
use_hierarchical_allreduce
;
if
(
nccl_ctxs_
==
nullptr
)
{
return
;
}
if
(
!
use_hierarchical_allreduce_
)
{
auto
ctxs
=
nccl_ctxs_
->
GetFlatCtx
(
run_order
);
for
(
auto
&
p
:
places_
)
{
this
->
SetDeviceContext
(
p
,
ctxs
->
DevCtx
(
p
));
}
return
;
}
PADDLE_ENFORCE
(
places_
.
size
()
==
1
,
"HierarchicalAllReduce run one proc with one card mode."
);
for
(
auto
&
p
:
places_
)
{
auto
ctxs
=
nccl_ctxs_
->
GetHierarchicalInterCtx
(
run_order
);
this
->
SetDeviceContext
(
p
,
ctxs
->
DevCtx
(
p
));
}
for
(
auto
&
p
:
dev_ctxes_
)
{
int
dev_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
p
.
first
).
device
;
if
(
inter_events_
.
find
(
dev_id
)
!=
inter_events_
.
end
())
{
continue
;
}
PADDLE_ENFORCE
(
cudaSetDevice
(
dev_id
));
PADDLE_ENFORCE
(
cudaEventCreateWithFlags
(
&
inter_events_
[
dev_id
],
cudaEventDisableTiming
));
PADDLE_ENFORCE
(
cudaEventCreateWithFlags
(
&
exter_events_
[
dev_id
],
cudaEventDisableTiming
));
VLOG
(
10
)
<<
"Create events on dev_id:"
<<
dev_id
<<
", inter_event:"
<<
&
inter_events_
[
dev_id
]
<<
", exter_event:"
<<
&
exter_events_
[
dev_id
];
}
}
void
FlatNCCLAllReduce
(
platform
::
Place
place
,
const
void
*
sendbuff
,
void
*
recvbuff
,
size_t
count
,
ncclDataType_t
datatype
,
ncclRedOp_t
op
)
{
PADDLE_ENFORCE
(
run_order_
>=
0
,
"run_order must > 0"
);
auto
flat_nccl_ctxs
=
nccl_ctxs_
->
GetFlatCtx
(
run_order_
);
int
dev_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
place
).
device
;
auto
&
nccl_ctx
=
flat_nccl_ctxs
->
at
(
dev_id
);
auto
stream
=
nccl_ctx
.
stream
();
auto
comm
=
nccl_ctx
.
comm_
;
VLOG
(
10
)
<<
"before all reduce buffer:"
<<
sendbuff
<<
", numel:"
<<
count
<<
", dev_id:"
<<
dev_id
<<
", dtype:"
<<
datatype
<<
", place:"
<<
place
;
PADDLE_ENFORCE
(
platform
::
dynload
::
ncclAllReduce
(
sendbuff
,
recvbuff
,
count
,
datatype
,
op
,
comm
,
stream
));
}
void
NCCLAllReduce
(
platform
::
Place
place
,
const
void
*
sendbuff
,
void
*
recvbuff
,
size_t
count
,
ncclDataType_t
datatype
,
ncclRedOp_t
op
)
{
PADDLE_ENFORCE
(
run_order_
>=
0
,
"run_order must > 0"
);
if
(
!
use_hierarchical_allreduce_
)
{
FlatNCCLAllReduce
(
place
,
sendbuff
,
recvbuff
,
count
,
datatype
,
op
);
return
;
}
HierarchicalAllReduce
(
place
,
sendbuff
,
recvbuff
,
count
,
datatype
,
op
);
}
void
HierarchicalAllReduce
(
platform
::
Place
place
,
const
void
*
sendbuff
,
void
*
recvbuff
,
size_t
count
,
ncclDataType_t
datatype
,
ncclRedOp_t
op
)
{
PADDLE_ENFORCE
(
run_order_
>=
0
,
"run_order must > 0"
);
InterReduce
(
place
,
sendbuff
,
recvbuff
,
count
,
datatype
,
op
);
// When a trainer is not in exter allreduce ring
// they need not to call this.
if
(
nccl_ctxs_
->
NeedExterAllReduce
())
{
ExterAllReduce
(
place
,
recvbuff
,
recvbuff
,
count
,
datatype
,
op
);
}
InterBroadCast
(
place
,
recvbuff
,
count
,
datatype
,
op
);
}
protected:
void
InterReduce
(
platform
::
Place
place
,
const
void
*
sendbuff
,
void
*
recvbuff
,
size_t
count
,
ncclDataType_t
datatype
,
ncclRedOp_t
op
)
{
auto
nccl_ctxs
=
nccl_ctxs_
->
GetHierarchicalInterCtx
(
run_order_
);
int
dev_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
place
).
device
;
auto
&
nccl_ctx
=
nccl_ctxs
->
at
(
dev_id
);
auto
stream
=
nccl_ctx
.
stream
();
auto
comm
=
nccl_ctx
.
comm_
;
VLOG
(
10
)
<<
"before all reduce"
<<
" run_order:"
<<
run_order_
<<
", buffer:"
<<
sendbuff
<<
", numel:"
<<
count
<<
", dev_id:"
<<
dev_id
<<
", dtype:"
<<
datatype
<<
", place:"
<<
place
<<
", stream:"
<<
stream
;
PADDLE_ENFORCE
(
platform
::
dynload
::
ncclReduce
(
sendbuff
,
recvbuff
,
count
,
datatype
,
ncclSum
,
0
,
comm
,
stream
));
cudaEventRecord
(
inter_events_
.
at
(
dev_id
),
stream
);
if
(
FLAGS_sync_nccl_allreduce
)
{
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
stream
),
"sync HierarchicalAllReduce inter stream error"
);
}
}
void
ExterAllReduce
(
platform
::
Place
place
,
const
void
*
sendbuff
,
void
*
recvbuff
,
size_t
count
,
ncclDataType_t
datatype
,
ncclRedOp_t
op
)
{
auto
nccl_ctxs
=
nccl_ctxs_
->
GetHierarchicalExterCtx
(
run_order_
);
PADDLE_ENFORCE
(
nccl_ctxs_
,
"can't get exter %d nccl_ctxs"
,
run_order_
);
int
dev_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
place
).
device
;
auto
&
nccl_ctx
=
nccl_ctxs
->
at
(
dev_id
);
auto
stream
=
nccl_ctx
.
stream
();
auto
comm
=
nccl_ctx
.
comm_
;
VLOG
(
10
)
<<
"before all reduce run_order:"
<<
run_order_
<<
"buffer:"
<<
sendbuff
<<
", numel:"
<<
count
<<
", dev_id:"
<<
dev_id
<<
", dtype:"
<<
datatype
<<
", place:"
<<
place
<<
", stream:"
<<
stream
;
cudaStreamWaitEvent
(
stream
,
inter_events_
.
at
(
dev_id
),
0
);
PADDLE_ENFORCE
(
platform
::
dynload
::
ncclAllReduce
(
sendbuff
,
recvbuff
,
count
,
datatype
,
op
,
comm
,
stream
));
cudaEventRecord
(
exter_events_
.
at
(
dev_id
),
stream
);
if
(
FLAGS_sync_nccl_allreduce
)
{
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
stream
),
"sync HierarchicalAllReduce exter stream error"
);
}
}
void
InterBroadCast
(
platform
::
Place
place
,
void
*
sendbuff
,
size_t
count
,
ncclDataType_t
datatype
,
ncclRedOp_t
op
)
{
auto
nccl_ctxs
=
nccl_ctxs_
->
GetHierarchicalInterCtx
(
run_order_
);
int
dev_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
place
).
device
;
auto
&
nccl_ctx
=
nccl_ctxs
->
at
(
dev_id
);
auto
stream
=
nccl_ctx
.
stream
();
auto
comm
=
nccl_ctx
.
comm_
;
VLOG
(
10
)
<<
"before InterBroadCast buffer:"
<<
sendbuff
<<
", numel:"
<<
count
<<
", dev_id:"
<<
dev_id
<<
", dtype:"
<<
datatype
<<
", place:"
<<
place
<<
", stream:"
<<
stream
;
cudaStreamWaitEvent
(
stream
,
exter_events_
.
at
(
dev_id
),
0
);
PADDLE_ENFORCE
(
platform
::
dynload
::
ncclBcast
(
sendbuff
,
count
,
datatype
,
0
,
comm
,
stream
));
}
protected:
std
::
vector
<
platform
::
Place
>
places_
;
const
platform
::
MultiNCCLContextMap
*
nccl_ctxs_
{
nullptr
};
// When multi trainer call collective function, they need run the same order.
// Or the program will hang.So we use allreduce_deps_pass to set this
// run_order_.
int
run_order_
{
0
};
// Use 2d allreduce or not.
bool
use_hierarchical_allreduce_
{
false
};
private:
// hierarchical needed events
std
::
unordered_map
<
int
,
cudaEvent_t
>
inter_events_
;
std
::
unordered_map
<
int
,
cudaEvent_t
>
exter_events_
;
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/op_handle_base.cc
浏览文件 @
65bbf950
...
@@ -187,6 +187,11 @@ void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
...
@@ -187,6 +187,11 @@ void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
std
::
function
<
void
()
>
method
=
callback
;
std
::
function
<
void
()
>
method
=
callback
;
for
(
auto
&
p
:
dev_ctxes_
)
{
for
(
auto
&
p
:
dev_ctxes_
)
{
method
=
[
method
,
p
,
this
]()
{
method
=
[
method
,
p
,
this
]()
{
VLOG
(
10
)
<<
"cudadevicecontext:"
<<
static_cast
<
platform
::
CUDADeviceContext
*>
(
p
.
second
)
<<
", dev_id:"
<<
boost
::
get
<
platform
::
CUDAPlace
>
(
p
.
first
).
device
;
static_cast
<
platform
::
CUDADeviceContext
*>
(
p
.
second
)
->
RecordEvent
(
static_cast
<
platform
::
CUDADeviceContext
*>
(
p
.
second
)
->
RecordEvent
(
events_
.
at
(
boost
::
get
<
platform
::
CUDAPlace
>
(
p
.
first
).
device
),
events_
.
at
(
boost
::
get
<
platform
::
CUDAPlace
>
(
p
.
first
).
device
),
method
);
method
);
...
...
paddle/fluid/framework/details/parallel_ssa_graph_executor.cc
浏览文件 @
65bbf950
...
@@ -95,6 +95,7 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
...
@@ -95,6 +95,7 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
auto
seq_allreduce_pass
=
auto
seq_allreduce_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"all_reduce_deps_pass"
);
ir
::
PassRegistry
::
Instance
().
Get
(
"all_reduce_deps_pass"
);
seq_allreduce_pass
->
Set
<
bool
>
(
kUseHierarchicalAllReduce
,
new
bool
(
false
));
for
(
size_t
i
=
0
;
i
<
graphs_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
graphs_
.
size
();
++
i
)
{
graphs_
[
i
].
reset
(
seq_allreduce_pass
->
Apply
(
graphs_
[
i
].
release
()));
graphs_
[
i
].
reset
(
seq_allreduce_pass
->
Apply
(
graphs_
[
i
].
release
()));
}
}
...
...
paddle/fluid/framework/details/sparse_all_reduce_op_handle.cc
浏览文件 @
65bbf950
...
@@ -30,7 +30,7 @@ namespace details {
...
@@ -30,7 +30,7 @@ namespace details {
SparseAllReduceOpHandle
::
SparseAllReduceOpHandle
(
SparseAllReduceOpHandle
::
SparseAllReduceOpHandle
(
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
platform
::
NCCLContextMap
*
ctxs
,
bool
is_encoded
,
int
nranks
)
const
platform
::
Multi
NCCLContextMap
*
ctxs
,
bool
is_encoded
,
int
nranks
)
:
AllReduceOpHandle
(
node
,
local_scopes
,
places
,
ctxs
),
:
AllReduceOpHandle
(
node
,
local_scopes
,
places
,
ctxs
),
is_encoded_
(
is_encoded
),
is_encoded_
(
is_encoded
),
nranks_
(
nranks
)
{
nranks_
(
nranks
)
{
...
@@ -102,7 +102,8 @@ void SparseAllReduceOpHandle::RunImplEncoded() {
...
@@ -102,7 +102,8 @@ void SparseAllReduceOpHandle::RunImplEncoded() {
out_numel
=
(
out_numel
==
0
)
?
static_cast
<
size_t
>
(
out
.
numel
())
:
out_numel
;
out_numel
=
(
out_numel
==
0
)
?
static_cast
<
size_t
>
(
out
.
numel
())
:
out_numel
;
int
dev_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
place
).
device
;
int
dev_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
place
).
device
;
auto
&
nccl_ctx
=
nccl_ctxs_
->
at
(
dev_id
);
auto
*
nccl_ctxs
=
nccl_ctxs_
->
GetRunEnvNCCLCtx
(
run_order_
,
false
);
auto
&
nccl_ctx
=
nccl_ctxs
->
at
(
dev_id
);
auto
stream
=
nccl_ctx
.
stream
();
auto
stream
=
nccl_ctx
.
stream
();
auto
comm
=
nccl_ctx
.
comm_
;
auto
comm
=
nccl_ctx
.
comm_
;
...
...
paddle/fluid/framework/details/sparse_all_reduce_op_handle.h
浏览文件 @
65bbf950
...
@@ -32,7 +32,7 @@ class SparseAllReduceOpHandle : public AllReduceOpHandle {
...
@@ -32,7 +32,7 @@ class SparseAllReduceOpHandle : public AllReduceOpHandle {
SparseAllReduceOpHandle
(
ir
::
Node
*
node
,
SparseAllReduceOpHandle
(
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
platform
::
NCCLContextMap
*
ctxs
,
const
platform
::
Multi
NCCLContextMap
*
ctxs
,
bool
is_encoded
=
false
,
int
nranks
=
-
1
);
bool
is_encoded
=
false
,
int
nranks
=
-
1
);
std
::
string
Name
()
const
override
;
std
::
string
Name
()
const
override
;
...
...
paddle/fluid/framework/ir/multi_devices_graph_pass/all_reduce_deps_pass.cc
浏览文件 @
65bbf950
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/fused_all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
...
@@ -35,9 +36,20 @@ namespace ir {
...
@@ -35,9 +36,20 @@ namespace ir {
class
AllReduceDepsPass
:
public
ir
::
Pass
{
class
AllReduceDepsPass
:
public
ir
::
Pass
{
protected:
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
{
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
{
std
::
vector
<
details
::
AllReduceOpHandl
e
*>
all_reduce_op_handles
=
std
::
vector
<
details
::
OpHandleBas
e
*>
all_reduce_op_handles
=
GetSortedAllReduceOps
(
*
graph
);
GetSortedAllReduceOps
(
*
graph
);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto
use_hierarchical_allreduce
=
Get
<
bool
>
(
details
::
kUseHierarchicalAllReduce
);
for
(
size_t
i
=
0
;
i
<
all_reduce_op_handles
.
size
();
++
i
)
{
auto
op_handle
=
dynamic_cast
<
details
::
NCCLOpHandleBase
*>
(
all_reduce_op_handles
[
i
]);
PADDLE_ENFORCE
(
op_handle
,
"op_handle must be NCCLOpHandleBase"
);
op_handle
->
SetRunEnv
(
i
,
use_hierarchical_allreduce
);
}
#endif
for
(
size_t
i
=
1
;
i
<
all_reduce_op_handles
.
size
();
++
i
)
{
for
(
size_t
i
=
1
;
i
<
all_reduce_op_handles
.
size
();
++
i
)
{
auto
*
dep_var
=
new
details
::
DummyVarHandle
(
graph
->
CreateControlDepVar
());
auto
*
dep_var
=
new
details
::
DummyVarHandle
(
graph
->
CreateControlDepVar
());
graph
->
Get
<
details
::
GraphDepVars
>
(
details
::
kGraphDepVars
)
graph
->
Get
<
details
::
GraphDepVars
>
(
details
::
kGraphDepVars
)
...
@@ -51,13 +63,12 @@ class AllReduceDepsPass : public ir::Pass {
...
@@ -51,13 +63,12 @@ class AllReduceDepsPass : public ir::Pass {
}
}
}
}
std
::
vector
<
details
::
AllReduceOpHandl
e
*>
GetSortedAllReduceOps
(
std
::
vector
<
details
::
OpHandleBas
e
*>
GetSortedAllReduceOps
(
const
ir
::
Graph
&
graph
)
const
{
const
ir
::
Graph
&
graph
)
const
{
std
::
vector
<
details
::
AllReduceOpHandl
e
*>
all_reduce_op_handles
;
std
::
vector
<
details
::
OpHandleBas
e
*>
all_reduce_op_handles
;
std
::
unordered_map
<
details
::
OpHandleBase
*
,
size_t
>
pending_ops
;
std
::
unordered_map
<
details
::
OpHandleBase
*
,
size_t
>
pending_ops
;
std
::
unordered_set
<
details
::
OpHandleBase
*>
ready_ops
;
std
::
unordered_set
<
details
::
OpHandleBase
*>
ready_ops
;
std
::
unordered_set
<
details
::
OpHandleBase
*>
next_ready_ops
;
std
::
unordered_set
<
details
::
OpHandleBase
*>
next_ready_ops
;
auto
op_handles
=
ir
::
FilterByNodeWrapper
<
details
::
OpHandleBase
>
(
graph
);
auto
op_handles
=
ir
::
FilterByNodeWrapper
<
details
::
OpHandleBase
>
(
graph
);
size_t
num_of_ops
=
op_handles
.
size
();
size_t
num_of_ops
=
op_handles
.
size
();
for
(
details
::
OpHandleBase
*
op
:
op_handles
)
{
for
(
details
::
OpHandleBase
*
op
:
op_handles
)
{
...
@@ -95,13 +106,16 @@ class AllReduceDepsPass : public ir::Pass {
...
@@ -95,13 +106,16 @@ class AllReduceDepsPass : public ir::Pass {
void
GetSortedAllReduceOps
(
void
GetSortedAllReduceOps
(
const
std
::
unordered_set
<
details
::
OpHandleBase
*>&
ready_ops
,
const
std
::
unordered_set
<
details
::
OpHandleBase
*>&
ready_ops
,
std
::
vector
<
details
::
AllReduceOpHandl
e
*>*
all_reduce_op_handles
)
const
{
std
::
vector
<
details
::
OpHandleBas
e
*>*
all_reduce_op_handles
)
const
{
std
::
vector
<
details
::
AllReduceOpHandl
e
*>
current_all_reduce_op_handles
;
std
::
vector
<
details
::
OpHandleBas
e
*>
current_all_reduce_op_handles
;
for
(
auto
&
op_handle
:
ready_ops
)
{
for
(
auto
&
op_handle
:
ready_ops
)
{
auto
all_reduce_op_handle
=
auto
all_reduce_op_handle
=
dynamic_cast
<
details
::
AllReduceOpHandle
*>
(
op_handle
);
dynamic_cast
<
details
::
AllReduceOpHandle
*>
(
op_handle
);
if
(
all_reduce_op_handle
)
{
auto
fused_all_reduce_op_handle
=
current_all_reduce_op_handles
.
emplace_back
(
all_reduce_op_handle
);
dynamic_cast
<
details
::
FusedAllReduceOpHandle
*>
(
op_handle
);
if
(
all_reduce_op_handle
||
fused_all_reduce_op_handle
)
{
current_all_reduce_op_handles
.
emplace_back
(
op_handle
);
}
}
}
}
...
@@ -110,8 +124,8 @@ class AllReduceDepsPass : public ir::Pass {
...
@@ -110,8 +124,8 @@ class AllReduceDepsPass : public ir::Pass {
// Sort the current_all_reduce_op_handles according to the name of input.
// Sort the current_all_reduce_op_handles according to the name of input.
sort
(
current_all_reduce_op_handles
.
begin
(),
sort
(
current_all_reduce_op_handles
.
begin
(),
current_all_reduce_op_handles
.
end
(),
current_all_reduce_op_handles
.
end
(),
[](
const
details
::
AllReduceOpHandl
e
*
left
,
[](
const
details
::
OpHandleBas
e
*
left
,
const
details
::
AllReduceOpHandl
e
*
right
)
->
bool
{
const
details
::
OpHandleBas
e
*
right
)
->
bool
{
auto
left_in_vars
=
auto
left_in_vars
=
details
::
DynamicCast
<
details
::
VarHandle
>
(
left
->
Inputs
());
details
::
DynamicCast
<
details
::
VarHandle
>
(
left
->
Inputs
());
auto
right_in_vars
=
auto
right_in_vars
=
...
@@ -126,9 +140,9 @@ class AllReduceDepsPass : public ir::Pass {
...
@@ -126,9 +140,9 @@ class AllReduceDepsPass : public ir::Pass {
current_all_reduce_op_handles
.
end
());
current_all_reduce_op_handles
.
end
());
}
}
void
DebugString
(
const
ir
::
Graph
&
graph
,
void
DebugString
(
const
std
::
vector
<
details
::
AllReduceOpHandle
*>&
const
ir
::
Graph
&
graph
,
all_reduce_op_handles
)
const
{
const
std
::
vector
<
details
::
OpHandleBase
*>&
all_reduce_op_handles
)
const
{
// get vars order
// get vars order
std
::
map
<
int
,
std
::
vector
<
std
::
string
>>
vars
=
std
::
map
<
int
,
std
::
vector
<
std
::
string
>>
vars
=
GetSoredGradientsFromStaleProgram
(
graph
);
GetSoredGradientsFromStaleProgram
(
graph
);
...
...
paddle/fluid/framework/ir/multi_devices_graph_pass/fuse_all_reduce_op_pass.cc
浏览文件 @
65bbf950
...
@@ -34,7 +34,8 @@ class FuseAllReduceOpPass : public ir::Pass {
...
@@ -34,7 +34,8 @@ class FuseAllReduceOpPass : public ir::Pass {
auto
&
places
=
Get
<
const
std
::
vector
<
platform
::
Place
>>
(
details
::
kPlaces
);
auto
&
places
=
Get
<
const
std
::
vector
<
platform
::
Place
>>
(
details
::
kPlaces
);
auto
&
local_scopes
=
Get
<
const
std
::
vector
<
Scope
*>>
(
details
::
kLocalScopes
);
auto
&
local_scopes
=
Get
<
const
std
::
vector
<
Scope
*>>
(
details
::
kLocalScopes
);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto
*
nccl_ctxs
=
&
Get
<
platform
::
NCCLContextMap
>
(
details
::
kNCCLCtxs
);
auto
*
multi_nccl_ctxs
=
&
Get
<
platform
::
MultiNCCLContextMap
>
(
details
::
kNCCLCtxs
);
#endif
#endif
std
::
unordered_set
<
std
::
string
>
grads
;
std
::
unordered_set
<
std
::
string
>
grads
;
...
@@ -94,7 +95,7 @@ class FuseAllReduceOpPass : public ir::Pass {
...
@@ -94,7 +95,7 @@ class FuseAllReduceOpPass : public ir::Pass {
}
}
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
InsertFusedAllReduce
(
places
,
local_scopes
,
group_size
,
InsertFusedAllReduce
(
places
,
local_scopes
,
group_size
,
group_all_reduce_ops
,
nccl_ctxs
,
&
result
);
group_all_reduce_ops
,
multi_
nccl_ctxs
,
&
result
);
#else
#else
InsertFusedAllReduce
(
places
,
local_scopes
,
group_size
,
InsertFusedAllReduce
(
places
,
local_scopes
,
group_size
,
group_all_reduce_ops
,
&
result
);
group_all_reduce_ops
,
&
result
);
...
@@ -102,12 +103,12 @@ class FuseAllReduceOpPass : public ir::Pass {
...
@@ -102,12 +103,12 @@ class FuseAllReduceOpPass : public ir::Pass {
}
}
}
}
void
InsertFusedAllReduce
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
void
InsertFusedAllReduce
(
const
std
::
vector
<
Scope
*>
&
local_scop
es
,
const
std
::
vector
<
platform
::
Place
>
&
plac
es
,
const
size_t
num_of_all_reduce
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
size_t
num_of_all_reduce
,
const
std
::
vector
<
ir
::
Node
*>
&
all_reduce_ops
,
const
std
::
vector
<
ir
::
Node
*>
&
all_reduce_ops
,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const
platform
::
NCCLContextMap
*
nccl_ctxs
,
const
platform
::
MultiNCCLContextMap
*
multi_
nccl_ctxs
,
#endif
#endif
ir
::
Graph
*
result
)
const
{
ir
::
Graph
*
result
)
const
{
std
::
vector
<
details
::
VarHandleBase
*>
inputs
;
std
::
vector
<
details
::
VarHandleBase
*>
inputs
;
...
@@ -135,7 +136,7 @@ class FuseAllReduceOpPass : public ir::Pass {
...
@@ -135,7 +136,7 @@ class FuseAllReduceOpPass : public ir::Pass {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
CreateFusedAllReduceOp
(
inputs
,
outputs
,
num_of_all_reduce
,
places
,
CreateFusedAllReduceOp
(
inputs
,
outputs
,
num_of_all_reduce
,
places
,
local_scopes
,
nccl_ctxs
,
result
);
local_scopes
,
multi_
nccl_ctxs
,
result
);
#else
#else
CreateFusedAllReduceOp
(
inputs
,
outputs
,
num_of_all_reduce
,
places
,
CreateFusedAllReduceOp
(
inputs
,
outputs
,
num_of_all_reduce
,
places
,
local_scopes
,
result
);
local_scopes
,
result
);
...
@@ -150,13 +151,13 @@ class FuseAllReduceOpPass : public ir::Pass {
...
@@ -150,13 +151,13 @@ class FuseAllReduceOpPass : public ir::Pass {
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const
platform
::
NCCLContextMap
*
nccl_ctxs
,
const
platform
::
MultiNCCLContextMap
*
multi_
nccl_ctxs
,
#endif
#endif
ir
::
Graph
*
result
)
const
{
ir
::
Graph
*
result
)
const
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto
*
op_handle
=
new
details
::
FusedAllReduceOpHandle
(
auto
*
op_handle
=
new
details
::
FusedAllReduceOpHandle
(
result
->
CreateEmptyNode
(
"fused_all_reduce"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"fused_all_reduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes
,
places
,
num_of_all_reduce
,
nccl_ctxs
);
local_scopes
,
places
,
num_of_all_reduce
,
multi_
nccl_ctxs
);
#else
#else
auto
*
op_handle
=
new
details
::
FusedAllReduceOpHandle
(
auto
*
op_handle
=
new
details
::
FusedAllReduceOpHandle
(
result
->
CreateEmptyNode
(
"fused_all_reduce"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"fused_all_reduce"
,
ir
::
Node
::
Type
::
kOperation
),
...
@@ -172,7 +173,7 @@ class FuseAllReduceOpPass : public ir::Pass {
...
@@ -172,7 +173,7 @@ class FuseAllReduceOpPass : public ir::Pass {
}
}
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
if
(
!
nccl_ctxs
)
{
if
(
!
multi_
nccl_ctxs
)
{
SetCommunicationContext
(
places
,
op_handle
);
SetCommunicationContext
(
places
,
op_handle
);
}
}
#else
#else
...
...
paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc
浏览文件 @
65bbf950
...
@@ -157,7 +157,11 @@ void MultiDevSSAGraphBuilderBase::Init() const {
...
@@ -157,7 +157,11 @@ void MultiDevSSAGraphBuilderBase::Init() const {
local_scopes_
=
Get
<
const
std
::
vector
<
Scope
*>>
(
details
::
kLocalScopes
);
local_scopes_
=
Get
<
const
std
::
vector
<
Scope
*>>
(
details
::
kLocalScopes
);
strategy_
=
Get
<
const
details
::
BuildStrategy
>
(
kStrategy
);
strategy_
=
Get
<
const
details
::
BuildStrategy
>
(
kStrategy
);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
nccl_ctxs_
=
&
Get
<
platform
::
NCCLContextMap
>
(
details
::
kNCCLCtxs
);
multi_nccl_ctxs_
=
&
Get
<
platform
::
MultiNCCLContextMap
>
(
details
::
kNCCLCtxs
);
nccl_ctxs_
=
nullptr
;
if
(
multi_nccl_ctxs_
)
{
nccl_ctxs_
=
multi_nccl_ctxs_
->
DefaultFlatCtx
();
}
#endif
#endif
PADDLE_ENFORCE_EQ
(
places_
.
size
(),
local_scopes_
.
size
());
PADDLE_ENFORCE_EQ
(
places_
.
size
(),
local_scopes_
.
size
());
}
}
...
@@ -460,20 +464,20 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result,
...
@@ -460,20 +464,20 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result,
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
details
::
SparseAllReduceOpHandle
(
new
details
::
SparseAllReduceOpHandle
(
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
scopes
,
places
,
nccl_ctxs_
,
is_encoded
,
scopes
,
places
,
multi_
nccl_ctxs_
,
is_encoded
,
static_cast
<
int
>
(
strategy_
.
trainers_endpoints_
.
size
())
*
static_cast
<
int
>
(
strategy_
.
trainers_endpoints_
.
size
())
*
places_
.
size
()));
places_
.
size
()));
}
else
{
}
else
{
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
details
::
AllReduceOpHandle
(
new
details
::
AllReduceOpHandle
(
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
scopes
,
places
,
nccl_ctxs_
));
scopes
,
places
,
multi_
nccl_ctxs_
));
}
}
#elif defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#elif defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
details
::
AllReduceOpHandle
(
new
details
::
AllReduceOpHandle
(
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
scopes
,
places
,
nccl_ctxs_
));
scopes
,
places
,
multi_
nccl_ctxs_
));
#else
#else
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
details
::
AllReduceOpHandle
(
new
details
::
AllReduceOpHandle
(
...
...
paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h
浏览文件 @
65bbf950
...
@@ -96,7 +96,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
...
@@ -96,7 +96,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
size_t
device_id
)
const
;
size_t
device_id
)
const
;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
mutable
platform
::
NCCLContextMap
*
nccl_ctxs_
;
mutable
platform
::
NCCLContextMap
*
nccl_ctxs_
{
nullptr
};
mutable
platform
::
MultiNCCLContextMap
*
multi_nccl_ctxs_
{
nullptr
};
#endif
#endif
mutable
std
::
string
loss_var_name_
;
mutable
std
::
string
loss_var_name_
;
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
65bbf950
...
@@ -94,6 +94,89 @@ class ParallelExecutorPrivate {
...
@@ -94,6 +94,89 @@ class ParallelExecutorPrivate {
}
}
}
}
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
void
InitNCCLCtxs
(
framework
::
Scope
*
scope
,
const
BuildStrategy
&
bst
)
{
VLOG
(
1
)
<<
"nccl comm num:"
<<
bst
.
nccl_comm_num_
<<
", nranks:"
<<
nranks_
<<
", num_trainers:"
<<
bst
.
num_trainers_
<<
", trainer_id:"
<<
bst
.
trainer_id_
;
if
(
bst
.
use_hierarchical_allreduce_
)
{
VLOG
(
1
)
<<
", use_hierarchical_allreduce:"
<<
bst
.
use_hierarchical_allreduce_
<<
", inter_trainers_num:"
<<
bst
.
hierarchical_allreduce_inter_nranks_
<<
", exter_trainers_num:"
<<
bst
.
hierarchical_allreduce_exter_nranks_
;
}
std
::
vector
<
ncclUniqueId
*>
flat_nccl_ids
;
if
(
nranks_
==
1
)
{
// FIXME(gongwb): need not to create ncclid when nranks==1
nccl_ctxs_
.
InitFlatCtxs
(
places_
,
flat_nccl_ids
,
bst
.
num_trainers_
,
bst
.
trainer_id_
);
return
;
}
if
(
bst
.
enable_parallel_graph_
)
{
VLOG
(
1
)
<<
"use only one ncclid in pg model"
;
ncclUniqueId
*
nccl_id
=
nullptr
;
std
::
string
var_name
=
platform
::
GetFlatNCCLVarName
(
0
);
auto
nccl_id_var
=
scope
->
FindVar
(
var_name
);
if
(
nccl_id_var
)
{
nccl_id
=
nccl_id_var
->
GetMutable
<
ncclUniqueId
>
();
}
else
{
nccl_id
=
new
ncclUniqueId
();
PADDLE_ENFORCE
(
platform
::
dynload
::
ncclGetUniqueId
(
nccl_id
));
}
flat_nccl_ids
.
push_back
(
nccl_id
);
nccl_ctxs_
.
InitFlatCtxs
(
places_
,
flat_nccl_ids
,
bst
.
num_trainers_
,
bst
.
trainer_id_
);
VLOG
(
1
)
<<
"init bst nccl context complete!"
;
return
;
}
// num_trainers ==1 && places > 1
if
(
bst
.
num_trainers_
==
1
)
{
nccl_ctxs_
.
InitFlatCtxs
(
places_
,
flat_nccl_ids
,
bst
.
num_trainers_
,
bst
.
trainer_id_
);
return
;
}
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
bst
.
nccl_comm_num_
);
i
++
)
{
std
::
string
var_name
=
platform
::
GetFlatNCCLVarName
(
i
);
auto
nccl_id_var
=
scope
->
FindVar
(
var_name
);
PADDLE_ENFORCE
(
nccl_id_var
,
"can't find %s nccl_id_var"
,
var_name
);
auto
nccl_id
=
nccl_id_var
->
GetMutable
<
ncclUniqueId
>
();
flat_nccl_ids
.
push_back
(
nccl_id
);
}
nccl_ctxs_
.
InitFlatCtxs
(
places_
,
flat_nccl_ids
,
bst
.
num_trainers_
,
bst
.
trainer_id_
);
if
(
bst
.
use_hierarchical_allreduce_
)
{
std
::
string
var_name
=
platform
::
GetHierarchicalInterNCCLVarName
();
auto
nccl_id_var
=
scope
->
FindVar
(
var_name
);
auto
inter_nccl_id
=
nccl_id_var
->
GetMutable
<
ncclUniqueId
>
();
std
::
vector
<
ncclUniqueId
*>
exter_nccl_ids
;
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
bst
.
nccl_comm_num_
);
i
++
)
{
std
::
string
var_name
=
platform
::
GetHierarchicalExterNCCLVarName
(
i
);
auto
nccl_id_var
=
scope
->
FindVar
(
var_name
);
PADDLE_ENFORCE
(
nccl_id_var
,
"can't find %s nccl_id_var"
,
var_name
);
auto
nccl_id
=
nccl_id_var
->
GetMutable
<
ncclUniqueId
>
();
exter_nccl_ids
.
push_back
(
nccl_id
);
}
nccl_ctxs_
.
InitHierarchicalCtxs
(
places_
,
inter_nccl_id
,
exter_nccl_ids
,
bst
.
num_trainers_
,
bst
.
trainer_id_
,
bst
.
hierarchical_allreduce_inter_nranks_
,
bst
.
hierarchical_allreduce_exter_nranks_
);
}
}
#endif
BuildStrategy
build_strategy_
;
BuildStrategy
build_strategy_
;
std
::
vector
<
platform
::
Place
>
places_
;
std
::
vector
<
platform
::
Place
>
places_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
Scope
*>
local_scopes_
;
...
@@ -101,7 +184,7 @@ class ParallelExecutorPrivate {
...
@@ -101,7 +184,7 @@ class ParallelExecutorPrivate {
std
::
unique_ptr
<
details
::
SSAGraphExecutor
>
executor_
;
std
::
unique_ptr
<
details
::
SSAGraphExecutor
>
executor_
;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
std
::
unique_ptr
<
platform
::
NCCLContextMap
>
nccl_ctxs_
;
platform
::
MultiNCCLContextMap
nccl_ctxs_
;
#endif
#endif
bool
own_local_scope_
;
bool
own_local_scope_
;
bool
use_cuda_
;
bool
use_cuda_
;
...
@@ -254,24 +337,7 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
...
@@ -254,24 +337,7 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
if
(
member_
->
use_cuda_
)
{
if
(
member_
->
use_cuda_
)
{
// Bcast Parameters to all GPUs
// Bcast Parameters to all GPUs
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
ncclUniqueId
*
nccl_id
=
nullptr
;
member_
->
InitNCCLCtxs
(
scope
,
build_strategy
);
// gen_nccl_id operator can broadcast the ncclUniqueId for nccl2 collective
// distributed training
auto
*
nccl_id_var
=
scope
->
FindVar
(
NCCL_ID_VARNAME
);
if
(
nccl_id_var
!=
nullptr
)
{
nccl_id
=
nccl_id_var
->
GetMutable
<
ncclUniqueId
>
();
}
if
(
build_strategy
.
enable_parallel_graph_
&&
member_
->
nranks_
>
1UL
)
{
if
(
nccl_id
==
nullptr
)
{
local_nccl_id_
.
reset
(
new
ncclUniqueId
());
platform
::
dynload
::
ncclGetUniqueId
(
local_nccl_id_
.
get
());
nccl_id
=
local_nccl_id_
.
get
();
}
}
member_
->
nccl_ctxs_
.
reset
(
new
platform
::
NCCLContextMap
(
member_
->
places_
,
nccl_id
,
build_strategy
.
num_trainers_
,
build_strategy
.
trainer_id_
));
// Initialize device context's nccl comm, will be used by normal
// Initialize device context's nccl comm, will be used by normal
// Operators like sync_batch_norm, and collective ops.
// Operators like sync_batch_norm, and collective ops.
...
@@ -280,23 +346,15 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
...
@@ -280,23 +346,15 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
// NOTE: NCCL group-calls and non-group-calls can not use the same
// NOTE: NCCL group-calls and non-group-calls can not use the same
// NCCL communicator, so for ParallelGraph and Multi-Process mode, re-use
// NCCL communicator, so for ParallelGraph and Multi-Process mode, re-use
// same communicators.
// same communicators.
std
::
unique_ptr
<
platform
::
NCCLContextMap
>
dev_nccl_ctxs
;
if
(
nccl_id
==
nullptr
)
{
dev_nccl_ctxs
.
reset
(
new
platform
::
NCCLContextMap
(
member_
->
places_
));
}
for
(
size_t
dev_id
=
0
;
dev_id
<
member_
->
places_
.
size
();
++
dev_id
)
{
for
(
size_t
dev_id
=
0
;
dev_id
<
member_
->
places_
.
size
();
++
dev_id
)
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
::
Instance
();
auto
*
dev_ctx
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
auto
*
dev_ctx
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
pool
.
Get
(
member_
->
places_
[
dev_id
]));
pool
.
Get
(
member_
->
places_
[
dev_id
]));
if
(
nccl_id
!=
nullptr
)
{
auto
&
nccl_ctx
=
auto
&
nccl_ctx
=
member_
->
nccl_ctxs_
->
at
(
member_
->
places_
[
dev_id
]);
member_
->
nccl_ctxs_
.
DefaultFlatCtx
()
->
at
(
member_
->
places_
[
dev_id
]);
dev_ctx
->
set_nccl_comm
(
nccl_ctx
.
comm
());
}
else
{
auto
&
nccl_ctx
=
dev_nccl_ctxs
->
at
(
member_
->
places_
[
dev_id
]);
dev_ctx
->
set_nccl_comm
(
nccl_ctx
.
comm
());
dev_ctx
->
set_nccl_comm
(
nccl_ctx
.
comm
());
}
}
}
#else
#else
PADDLE_THROW
(
"Not compiled with CUDA"
);
PADDLE_THROW
(
"Not compiled with CUDA"
);
#endif
#endif
...
@@ -327,18 +385,18 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
...
@@ -327,18 +385,18 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
VLOG
(
3
)
<<
"use local async mode"
;
VLOG
(
3
)
<<
"use local async mode"
;
graph
=
build_strategy
.
Apply
(
graph
,
{
member_
->
places_
[
0
]},
loss_var_name
,
graph
=
build_strategy
.
Apply
(
graph
,
{
member_
->
places_
[
0
]},
loss_var_name
,
{
member_
->
local_scopes_
[
0
]},
1
,
{
member_
->
local_scopes_
[
0
]},
1
,
member_
->
use_cuda_
,
member_
->
nccl_ctxs_
.
get
()
);
member_
->
use_cuda_
,
&
member_
->
nccl_ctxs_
);
for
(
size_t
i
=
1
;
i
<
member_
->
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
1
;
i
<
member_
->
places_
.
size
();
++
i
)
{
graphs
[
i
]
=
graphs
[
i
]
=
build_strategy
.
Apply
(
graphs
[
i
],
{
member_
->
places_
[
i
]},
loss_var_name
,
build_strategy
.
Apply
(
graphs
[
i
],
{
member_
->
places_
[
i
]},
loss_var_name
,
{
member_
->
local_scopes_
[
i
]},
1
,
{
member_
->
local_scopes_
[
i
]},
1
,
member_
->
use_cuda_
,
member_
->
nccl_ctxs_
.
get
()
);
member_
->
use_cuda_
,
&
member_
->
nccl_ctxs_
);
async_graphs
[
i
]
=
graphs
[
i
];
async_graphs
[
i
]
=
graphs
[
i
];
}
}
}
else
{
}
else
{
graph
=
build_strategy
.
Apply
(
graph
,
member_
->
places_
,
loss_var_name
,
graph
=
build_strategy
.
Apply
(
graph
,
member_
->
places_
,
loss_var_name
,
member_
->
local_scopes_
,
member_
->
nranks_
,
member_
->
local_scopes_
,
member_
->
nranks_
,
member_
->
use_cuda_
,
member_
->
nccl_ctxs_
.
get
()
);
member_
->
use_cuda_
,
&
member_
->
nccl_ctxs_
);
}
}
#else
#else
if
(
build_strategy
.
async_mode_
)
{
if
(
build_strategy
.
async_mode_
)
{
...
@@ -471,13 +529,14 @@ void ParallelExecutor::BCastParamsToDevices(
...
@@ -471,13 +529,14 @@ void ParallelExecutor::BCastParamsToDevices(
PADDLE_ENFORCE_EQ
(
member_
->
places_
.
size
(),
buffers
.
size
(),
PADDLE_ENFORCE_EQ
(
member_
->
places_
.
size
(),
buffers
.
size
(),
"variables' buffer size to bcast NOT equal to places"
);
"variables' buffer size to bcast NOT equal to places"
);
{
{
auto
*
nccl_ctxs
=
member_
->
nccl_ctxs_
.
DefaultFlatCtx
();
platform
::
NCCLGroupGuard
guard
;
platform
::
NCCLGroupGuard
guard
;
for
(
size_t
i
=
0
;
i
<
member_
->
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
member_
->
places_
.
size
();
++
i
)
{
auto
&
nccl_ctx
=
member_
->
nccl_ctxs_
->
at
(
member_
->
places_
[
i
]);
auto
&
nccl_ctx
=
nccl_ctxs
->
at
(
member_
->
places_
[
i
]);
platform
::
dynload
::
ncclBcast
(
buffers
[
i
],
numel
,
data_type
,
0
,
platform
::
dynload
::
ncclBcast
(
buffers
[
i
],
numel
,
data_type
,
0
,
nccl_ctx
.
comm_
,
nccl_ctx
.
stream
());
nccl_ctx
.
comm_
,
nccl_ctx
.
stream
());
}
}
member_
->
nccl_ctxs_
->
WaitAll
();
nccl_ctxs
->
WaitAll
();
}
}
#else
#else
PADDLE_THROW
(
"Not compiled with CUDA"
);
PADDLE_THROW
(
"Not compiled with CUDA"
);
...
@@ -512,6 +571,7 @@ void ParallelExecutor::BCastParamsToDevices(
...
@@ -512,6 +571,7 @@ void ParallelExecutor::BCastParamsToDevices(
void
ParallelExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
void
ParallelExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
const
std
::
string
&
fetched_var_name
)
{
const
std
::
string
&
fetched_var_name
)
{
VLOG
(
3
)
<<
"enter ParallelExecutor Run"
;
#ifdef WITH_GPERFTOOLS
#ifdef WITH_GPERFTOOLS
if
(
gProfileStarted
)
{
if
(
gProfileStarted
)
{
ProfilerFlush
();
ProfilerFlush
();
...
@@ -522,6 +582,8 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
...
@@ -522,6 +582,8 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
if
(
member_
->
HasGarbageCollectors
())
{
if
(
member_
->
HasGarbageCollectors
())
{
member_
->
ResetRuntimeReferenceCount
(
fetch_tensors
,
fetched_var_name
);
member_
->
ResetRuntimeReferenceCount
(
fetch_tensors
,
fetched_var_name
);
}
}
VLOG
(
3
)
<<
"ParallelExecutor begin to run member_->executor_->Run"
;
auto
fetch_data
=
member_
->
executor_
->
Run
(
fetch_tensors
);
auto
fetch_data
=
member_
->
executor_
->
Run
(
fetch_tensors
);
*
member_
->
global_scope_
->
Var
(
fetched_var_name
)
->
GetMutable
<
FeedFetchList
>
()
=
*
member_
->
global_scope_
->
Var
(
fetched_var_name
)
->
GetMutable
<
FeedFetchList
>
()
=
fetch_data
;
fetch_data
;
...
...
paddle/fluid/framework/parallel_executor.h
浏览文件 @
65bbf950
...
@@ -87,9 +87,6 @@ class ParallelExecutor {
...
@@ -87,9 +87,6 @@ class ParallelExecutor {
ParallelExecutorPrivate
*
member_
;
ParallelExecutorPrivate
*
member_
;
std
::
vector
<
std
::
unique_ptr
<
ir
::
Graph
>>
async_graphs_
;
std
::
vector
<
std
::
unique_ptr
<
ir
::
Graph
>>
async_graphs_
;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
std
::
unique_ptr
<
ncclUniqueId
>
local_nccl_id_
;
#endif
};
};
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/operators/distributed_ops/gen_nccl_id_op.cc
浏览文件 @
65bbf950
...
@@ -41,31 +41,129 @@ class GenNCCLIdOp : public framework::OperatorBase {
...
@@ -41,31 +41,129 @@ class GenNCCLIdOp : public framework::OperatorBase {
// put nccl id in CPUPlace
// put nccl id in CPUPlace
auto
&
dev_ctx
=
*
pool
.
Get
(
platform
::
CPUPlace
());
auto
&
dev_ctx
=
*
pool
.
Get
(
platform
::
CPUPlace
());
int
trainer_id
=
Attr
<
int
>
(
"trainer_id"
);
int
trainer_id
=
Attr
<
int
>
(
"trainer_id"
);
std
::
vector
<
std
::
string
>
trainers
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"trainers"
);
PADDLE_ENFORCE
(
trainer_id
>=
0
&&
trainer_id
<
static_cast
<
int
>
(
trainers
.
size
()),
"trainer_id:%d must be in trainers.size range"
,
trainer_id
);
std
::
string
endpoint
=
trainers
[
trainer_id
];
framework
::
Scope
&
local_scope
=
scope
.
NewScope
();
framework
::
Scope
&
local_scope
=
scope
.
NewScope
();
int
nccl_comm_num
=
Attr
<
int
>
(
"nccl_comm_num"
);
int
use_hierarchical_allreduce
=
Attr
<
bool
>
(
"use_hierarchical_allreduce"
);
int
inter_nranks
=
Attr
<
int
>
(
"hierarchical_allreduce_inter_nranks"
);
int
inter_trainer_id
=
-
1
;
int
exter_trainer_id
=
-
1
;
if
(
use_hierarchical_allreduce
)
{
PADDLE_ENFORCE
(
trainers
.
size
()
>
1
,
"trainers.size():%llu < 1"
,
trainers
.
size
());
PADDLE_ENFORCE
(
inter_nranks
>
1
,
"inter_nranks:%d < 1"
,
inter_nranks
);
PADDLE_ENFORCE
((
trainers
.
size
()
%
inter_nranks
==
0
),
"trainers.size():%llu mod inter_nranks:%d != 0"
,
trainers
.
size
(),
inter_nranks
);
inter_trainer_id
=
trainer_id
%
inter_nranks
;
if
(
trainer_id
%
inter_nranks
==
0
)
{
exter_trainer_id
=
trainer_id
/
inter_nranks
;
}
}
if
(
trainer_id
!=
0
)
{
GetIdByServer
(
endpoint
,
&
local_scope
,
dev_ctx
,
nccl_comm_num
,
use_hierarchical_allreduce
,
trainer_id
,
inter_trainer_id
,
exter_trainer_id
);
}
std
::
ostringstream
ss
;
for
(
size_t
i
=
0
;
i
<
trainers
.
size
();
i
++
)
{
ss
<<
trainers
[
i
]
<<
","
;
}
VLOG
(
1
)
<<
"trainer_id:"
<<
trainer_id
<<
", use_hierarchical_allreduce:"
<<
use_hierarchical_allreduce
<<
", inter_nranks:"
<<
inter_nranks
<<
", inter_trainer_id:"
<<
inter_trainer_id
<<
", exter_trainer_id:"
<<
exter_trainer_id
<<
", trainers:"
<<
ss
.
str
();
// init flat
if
(
trainer_id
==
0
)
{
if
(
trainer_id
==
0
)
{
GenerateAndSend
(
&
local_scope
,
dev_ctx
);
std
::
vector
<
std
::
string
>
flat_endpoints
;
}
else
{
flat_endpoints
.
insert
(
flat_endpoints
.
begin
(),
trainers
.
begin
()
+
1
,
GetIdByServer
(
&
local_scope
,
dev_ctx
);
trainers
.
end
());
// flat nccl_id
for
(
int
i
=
0
;
i
<
nccl_comm_num
;
i
++
)
{
std
::
string
var_name
=
platform
::
GetFlatNCCLVarName
(
i
);
GenerateAndSend
(
&
local_scope
,
dev_ctx
,
var_name
,
flat_endpoints
);
}
}
if
(
!
use_hierarchical_allreduce
)
{
return
;
}
PADDLE_ENFORCE
(
trainers
.
size
()
%
inter_nranks
==
0
,
"enpoints.size:%llu mod inter_nranks:%d should ==0"
,
trainers
.
size
(),
inter_nranks
);
PADDLE_ENFORCE
(
inter_nranks
>
1
,
"inter_nranks:%d must > 1"
,
inter_nranks
);
// hierarchical inter ncclid
if
(
inter_trainer_id
==
0
)
{
std
::
ostringstream
ss
;
ss
<<
endpoint
;
std
::
vector
<
std
::
string
>
inter_endpoints
;
for
(
int
i
=
trainer_id
+
1
;
i
<
trainer_id
+
inter_nranks
&&
i
<
static_cast
<
int
>
(
trainers
.
size
());
i
++
)
{
ss
<<
","
;
inter_endpoints
.
push_back
(
trainers
[
i
]);
ss
<<
trainers
[
i
];
}
VLOG
(
1
)
<<
"Hierarchical inter ring endpoints:"
<<
ss
.
str
();
std
::
string
nccl_var_name
=
platform
::
GetHierarchicalInterNCCLVarName
();
GenerateAndSend
(
&
local_scope
,
dev_ctx
,
nccl_var_name
,
inter_endpoints
);
}
// hierarchical exter ncclid
if
(
exter_trainer_id
==
0
)
{
std
::
ostringstream
ss
;
std
::
vector
<
std
::
string
>
exter_endpoints
;
ss
<<
endpoint
;
for
(
size_t
i
=
inter_nranks
;
i
<
trainers
.
size
();
i
+=
inter_nranks
)
{
ss
<<
","
;
exter_endpoints
.
push_back
(
trainers
[
i
]);
ss
<<
trainers
[
i
];
}
VLOG
(
1
)
<<
"Hierarchical exter ring endpoints:"
<<
ss
.
str
();
for
(
int
i
=
0
;
i
<
nccl_comm_num
;
i
++
)
{
std
::
string
nccl_var_name
=
platform
::
GetHierarchicalExterNCCLVarName
(
i
);
GenerateAndSend
(
&
local_scope
,
dev_ctx
,
nccl_var_name
,
exter_endpoints
);
}
}
}
}
}
private:
private:
void
GenerateAndSend
(
framework
::
Scope
*
scope
,
void
GenerateAndSend
(
framework
::
Scope
*
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
{
const
platform
::
DeviceContext
&
dev_ctx
,
auto
var
=
scope
->
FindVar
(
NCCL_ID_VARNAME
);
const
std
::
string
&
nccl_id_name
,
PADDLE_ENFORCE_NOT_NULL
(
var
);
const
std
::
vector
<
std
::
string
>&
endpoint_list
)
const
{
auto
var
=
scope
->
FindVar
(
nccl_id_name
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
"can't find nccl_id_var_name:%s"
,
nccl_id_name
);
auto
id
=
var
->
GetMutable
<
ncclUniqueId
>
();
auto
id
=
var
->
GetMutable
<
ncclUniqueId
>
();
PADDLE_ENFORCE
(
platform
::
dynload
::
ncclGetUniqueId
(
id
));
PADDLE_ENFORCE
(
platform
::
dynload
::
ncclGetUniqueId
(
id
));
std
::
vector
<
std
::
string
>
endpoint_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"endpoint_list"
);
distributed
::
RPCClient
*
client
=
distributed
::
RPCClient
*
client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
0
);
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
0
);
for
(
auto
&
ep
:
endpoint_list
)
{
for
(
auto
&
ep
:
endpoint_list
)
{
VLOG
(
3
)
<<
"sending nccl
id
to "
<<
ep
;
VLOG
(
3
)
<<
"sending nccl
_id_var:"
<<
nccl_id_name
<<
"
to "
<<
ep
;
client
->
AsyncSendVar
(
ep
,
dev_ctx
,
*
scope
,
NCCL_ID_VARNAME
);
client
->
AsyncSendVar
(
ep
,
dev_ctx
,
*
scope
,
nccl_id_name
);
}
}
client
->
Wait
();
client
->
Wait
();
for
(
auto
&
ep
:
endpoint_list
)
{
for
(
auto
&
ep
:
endpoint_list
)
{
...
@@ -75,9 +173,11 @@ class GenNCCLIdOp : public framework::OperatorBase {
...
@@ -75,9 +173,11 @@ class GenNCCLIdOp : public framework::OperatorBase {
VLOG
(
3
)
<<
"sending completed..."
;
VLOG
(
3
)
<<
"sending completed..."
;
}
}
void
GetIdByServer
(
framework
::
Scope
*
scope
,
void
GetIdByServer
(
const
std
::
string
&
endpoint
,
framework
::
Scope
*
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
{
const
platform
::
DeviceContext
&
dev_ctx
,
int
nccl_comm_num
,
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
bool
use_hierarchical_allreduce
,
int
trainer_id
,
int
inter_trainer_id
,
int
exter_trainer_id
)
const
{
// std::string endpoint = Attr<std::string>("endpoint");
// NOTE: Can not use unique_ptr here because the default
// NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and
// deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash.
// that will cause a wired crash.
...
@@ -98,10 +198,42 @@ class GenNCCLIdOp : public framework::OperatorBase {
...
@@ -98,10 +198,42 @@ class GenNCCLIdOp : public framework::OperatorBase {
std
::
thread
server_thread
(
std
::
thread
server_thread
(
std
::
bind
(
&
distributed
::
RPCServer
::
StartServer
,
rpc_service
.
get
()));
std
::
bind
(
&
distributed
::
RPCServer
::
StartServer
,
rpc_service
.
get
()));
for
(
int
i
=
0
;
i
<
nccl_comm_num
;
i
++
)
{
rpc_service
->
SetCond
(
distributed
::
kRequestSend
);
VLOG
(
3
)
<<
"trainer_id:"
<<
trainer_id
<<
" start getting nccl id from trainer 0, nccl_comm_no:"
<<
i
;
rpc_service
->
WaitBarrier
(
distributed
::
kRequestSend
);
rpc_service
->
ResetBarrierCounter
();
}
if
(
use_hierarchical_allreduce
)
{
if
(
inter_trainer_id
>
0
)
{
rpc_service
->
SetCond
(
distributed
::
kRequestSend
);
VLOG
(
3
)
<<
"trainer_id:"
<<
trainer_id
<<
", inter_trainer_id:"
<<
inter_trainer_id
<<
" start getting nccl id from inter_trainer 0"
;
rpc_service
->
WaitBarrier
(
distributed
::
kRequestSend
);
rpc_service
->
ResetBarrierCounter
();
}
if
(
exter_trainer_id
>
0
)
{
for
(
int
i
=
0
;
i
<
nccl_comm_num
;
i
++
)
{
rpc_service
->
SetCond
(
distributed
::
kRequestSend
);
rpc_service
->
SetCond
(
distributed
::
kRequestSend
);
VLOG
(
3
)
<<
"start getting nccl id from trainer 0..."
;
VLOG
(
3
)
<<
"trainer_id:"
<<
trainer_id
<<
", exter_trainer_id:"
<<
exter_trainer_id
<<
" start getting nccl id from exter_trainer 0, nccl_comm_no:"
<<
i
;
rpc_service
->
WaitBarrier
(
distributed
::
kRequestSend
);
rpc_service
->
WaitBarrier
(
distributed
::
kRequestSend
);
VLOG
(
3
)
<<
"got nccl id and stop server..."
;
rpc_service
->
ResetBarrierCounter
();
}
}
}
VLOG
(
3
)
<<
"traier_id:"
<<
trainer_id
<<
", inter_trainer_id:"
<<
inter_trainer_id
<<
", exter_trainer_id:"
<<
exter_trainer_id
<<
" got nccl id and stop server..."
;
rpc_service
->
ShutDown
();
rpc_service
->
ShutDown
();
VLOG
(
3
)
<<
"rpc server stopped"
;
VLOG
(
3
)
<<
"rpc server stopped"
;
server_thread
.
join
();
server_thread
.
join
();
...
@@ -118,18 +250,26 @@ GenNCCLId operator
...
@@ -118,18 +250,26 @@ GenNCCLId operator
For trainer 0: generate a new UniqueId and send it to all the other trainers.
For trainer 0: generate a new UniqueId and send it to all the other trainers.
For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server.
For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server.
)DOC"
);
)DOC"
);
AddAttr
<
std
::
string
>
(
"endpoint"
,
"(string), e.g. 127.0.0.1:6175 "
"current listen endpoint"
);
AddAttr
<
std
::
vector
<
std
::
string
>>
(
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"
endpoint_list
"
,
"
trainers
"
,
"['trainer
1_ip:port', 'trainer2
_ip:port', ...] "
"['trainer
0_ip:port', 'trainer1
_ip:port', ...] "
"list of
trainer endpoints start from trainer 1
"
)
"list of
all trainer endpoints
"
)
.
SetDefault
({});
.
SetDefault
({});
AddAttr
<
int
>
(
"trainer_id"
,
AddAttr
<
int
>
(
"trainer_id"
,
"(int default 0) "
"(int) "
"The index of the trainer in distributed training."
)
"The index of the trainer in distributed training."
);
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"nccl_comm_num"
,
"(int default 1) "
"The number of nccl communicator num."
)
.
SetDefault
(
1
);
AddAttr
<
bool
>
(
"use_hierarchical_allreduce"
,
"(bool default false) "
"Wheter to use hierarchical allreduce."
)
.
SetDefault
(
false
);
AddAttr
<
int
>
(
"hierarchical_allreduce_inter_nranks"
,
"(int default 1) "
"Wheter to use hierarchical allreduce."
)
.
SetDefault
(
-
1
);
}
}
};
};
...
...
paddle/fluid/platform/nccl_helper.h
浏览文件 @
65bbf950
...
@@ -124,8 +124,8 @@ struct NCCLContextMap {
...
@@ -124,8 +124,8 @@ struct NCCLContextMap {
}
else
{
}
else
{
rank
=
trainer_id
;
rank
=
trainer_id
;
}
}
VLOG
(
3
)
<<
"init nccl rank: "
<<
rank
<<
" nranks:
"
<<
nranks
VLOG
(
1
)
<<
"init nccl rank:"
<<
rank
<<
", nranks:
"
<<
nranks
<<
"
gpu id: "
<<
gpu_id
;
<<
"
, gpu_id:"
<<
gpu_id
<<
", dev_id:"
<<
order_
[
i
]
;
PADDLE_ENFORCE
(
cudaSetDevice
(
gpu_id
));
PADDLE_ENFORCE
(
cudaSetDevice
(
gpu_id
));
PADDLE_ENFORCE
(
platform
::
dynload
::
ncclCommInitRank
(
PADDLE_ENFORCE
(
platform
::
dynload
::
ncclCommInitRank
(
comms
.
get
()
+
i
,
nranks
,
*
nccl_id
,
rank
));
comms
.
get
()
+
i
,
nranks
,
*
nccl_id
,
rank
));
...
@@ -160,6 +160,134 @@ struct NCCLContextMap {
...
@@ -160,6 +160,134 @@ struct NCCLContextMap {
}
}
};
};
inline
std
::
string
GetFlatNCCLVarName
(
size_t
pos
)
{
if
(
pos
==
0
)
{
return
NCCL_ID_VARNAME
;
}
return
string
::
Sprintf
(
"%s_%d"
,
NCCL_ID_VARNAME
,
static_cast
<
int
>
(
pos
));
}
inline
std
::
string
GetHierarchicalExterNCCLVarName
(
size_t
pos
)
{
return
string
::
Sprintf
(
"Hierarchical_exter_%s_%d"
,
NCCL_ID_VARNAME
,
static_cast
<
int
>
(
pos
));
}
inline
std
::
string
GetHierarchicalInterNCCLVarName
()
{
return
string
::
Sprintf
(
"Hierarchical_inter_%s"
,
NCCL_ID_VARNAME
);
}
class
MultiNCCLContextMap
{
public:
MultiNCCLContextMap
()
{}
virtual
~
MultiNCCLContextMap
()
{}
NCCLContextMap
*
DefaultFlatCtx
()
const
{
if
(
flat_ctxs_
.
size
()
==
0
)
{
return
nullptr
;
}
return
flat_ctxs_
[
0
].
get
();
}
std
::
vector
<
std
::
unique_ptr
<
NCCLContextMap
>>
*
GetFlatCtxs
()
{
return
&
flat_ctxs_
;
}
NCCLContextMap
*
GetFlatCtx
(
size_t
run_order
)
const
{
return
flat_ctxs_
[
run_order
%
flat_ctxs_
.
size
()].
get
();
}
NCCLContextMap
*
GetRunEnvNCCLCtx
(
size_t
run_order
,
bool
use_hierarchical_allreduce
)
const
{
if
(
!
use_hierarchical_allreduce
)
{
return
GetFlatCtx
(
run_order
);
}
return
GetHierarchicalInterCtx
(
run_order
);
}
void
InitFlatCtxs
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
ncclUniqueId
*>
&
nccl_ids
,
size_t
trainers_num
,
size_t
trainer_id
)
{
if
(
nccl_ids
.
size
()
==
0
)
{
auto
ptr
=
new
platform
::
NCCLContextMap
(
places
);
VLOG
(
1
)
<<
"init local trainer"
;
flat_ctxs_
.
emplace_back
(
ptr
);
return
;
}
for
(
size_t
i
=
0
;
i
<
nccl_ids
.
size
();
i
++
)
{
auto
ptr
=
new
platform
::
NCCLContextMap
(
places
,
nccl_ids
[
i
],
trainers_num
,
trainer_id
);
VLOG
(
1
)
<<
"init trainer_id:"
<<
trainer_id
<<
", comm no:"
<<
i
;
flat_ctxs_
.
emplace_back
(
ptr
);
}
}
void
InitHierarchicalCtxs
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
ncclUniqueId
*
inter_nccl_id
,
const
std
::
vector
<
ncclUniqueId
*>
&
exter_nccl_id
,
size_t
trainers_num
,
size_t
trainer_id
,
size_t
inter_trainers_num
,
size_t
exter_trainers_num
)
{
PADDLE_ENFORCE
(
trainers_num
==
inter_trainers_num
*
exter_trainers_num
,
"trainers_num:%llu != inter_trainers_num:%llu * "
"exter_trainers_num:%llu"
,
trainers_num
,
inter_trainers_num
,
exter_trainers_num
);
PADDLE_ENFORCE
(
inter_trainers_num
>
1
,
"inter_trainers_num:%llu must > 1"
,
inter_trainers_num
);
int
inter_trainer_id
=
trainer_id
%
inter_trainers_num
;
VLOG
(
1
)
<<
"init inter_trainer_id:"
<<
inter_trainer_id
;
auto
local
=
new
NCCLContextMap
(
places
,
inter_nccl_id
,
inter_trainers_num
,
inter_trainer_id
);
h_inter_ctxs_
.
emplace_back
(
local
);
int
exter_trainer_id
=
-
1
;
if
(
trainer_id
%
inter_trainers_num
==
0
)
{
exter_trainer_id
=
trainer_id
/
inter_trainers_num
;
}
if
(
exter_trainer_id
>=
0
)
{
for
(
size_t
i
=
0
;
i
<
exter_nccl_id
.
size
();
i
++
)
{
auto
ex
=
new
NCCLContextMap
(
places
,
exter_nccl_id
[
i
],
exter_trainers_num
,
exter_trainer_id
);
VLOG
(
1
)
<<
"init exter_trainer_id:"
<<
exter_trainer_id
<<
", comm no:"
<<
i
;
h_exter_ctxs_
.
emplace_back
(
ex
);
}
}
}
bool
NeedExterAllReduce
()
const
{
return
h_exter_ctxs_
.
size
()
>
0
;
}
NCCLContextMap
*
GetHierarchicalInterCtx
(
size_t
run_order
)
const
{
return
h_inter_ctxs_
[
run_order
%
h_inter_ctxs_
.
size
()].
get
();
}
NCCLContextMap
*
GetHierarchicalExterCtx
(
size_t
run_order
)
const
{
return
h_exter_ctxs_
[
run_order
%
h_exter_ctxs_
.
size
()].
get
();
}
std
::
vector
<
std
::
unique_ptr
<
NCCLContextMap
>>
*
GetHierarchicalInterCtxs
()
{
return
&
h_inter_ctxs_
;
}
std
::
vector
<
std
::
unique_ptr
<
NCCLContextMap
>>
*
GetHierarchicalExterCtxs
()
{
return
&
h_exter_ctxs_
;
}
protected:
// Support multi nccl comm on default nccl ring while NCCLContextMap can't.
std
::
vector
<
std
::
unique_ptr
<
NCCLContextMap
>>
flat_ctxs_
;
// h_inter_ctxs_ and h_exter_ctxs_ are for 2d allreduce.
// And h_exter_ctxs_ can support multi comm too.
std
::
vector
<
std
::
unique_ptr
<
NCCLContextMap
>>
h_inter_ctxs_
;
std
::
vector
<
std
::
unique_ptr
<
NCCLContextMap
>>
h_exter_ctxs_
;
};
}
// namespace platform
}
// namespace platform
}
// namespace paddle
}
// namespace paddle
#endif
#endif
paddle/fluid/pybind/pybind.cc
浏览文件 @
65bbf950
...
@@ -1483,6 +1483,34 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -1483,6 +1483,34 @@ All parameter, weight, gradient are variables in Paddle.
[](
BuildStrategy
&
self
,
int
trainer_id
)
{
[](
BuildStrategy
&
self
,
int
trainer_id
)
{
self
.
trainer_id_
=
trainer_id
;
self
.
trainer_id_
=
trainer_id
;
})
})
.
def_property
(
"nccl_comm_num"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
nccl_comm_num_
;
},
[](
BuildStrategy
&
self
,
int
nccl_comm_num
)
{
self
.
nccl_comm_num_
=
nccl_comm_num
;
})
.
def_property
(
"use_hierarchical_allreduce_"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
use_hierarchical_allreduce_
;
},
[](
BuildStrategy
&
self
,
bool
use
)
{
self
.
use_hierarchical_allreduce_
=
use
;
})
.
def_property
(
"hierarchical_allreduce_inter_nranks_"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
hierarchical_allreduce_inter_nranks_
;
},
[](
BuildStrategy
&
self
,
int
nranks
)
{
self
.
hierarchical_allreduce_inter_nranks_
=
nranks
;
})
.
def_property
(
"hierarchical_allreduce_exter_nranks_"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
hierarchical_allreduce_exter_nranks_
;
},
[](
BuildStrategy
&
self
,
int
nranks
)
{
self
.
hierarchical_allreduce_exter_nranks_
=
nranks
;
})
.
def_property
(
.
def_property
(
"fuse_elewise_add_act_ops"
,
"fuse_elewise_add_act_ops"
,
[](
const
BuildStrategy
&
self
)
{
[](
const
BuildStrategy
&
self
)
{
...
...
python/paddle/fluid/compiler.py
浏览文件 @
65bbf950
...
@@ -98,6 +98,7 @@ class CompiledProgram(object):
...
@@ -98,6 +98,7 @@ class CompiledProgram(object):
def
__init__
(
self
,
program_or_graph
):
def
__init__
(
self
,
program_or_graph
):
if
isinstance
(
program_or_graph
,
core
.
Graph
):
if
isinstance
(
program_or_graph
,
core
.
Graph
):
self
.
_graph
=
program_or_graph
self
.
_graph
=
program_or_graph
# don't not create a new program here.
self
.
_program
=
None
self
.
_program
=
None
elif
isinstance
(
program_or_graph
,
framework
.
Program
):
elif
isinstance
(
program_or_graph
,
framework
.
Program
):
self
.
_graph
=
core
.
Graph
(
program_or_graph
.
desc
)
self
.
_graph
=
core
.
Graph
(
program_or_graph
.
desc
)
...
@@ -299,6 +300,7 @@ class CompiledProgram(object):
...
@@ -299,6 +300,7 @@ class CompiledProgram(object):
# TODO(wuyi): trainer endpoings should be passed in through
# TODO(wuyi): trainer endpoings should be passed in through
# build_strategy, not program.xxx.
# build_strategy, not program.xxx.
# TODO(gongwb): let user to set them once.
if
self
.
_program
and
self
.
_build_strategy
.
num_trainers
>
1
and
\
if
self
.
_program
and
self
.
_build_strategy
.
num_trainers
>
1
and
\
self
.
_program
.
_trainers_endpoints
:
self
.
_program
.
_trainers_endpoints
:
tps
=
self
.
_program
.
_trainers_endpoints
tps
=
self
.
_program
.
_trainers_endpoints
...
@@ -307,6 +309,12 @@ class CompiledProgram(object):
...
@@ -307,6 +309,12 @@ class CompiledProgram(object):
tps
),
"num_trainers == len(end_points)"
tps
),
"num_trainers == len(end_points)"
self
.
_build_strategy
.
trainers_endpoints
=
tps
self
.
_build_strategy
.
trainers_endpoints
=
tps
if
self
.
_program
:
self
.
_build_strategy
.
nccl_comm_num
=
self
.
_program
.
_nccl_comm_num
self
.
_build_strategy
.
use_hierarchical_allreduce_
=
self
.
_program
.
_use_hierarchical_allreduce
self
.
_build_strategy
.
hierarchical_allreduce_inter_nranks_
=
self
.
_program
.
_hierarchical_allreduce_inter_nranks
self
.
_build_strategy
.
hierarchical_allreduce_exter_nranks_
=
self
.
_program
.
_hierarchical_allreduce_exter_nranks
if
self
.
_build_strategy
.
sync_batch_norm
:
if
self
.
_build_strategy
.
sync_batch_norm
:
self
.
_build_strategy
.
enable_sequential_execution
=
True
self
.
_build_strategy
.
enable_sequential_execution
=
True
...
...
python/paddle/fluid/framework.py
浏览文件 @
65bbf950
...
@@ -2762,6 +2762,10 @@ class Program(object):
...
@@ -2762,6 +2762,10 @@ class Program(object):
# use Deep gradient comrepssion or not
# use Deep gradient comrepssion or not
self
.
_enable_dgc
=
False
self
.
_enable_dgc
=
False
self
.
_nccl_comm_num
=
1
self
.
_use_hierarchical_allreduce
=
False
self
.
_hierarchical_allreduce_inter_nranks
=
0
self
.
_hierarchical_allreduce_exter_nranks
=
0
# @deprecated(the python memory optimize transpiler is deprecated)
# @deprecated(the python memory optimize transpiler is deprecated)
# whether the program is optimized by memory_optimize_transpiler
# whether the program is optimized by memory_optimize_transpiler
...
...
python/paddle/fluid/tests/unittests/test_dist_base.py
浏览文件 @
65bbf950
...
@@ -51,11 +51,14 @@ class TestDistRunnerBase(object):
...
@@ -51,11 +51,14 @@ class TestDistRunnerBase(object):
trainers
,
trainers
,
sync_mode
,
sync_mode
,
dc_asgd
=
False
,
dc_asgd
=
False
,
current_endpoint
=
None
):
current_endpoint
=
None
,
nccl_comm_num
=
1
):
# NOTE: import fluid until runtime, or else forking processes will cause error.
# NOTE: import fluid until runtime, or else forking processes will cause error.
config
=
fluid
.
DistributeTranspilerConfig
()
config
=
fluid
.
DistributeTranspilerConfig
()
config
.
enable_dc_asgd
=
dc_asgd
config
.
enable_dc_asgd
=
dc_asgd
config
.
sync_mode
=
sync_mode
config
.
sync_mode
=
sync_mode
if
nccl_comm_num
>
1
:
config
.
nccl_comm_num
=
nccl_comm_num
# config.runtime_split_send_recv = True
# config.runtime_split_send_recv = True
t
=
fluid
.
DistributeTranspiler
(
config
=
config
)
t
=
fluid
.
DistributeTranspiler
(
config
=
config
)
t
.
transpile
(
t
.
transpile
(
...
@@ -106,6 +109,7 @@ class TestDistRunnerBase(object):
...
@@ -106,6 +109,7 @@ class TestDistRunnerBase(object):
# transpile for nccl2
# transpile for nccl2
config
=
fluid
.
DistributeTranspilerConfig
()
config
=
fluid
.
DistributeTranspilerConfig
()
config
.
mode
=
"nccl2"
config
.
mode
=
"nccl2"
config
.
nccl_comm_num
=
args
.
nccl_comm_num
nccl2_t
=
fluid
.
DistributeTranspiler
(
config
=
config
)
nccl2_t
=
fluid
.
DistributeTranspiler
(
config
=
config
)
nccl2_t
.
transpile
(
nccl2_t
.
transpile
(
args
.
trainer_id
,
args
.
trainer_id
,
...
@@ -113,6 +117,7 @@ class TestDistRunnerBase(object):
...
@@ -113,6 +117,7 @@ class TestDistRunnerBase(object):
startup_program
=
fluid
.
default_startup_program
(),
startup_program
=
fluid
.
default_startup_program
(),
trainers
=
args
.
endpoints
,
trainers
=
args
.
endpoints
,
current_endpoint
=
args
.
current_endpoint
)
current_endpoint
=
args
.
current_endpoint
)
trainer_prog
=
fluid
.
default_main_program
()
trainer_prog
=
fluid
.
default_main_program
()
else
:
else
:
trainer_prog
=
fluid
.
default_main_program
()
trainer_prog
=
fluid
.
default_main_program
()
...
@@ -268,6 +273,7 @@ def runtime_main(test_class):
...
@@ -268,6 +273,7 @@ def runtime_main(test_class):
choices
=
[
"pserver"
,
"nccl2"
,
"local"
,
"nccl2_reduce_layer"
])
choices
=
[
"pserver"
,
"nccl2"
,
"local"
,
"nccl2_reduce_layer"
])
parser
.
add_argument
(
'--trainer_id'
,
type
=
int
,
required
=
False
,
default
=
0
)
parser
.
add_argument
(
'--trainer_id'
,
type
=
int
,
required
=
False
,
default
=
0
)
parser
.
add_argument
(
'--trainers'
,
type
=
int
,
required
=
False
,
default
=
1
)
parser
.
add_argument
(
'--trainers'
,
type
=
int
,
required
=
False
,
default
=
1
)
parser
.
add_argument
(
'--nccl_comm_num'
,
type
=
int
,
required
=
False
,
default
=
1
)
parser
.
add_argument
(
parser
.
add_argument
(
'--current_endpoint'
,
type
=
str
,
required
=
False
,
default
=
""
)
'--current_endpoint'
,
type
=
str
,
required
=
False
,
default
=
""
)
parser
.
add_argument
(
'--sync_mode'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--sync_mode'
,
action
=
'store_true'
)
...
@@ -345,6 +351,7 @@ class TestDistBase(unittest.TestCase):
...
@@ -345,6 +351,7 @@ class TestDistBase(unittest.TestCase):
self
.
_lr
=
0.001
self
.
_lr
=
0.001
self
.
_use_dgc
=
False
self
.
_use_dgc
=
False
self
.
_dygraph
=
False
self
.
_dygraph
=
False
self
.
_nccl_comm_num
=
1
self
.
_setup_config
()
self
.
_setup_config
()
self
.
_after_setup_config
()
self
.
_after_setup_config
()
...
@@ -590,6 +597,11 @@ class TestDistBase(unittest.TestCase):
...
@@ -590,6 +597,11 @@ class TestDistBase(unittest.TestCase):
if
self
.
_use_dgc
:
if
self
.
_use_dgc
:
tr0_cmd
+=
" --use_dgc"
tr0_cmd
+=
" --use_dgc"
tr1_cmd
+=
" --use_dgc"
tr1_cmd
+=
" --use_dgc"
if
self
.
_nccl_comm_num
>
1
:
tr0_cmd
+=
" --nccl_comm_num {}"
.
format
(
self
.
_nccl_comm_num
)
tr1_cmd
+=
" --nccl_comm_num {}"
.
format
(
self
.
_nccl_comm_num
)
if
self
.
_mp_mode
:
if
self
.
_mp_mode
:
env0
=
{
"FLAGS_selected_gpus"
:
"0"
}
env0
=
{
"FLAGS_selected_gpus"
:
"0"
}
env1
=
{
"FLAGS_selected_gpus"
:
"1"
}
env1
=
{
"FLAGS_selected_gpus"
:
"1"
}
...
...
python/paddle/fluid/tests/unittests/test_dist_mnist.py
浏览文件 @
65bbf950
...
@@ -39,6 +39,20 @@ class TestDistMnistNCCL2(TestDistBase):
...
@@ -39,6 +39,20 @@ class TestDistMnistNCCL2(TestDistBase):
self
.
check_with_place
(
"dist_mnist.py"
,
delta
=
1e-5
)
self
.
check_with_place
(
"dist_mnist.py"
,
delta
=
1e-5
)
class
TestDistMnistNCCL2MultiNCCLComm
(
TestDistBase
):
def
_setup_config
(
self
):
self
.
_sync_mode
=
True
self
.
_use_reduce
=
False
self
.
_use_reader_alloc
=
False
self
.
_nccl2_mode
=
True
self
.
_nccl_comm_num
=
3
def
test_dist_train
(
self
):
import
paddle.fluid
as
fluid
if
fluid
.
core
.
is_compiled_with_cuda
():
self
.
check_with_place
(
"dist_mnist.py"
,
delta
=
1e-5
)
class
TestDistMnistNCCL2DGC
(
TestDistBase
):
class
TestDistMnistNCCL2DGC
(
TestDistBase
):
def
_setup_config
(
self
):
def
_setup_config
(
self
):
self
.
_sync_mode
=
True
self
.
_sync_mode
=
True
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
65bbf950
...
@@ -165,6 +165,15 @@ class DistributeTranspilerConfig(object):
...
@@ -165,6 +165,15 @@ class DistributeTranspilerConfig(object):
runtime_split_send_recv
=
False
runtime_split_send_recv
=
False
sync_mode
=
True
sync_mode
=
True
nccl_comm_num
=
1
#The picture here illustrates the principle:
#https://github.com/PaddlePaddle/Paddle/pull/17263#discussion_r285411396
use_hierarchical_allreduce
=
False
#Nccl ranks in a node when use hierarchical allreduce, it's setted to gpu cards' number in most cases.
hierarchical_allreduce_inter_nranks
=
0
#Nccl ranks bewteen nodes when use hierarchical allreduce, it's setted to nodes number.
hierarchical_allreduce_exter_nranks
=
0
class
DistributeTranspiler
(
object
):
class
DistributeTranspiler
(
object
):
"""
"""
...
@@ -261,14 +270,36 @@ class DistributeTranspiler(object):
...
@@ -261,14 +270,36 @@ class DistributeTranspiler(object):
nccl_id_var
=
startup_program
.
global_block
().
create_var
(
nccl_id_var
=
startup_program
.
global_block
().
create_var
(
name
=
"NCCLID"
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
name
=
"NCCLID"
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
for
i
in
range
(
1
,
self
.
config
.
nccl_comm_num
):
startup_program
.
global_block
().
create_var
(
name
=
"NCCLID_{}"
.
format
(
i
),
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
if
self
.
config
.
use_hierarchical_allreduce
:
startup_program
.
global_block
().
create_var
(
name
=
"Hierarchical_inter_NCCLID"
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
for
i
in
range
(
0
,
self
.
config
.
nccl_comm_num
):
startup_program
.
global_block
().
create_var
(
name
=
"Hierarchical_exter_NCCLID_{}"
.
format
(
i
),
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
startup_program
.
global_block
().
append_op
(
startup_program
.
global_block
().
append_op
(
type
=
"gen_nccl_id"
,
type
=
"gen_nccl_id"
,
inputs
=
{},
inputs
=
{},
outputs
=
{
"NCCLID"
:
nccl_id_var
},
outputs
=
{
"NCCLID"
:
nccl_id_var
},
attrs
=
{
attrs
=
{
"endpoint"
:
current_endpoint
,
"trainers"
:
trainers
.
split
(
","
),
"endpoint_list"
:
worker_endpoints
,
"trainer_id"
:
trainer_id
,
"trainer_id"
:
trainer_id
"nccl_comm_num"
:
self
.
config
.
nccl_comm_num
,
"use_hierarchical_allreduce"
:
self
.
config
.
use_hierarchical_allreduce
,
"hierarchical_allreduce_inter_nranks"
:
self
.
config
.
hierarchical_allreduce_inter_nranks
})
})
return
nccl_id_var
return
nccl_id_var
else
:
else
:
...
@@ -350,6 +381,12 @@ class DistributeTranspiler(object):
...
@@ -350,6 +381,12 @@ class DistributeTranspiler(object):
if
self
.
config
.
mode
==
"nccl2"
:
if
self
.
config
.
mode
==
"nccl2"
:
assert
(
isinstance
(
trainers
,
str
))
assert
(
isinstance
(
trainers
,
str
))
self
.
origin_program
.
_trainers_endpoints
=
trainers
.
split
(
","
)
self
.
origin_program
.
_trainers_endpoints
=
trainers
.
split
(
","
)
self
.
origin_program
.
_nccl_comm_num
=
self
.
config
.
nccl_comm_num
self
.
origin_program
.
_use_hierarchical_allreduce
=
self
.
config
.
use_hierarchical_allreduce
self
.
origin_program
.
_hierarchical_allreduce_inter_nranks
=
\
int
(
self
.
config
.
hierarchical_allreduce_inter_nranks
)
self
.
origin_program
.
_hierarchical_allreduce_exter_nranks
=
\
int
(
self
.
config
.
hierarchical_allreduce_exter_nranks
)
self
.
_transpile_nccl2
(
self
.
_transpile_nccl2
(
trainer_id
,
trainer_id
,
trainers
,
trainers
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录