未验证 提交 07741593 编写于 作者: K kuizhiqing 提交者: GitHub

new group (#31682)

* new group

* ci compatible fix

* assert nccl
上级 dbeb3ea4
......@@ -19,12 +19,11 @@
#include <utility>
#include <vector>
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/bkcl_helper.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/split.h"
#include "paddle/fluid/string/string_helper.h"
......@@ -77,7 +76,7 @@ void BKCLParallelContext::Init() {
bkcl_ids.resize(strategy_.nrings_);
if (strategy_.local_rank_ == 0) {
// generate the unique ncclid on the root worker
// generate the unique bkclid on the root worker
for (size_t i = 0; i < bkcl_ids.size(); ++i) {
auto ret = bkcl_get_unique_id(&bkcl_ids[i]);
PADDLE_ENFORCE_EQ(BKCL_SUCCESS, ret,
......@@ -99,6 +98,28 @@ void BKCLParallelContext::Init() {
}
}
void BKCLParallelContext::InitWithRingID(int ring_id) {
std::vector<BKCLUniqueId> bkcl_ids;
bkcl_ids.resize(1);
if (strategy_.local_rank_ == 0) {
// generate the unique bkclid on the root worker
auto ret = bkcl_get_unique_id(&bkcl_ids[0]);
PADDLE_ENFORCE_EQ(BKCL_SUCCESS, ret,
platform::errors::PreconditionNotMet(
"BKCL get unique id failed [%d]", ret));
}
BcastBKCLId(bkcl_ids, 0);
int xpu_id = BOOST_GET_CONST(platform::XPUPlace, place_).device;
VLOG(0) << "init BKCL context nranks: " << strategy_.nranks_
<< " local rank: " << strategy_.local_rank_ << " xpu id: " << xpu_id
<< " ring id: " << ring_id;
// it will assign bkcl_comm in XPUDeviceContext within ring_id
platform::BKCLCommContext::Instance().CreateBKCLComm(
&bkcl_ids[0], strategy_.nranks_, strategy_.local_rank_, xpu_id, ring_id);
}
void BKCLParallelContext::AllReduceByStream(const framework::Variable &src,
framework::Variable *dst,
int ring_id, bool use_calc_stream) {
......
......@@ -36,6 +36,8 @@ class BKCLParallelContext : public ParallelContext {
void Init() override;
void InitWithRingID(int ring_id) override;
void AllReduceByStream(const framework::Variable& src,
framework::Variable* dst, int ring_id,
bool use_calc_stream) override;
......
......@@ -79,6 +79,30 @@ void NCCLParallelContext::Init() {
}
}
void NCCLParallelContext::InitWithRingID(int ring_id) {
std::vector<ncclUniqueId> nccl_ids;
nccl_ids.resize(1);
if (strategy_.local_rank_ == 0) {
// generate the unique ncclid on the root worker
platform::dynload::ncclGetUniqueId(&nccl_ids[0]);
}
BcastNCCLId(nccl_ids, 0);
int gpu_id = BOOST_GET_CONST(platform::CUDAPlace, place_).device;
VLOG(0) << "init nccl context nranks: " << strategy_.nranks_
<< " local rank: " << strategy_.local_rank_ << " gpu id: " << gpu_id
<< " ring id: " << ring_id;
// it will assign nccl_comm in CUDADeviceContext within ring_id
platform::NCCLCommContext::Instance().CreateNCCLComm(
&nccl_ids[0], strategy_.nranks_, strategy_.local_rank_, gpu_id, ring_id);
compute_events_.emplace_back(platform::CudaEventResourcePool::Instance().New(
BOOST_GET_CONST(platform::CUDAPlace, place_).device));
comm_events_.emplace_back(platform::CudaEventResourcePool::Instance().New(
BOOST_GET_CONST(platform::CUDAPlace, place_).device));
}
void NCCLParallelContext::AllReduceByStream(const framework::Variable &src,
framework::Variable *dst,
int ring_id, bool use_calc_stream) {
......
......@@ -53,6 +53,8 @@ class NCCLParallelContext : public ParallelContext {
void Init() override;
void InitWithRingID(int ring_id) override;
void AllReduceByStream(const framework::Variable& src,
framework::Variable* dst, int ring_id,
bool use_calc_stream) override;
......
......@@ -50,6 +50,8 @@ class ParallelContext {
virtual void Init() = 0;
virtual void InitWithRingID(int ring_id) = 0;
virtual void AllReduceByStream(const framework::Variable& src,
framework::Variable* dst, int ring_id,
bool use_calc_stream) = 0;
......
......@@ -15,40 +15,20 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
class CSyncCalcStreamOp : public framework::OperatorBase {
class CSyncCalcStreamOp : public framework::OperatorWithKernel {
public:
CSyncCalcStreamOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
PADDLE_ENFORCE_EQ(is_gpu_place(place), true,
platform::errors::PreconditionNotMet(
"Sync stream op can run on gpu place only for now."));
#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && !defined(_WIN32)
auto dev_ctx = static_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(place));
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(dev_ctx->stream()));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(dev_ctx->stream()));
#endif
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.GetPlace());
}
};
......@@ -65,10 +45,36 @@ Call calculation stream synchronization.
}
};
template <typename T>
class CSyncCalcStreamCudaKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && !defined(_WIN32)
auto place = ctx.GetPlace();
auto dev_ctx = static_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(place));
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(dev_ctx->stream()));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(dev_ctx->stream()));
#endif
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(c_sync_calc_stream, ops::CSyncCalcStreamOp,
ops::CSyncCalcStreamOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(c_sync_calc_stream, ops::CSyncCalcStreamOp,
ops::CSyncCalcStreamOpMaker);
REGISTER_OP_CUDA_KERNEL(c_sync_calc_stream,
ops::CSyncCalcStreamCudaKernel<float>);
......@@ -14,45 +14,25 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
class CSyncCommStreamOp : public framework::OperatorBase {
class CSyncCommStreamOp : public framework::OperatorWithKernel {
public:
CSyncCommStreamOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
PADDLE_ENFORCE_EQ(is_gpu_place(place), true,
platform::errors::PreconditionNotMet(
"Sync stream op can run on gpu place only for now."));
using framework::OperatorWithKernel::OperatorWithKernel;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
int ring_id = Attr<int>("ring_id");
auto stream =
platform::NCCLCommContext::Instance().Get(ring_id, place)->stream();
#ifdef PADDLE_WITH_RCCL
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
#endif
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif
void InferShape(framework::InferShapeContext* ctx) const override {}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.GetPlace());
}
};
......@@ -72,10 +52,38 @@ Call communication stream synchronization.
}
};
template <typename T>
class CSyncCommStreamCudaKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto place = ctx.GetPlace();
int ring_id = ctx.Attr<int>("ring_id");
auto stream =
platform::NCCLCommContext::Instance().Get(ring_id, place)->stream();
#ifdef PADDLE_WITH_RCCL
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
#endif
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(c_sync_comm_stream, ops::CSyncCommStreamOp,
ops::CSyncCommStreamOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(c_sync_comm_stream, ops::CSyncCommStreamOp,
ops::CSyncCommStreamOpMaker);
REGISTER_OP_CUDA_KERNEL(c_sync_comm_stream,
ops::CSyncCommStreamCudaKernel<float>);
......@@ -1578,7 +1578,10 @@ void BindImperative(py::module *m_ptr) {
m, "NCCLParallelContext")
.def(py::init<const imperative::ParallelStrategy &,
const platform::CUDAPlace &>())
.def("init", [](imperative::NCCLParallelContext &self) { self.Init(); });
.def("init", [](imperative::NCCLParallelContext &self) { self.Init(); })
.def("init_with_ring_id",
&imperative::NCCLParallelContext::InitWithRingID,
py::arg("ring_id"));
#endif
#if defined(PADDLE_WITH_XPU_BKCL)
......@@ -1587,7 +1590,10 @@ void BindImperative(py::module *m_ptr) {
m, "BKCLParallelContext")
.def(py::init<const imperative::ParallelStrategy &,
const platform::XPUPlace &>())
.def("init", [](imperative::BKCLParallelContext &self) { self.Init(); });
.def("init", [](imperative::BKCLParallelContext &self) { self.Init(); })
.def("init_with_ring_id",
&imperative::BKCLParallelContext::InitWithRingID,
py::arg("ring_id"));
#endif
}
......
......@@ -119,6 +119,8 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"fill_constant", {"Out"}},
{"matmul", {"Out"}},
{"c_broadcast", {"Out"}},
{"c_sync_calc_stream", {"Out"}},
{"c_sync_comm_stream", {"Out"}},
{"c_allreduce_sum", {"Out"}},
{"c_allreduce_max", {"Out"}},
{"c_allreduce_min", {"Out"}},
......
......@@ -26,6 +26,9 @@ import paddle.fluid as fluid
import paddle.fluid.core as core
__all__ = [
'wait',
'new_group',
'get_group',
'broadcast',
'all_reduce',
'reduce',
......@@ -75,30 +78,225 @@ class ReduceOp:
PROD = 3
class _Group():
"""The abstract representation of group."""
class Group():
"""
The abstract representation of group.
"""
def __init__(self, rank, rank_num):
def __init__(self, rank, rank_num, id=0, ranks=[]):
self.rank = rank
self.nranks = rank_num
self.id = id
self.ranks = ranks
def is_member(self):
if self.rank < 0:
return False
if self.nranks < 2:
return False
return True
def get_group_rank(self, rank):
if self.id == 0:
return rank
if self.is_member() and rank in self.ranks:
return self.ranks.index(rank)
else:
return -1
_global_env = None
def _get_global_env():
global _global_env
if not _global_env:
_global_env = paddle.distributed.ParallelEnv()
return _global_env
# group map : the map of all group, 0 for GlobalGroup
# Dict[int, Group]
_group_map = {}
def _get_group_map():
global _group_map
if not _group_map:
genv = _get_global_env()
_group_map[0] = Group(genv.rank, genv.world_size, 0)
return _group_map
def _get_global_group():
return _get_group_map()[0]
def _new_ring_id():
return len(_get_group_map()) + max(_get_global_env().nrings, 9)
def get_group(id=0):
"""
Get group instance by group id.
Args:
id (int): the group id
Returns:
Group: the group instance.
Examples:
.. code-block:: python
...
gid = paddle.distributed.new_group([2,4,6])
paddle.distributed.get_group(gid.id)
"""
gm = _get_group_map()
return gm[group] if group in gm else None
def new_group(ranks=None, backend=None):
"""
Creates a new distributed comminication group.
Args:
ranks (list): The global ranks of group members, list as sorted.
backend (str): The backend used to create group, only nccl is supported now.
Returns:
Group: The group instance. Nerver return None.
Examples:
.. code-block:: python
import numpy as np
import paddle
paddle.distributed.init_parallel_env()
tindata = np.random.random([10, 1000]).astype('float32')
tindata = paddle.to_tensor(tindata)
gid = paddle.distributed.new_group([2,4,6])
paddle.distributed.all_reduce(tindata, group=gid, use_calc_stream=False)
"""
if not backend:
backend = 'nccl'
assert backend == 'nccl', ("backend other than nccl is not supported yet")
genv = _get_global_env()
global_rank = genv.rank
ring_id = _new_ring_id()
global _group_map
if global_rank not in ranks:
gp = Group(-1, -1, ring_id, ranks)
_group_map[ring_id] = gp
return gp
ranks = sorted(ranks)
group_rank = ranks.index(global_rank)
group_size = len(ranks)
gp = Group(group_rank, group_size, ring_id, ranks)
_group_map[ring_id] = gp
if group_size < 2:
return gp
strategy = core.ParallelStrategy()
strategy.nranks = group_size
strategy.local_rank = group_rank
strategy.trainer_endpoints = [genv.trainer_endpoints[i] for i in ranks]
strategy.current_endpoint = genv.current_endpoint
strategy.nrings = 1
if core.is_compiled_with_cuda():
place = core.CUDAPlace(genv.device_id)
core.NCCLParallelContext(strategy, place).init_with_ring_id(ring_id)
else:
assert False
return gp
def wait(tensor, group=None, use_calc_stream=True):
"""
wait to sync stream for group.
Args:
tensor (Tensor): The Tensor used before sync.
group (Group): The Group instance to perform sync.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False),
default to False.
Returns:
None.
Examples:
.. code-block:: python
import numpy as np
import paddle
paddle.distributed.init_parallel_env()
tindata = np.random.random([10, 1000]).astype('float32')
tindata = paddle.to_tensor(tindata)
paddle.distributed.all_reduce(tindata, use_calc_stream=True)
paddle.distributed.wait(tindata)
"""
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
if use_calc_stream:
_sync_calc_stream(tensor)
else:
_sync_comm_stream(tensor, ring_id)
def _sync_calc_stream(tensor):
if in_dygraph_mode():
return core.ops.c_sync_calc_stream(tensor, tensor)
op_type = 'c_sync_calc_stream'
helper = LayerHelper(op_type, **locals())
helper.append_op(
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [tensor]}, )
# NOTE(chenweihang): Lazily initialized global group information
# If we initialize _default_group when import module, it will
# not update when we use spawn to run multi-process training
_default_group = None
def _sync_comm_stream(tensor, ring_id=0):
def _get_global_default_group():
global _default_group
if _default_group is None:
_default_group = _Group(
int(os.getenv("PADDLE_TRAINER_ID", "0")),
int(os.getenv("PADDLE_TRAINERS_NUM", "1")))
return _default_group
if in_dygraph_mode():
return core.ops.c_sync_comm_stream([tensor], [tensor], 'ring_id',
ring_id)
op_type = 'c_sync_comm_stream'
def broadcast(tensor, src, group=0):
helper = LayerHelper(op_type, **locals())
helper.append_op(
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [tensor]},
attrs={'ring_id': ring_id}, )
def broadcast(tensor, src, group=None, use_calc_stream=True):
"""
Broadcast a tensor from the source to all others.
......@@ -107,7 +305,9 @@ def broadcast(tensor, src, group=0):
tensor (Tensor): The Tensor to send if current rank is the source, or the tensor to receive otherwise. Its data type
should be float16, float32, float64, int32 or int64.
src (int): The source rank.
group (int): The process group to work on. It is Optional.
group (Group): The group instance return by new_group or None for global default group.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False),
default to True.
Returns:
None.
......@@ -130,17 +330,25 @@ def broadcast(tensor, src, group=0):
out = data.numpy()
# [[1, 2, 3], [1, 2, 3]]
"""
if group is not None and not group.is_member():
return
if not isinstance(src, int):
raise ValueError("src should be int.")
ring_id = 0 if group is None else group.id
gsrc = src if group is None else group.get_group_rank(src)
if in_dygraph_mode():
return core.ops.c_broadcast(tensor, tensor, 'root', src,
'use_calc_stream', True, 'ring_id', group)
return core.ops.c_broadcast(tensor, tensor, 'root', gsrc,
'use_calc_stream', use_calc_stream,
'ring_id', ring_id)
op_type = 'c_broadcast'
check_variable_and_dtype(
tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
'broadcast')
if not isinstance(src, int) or not isinstance(group, int):
raise ValueError("Both the type of 'src' and 'group' for broadcast "
"should be int.")
helper = LayerHelper(op_type, **locals())
helper.append_op(
......@@ -148,13 +356,13 @@ def broadcast(tensor, src, group=0):
inputs={'X': [tensor]},
outputs={'Out': [tensor]},
attrs={
'root': src,
'use_calc_stream': True,
'ring_id': group,
'root': gsrc,
'use_calc_stream': use_calc_stream,
'ring_id': ring_id,
})
def all_reduce(tensor, op=ReduceOp.SUM, group=0):
def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True):
"""
Reduce a tensor over all ranks so that all get the result.
......@@ -163,7 +371,9 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=0):
tensor (Tensor): The input Tensor. It also works as the output Tensor. Its data type
should be float16, float32, float64, int32 or int64.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used.
group (int): Optional. The process group to work on.
group (Group): The group instance return by new_group or None for global default group.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False),
default to True.
Returns:
None.
......@@ -187,19 +397,25 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=0):
out = data.numpy()
# [[5, 7, 9], [5, 7, 9]]
"""
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
if in_dygraph_mode():
if op == ReduceOp.SUM:
return core.ops.c_allreduce_sum(tensor, tensor, 'use_calc_stream',
True, 'ring_id', group)
use_calc_stream, 'ring_id', ring_id)
elif op == ReduceOp.MAX:
return core.ops.c_allreduce_max(tensor, tensor, 'use_calc_stream',
True, 'ring_id', group)
use_calc_stream, 'ring_id', ring_id)
elif op == ReduceOp.MIN:
return core.ops.c_allreduce_min(tensor, tensor, 'use_calc_stream',
True, 'ring_id', group)
use_calc_stream, 'ring_id', ring_id)
elif op == ReduceOp.PROD:
return core.ops.c_allreduce_prod(tensor, tensor, 'use_calc_stream',
True, 'ring_id', group)
use_calc_stream, 'ring_id',
ring_id)
else:
raise ValueError("Unknown parameter: {}.".format(op))
......@@ -217,18 +433,18 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=0):
op_type = 'c_allreduce_min'
elif op == ReduceOp.PROD:
op_type = 'c_allreduce_prod'
if not isinstance(group, int):
raise ValueError("The type of 'group' for all_reduce should be int.")
if not isinstance(ring_id, int):
raise ValueError("The type of 'ring_id' for all_reduce should be int.")
helper = LayerHelper(op_type, **locals())
helper.append_op(
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [tensor]},
attrs={'ring_id': group,
'use_calc_stream': True})
attrs={'ring_id': ring_id,
'use_calc_stream': use_calc_stream})
def reduce(tensor, dst, op=ReduceOp.SUM, group=0):
def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True):
"""
Reduce a tensor to the destination from all others.
......@@ -238,7 +454,9 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=0):
should be float16, float32, float64, int32 or int64.
dst (int): The destination rank id.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used.
group (int): The id of the process group to work on.
group (Group): The group instance return by new_group or None for global default group.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False),
default to True.
Returns:
None.
......@@ -261,20 +479,32 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=0):
out = data.numpy()
# [[5, 7, 9], [5, 7, 9]]
"""
if group is not None and not group.is_member():
return
if not isinstance(dst, int):
raise ValueError("dst should be int.")
ring_id = 0 if group is None else group.id
gdst = dst if group is None else group.get_group_rank(dst)
if in_dygraph_mode():
if op == ReduceOp.SUM:
return core.ops.c_reduce_sum(tensor, tensor, 'use_calc_stream',
True, 'ring_id', group, 'root_id', dst)
use_calc_stream, 'ring_id', ring_id,
'root_id', gdst)
elif op == ReduceOp.MAX:
return core.ops.c_reduce_max(tensor, tensor, 'use_calc_stream',
True, 'ring_id', group, 'root_id', dst)
use_calc_stream, 'ring_id', ring_id,
'root_id', gdst)
elif op == ReduceOp.MIN:
return core.ops.c_reduce_min(tensor, tensor, 'use_calc_stream',
True, 'ring_id', group, 'root_id', dst)
use_calc_stream, 'ring_id', ring_id,
'root_id', gdst)
elif op == ReduceOp.PROD:
return core.ops.c_reduce_prod(tensor, tensor, 'use_calc_stream',
True, 'ring_id', group, 'root_id',
dst)
use_calc_stream, 'ring_id', ring_id,
'root_id', gdst)
else:
raise ValueError("Unknown parameter: {}.".format(op))
......@@ -295,22 +525,19 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=0):
elif op == ReduceOp.PROD:
op_type = 'c_reduce_prod'
if not isinstance(dst, int) or not isinstance(group, int):
raise ValueError("Both the type of 'dst' and 'group' for reduce "
"should be int.")
helper = LayerHelper(op_type, **locals())
helper.append_op(
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [tensor]},
attrs={
'ring_id': group,
'use_calc_stream': True,
'root_id': dst,
'ring_id': ring_id,
'use_calc_stream': use_calc_stream,
'root_id': gdst,
})
def all_gather(tensor_list, tensor, group=0):
def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
"""
Gather tensors from all participators and all get the result.
......@@ -320,7 +547,9 @@ def all_gather(tensor_list, tensor, group=0):
should be float16, float32, float64, int32 or int64.
tensor (Tensor): The Tensor to send. Its data type
should be float16, float32, float64, int32 or int64.
group (int): The id of the process group to work on.
group (Group): The group instance return by new_group or None for global default group.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False),
default to True.
Returns:
None.
......@@ -348,13 +577,19 @@ def all_gather(tensor_list, tensor, group=0):
data2 = paddle.to_tensor(np_data2)
paddle.distributed.all_gather(tensor_list, data2)
"""
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
nranks = _get_global_group().nranks if group is None else group.nranks
op_type = 'c_allgather'
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
_default_group = _get_global_default_group()
if in_dygraph_mode():
core.ops.c_allgather(tensor, out, 'use_calc_stream', True, 'ring_id',
group, 'nranks', _default_group.nranks)
core.ops.c_allgather(tensor, out, 'use_calc_stream', use_calc_stream,
'ring_id', ring_id, 'nranks', nranks)
else:
if not isinstance(tensor_list, list):
raise ValueError("The type of 'tensor_list' for all_gather "
......@@ -367,23 +602,20 @@ def all_gather(tensor_list, tensor, group=0):
check_variable_and_dtype(
tensor, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], 'all_gather')
if not isinstance(group, int):
raise ValueError("The type of 'group' for all_gather "
"should be int.")
helper.append_op(
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [out]},
attrs={
'ring_id': group,
'use_calc_stream': True,
'nranks': _default_group.nranks
'ring_id': ring_id,
'use_calc_stream': use_calc_stream,
'nranks': nranks
})
tensor_list.extend(paddle.split(out, _default_group.nranks, 0))
tensor_list.extend(paddle.split(out, nranks, 0))
def scatter(tensor, tensor_list=None, src=0, group=0):
def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True):
"""
Scatter a tensor to all participators.
......@@ -394,7 +626,9 @@ def scatter(tensor, tensor_list=None, src=0, group=0):
tensor_list (list): A list of Tensors to scatter. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32 or int64.
src (int): The source rank id.
group (int): The id of the process group to work on.
group (Group): The group instance return by new_group or None for global default group.
use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False),
default to True.
Returns:
None.
......@@ -422,45 +656,51 @@ def scatter(tensor, tensor_list=None, src=0, group=0):
paddle.distributed.scatter(data1, tensor_list=[data1, data2], src=1)
out = data1.numpy()
"""
if group is not None and not group.is_member():
return
if not isinstance(src, int):
raise ValueError("src should be int.")
ring_id = 0 if group is None else group.id
gsrc = src if group is None else group.get_group_rank(src)
rank = _get_global_group().rank if group is None else group.rank
nranks = _get_global_group().nranks if group is None else group.nranks
op_type = 'c_scatter'
_default_group = _get_global_default_group()
rank = _default_group.rank
nranks = _default_group.nranks
if rank != src:
if rank != gsrc:
tensor_list = []
for _ in range(nranks):
tensor_list.append(tensor)
temp = paddle.concat(tensor_list, axis=0)
if in_dygraph_mode():
return core.ops.c_scatter(temp, tensor, 'use_calc_stream', True,
'ring_id', group, 'nranks',
_default_group.nranks, 'root', src)
return core.ops.c_scatter(temp, tensor, 'use_calc_stream',
use_calc_stream, 'ring_id', ring_id, 'nranks',
nranks, 'root', gsrc)
check_variable_and_dtype(
tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
'scatter')
if not isinstance(group, int) or not isinstance(src, int):
raise ValueError("Both the type of 'src' and 'group' for scatter "
"should be int.")
helper = LayerHelper(op_type, **locals())
helper.append_op(
type=op_type,
inputs={'X': [temp]},
outputs={'Out': [tensor]},
attrs={
'ring_id': group,
'root': src,
'use_calc_stream': True,
'ring_id': ring_id,
'root': gsrc,
'use_calc_stream': use_calc_stream,
'nranks': nranks,
})
def barrier(group=0):
def barrier(group=None):
"""
Barrier among all participators in the group.
Args:
group (int): The id of the process group to work on.
group (Group): The group instance return by new_group or None for global default group.
Returns:
None.
......@@ -475,18 +715,23 @@ def barrier(group=0):
init_parallel_env()
paddle.distributed.barrier()
"""
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
op_type = 'barrier'
temp = fill_constant([1], dtype="int32", value="1")
if in_dygraph_mode():
return core.ops.barrier(temp, temp, 'ring_id', group)
if not isinstance(group, int):
return core.ops.barrier(temp, temp, 'ring_id', ring_id)
if not isinstance(ring_id, int):
raise ValueError("The type of 'group' for barrier must be int.")
helper = LayerHelper(op_type, **locals())
helper.append_op(
type=op_type,
inputs={'X': [temp]},
outputs={'Out': [temp]},
attrs={'ring_id': group})
attrs={'ring_id': ring_id})
def _parallel_linear(x, num_rows, num_cols, axis, param_attr, bias_attr,
......@@ -515,10 +760,10 @@ def _parallel_linear(x, num_rows, num_cols, axis, param_attr, bias_attr,
if gather_out:
if axis == 0:
paddle.distributed.all_reduce(linear_out, group=0)
paddle.distributed.all_reduce(linear_out)
else:
output = []
paddle.distributed.all_gather(output, linear_out, group=0)
paddle.distributed.all_gather(output, linear_out)
linear_out = paddle.concat(output, axis=len(linear_out.shape) - 1)
return linear_out
......@@ -559,7 +804,7 @@ def _parallel_embedding(x, per_part_embeddings, origin_size, param_attr,
main_block = paddle.static.default_main_program().global_block()
startup_block.vars[embedding.weight.name].is_distributed = True
main_block.vars[embedding.weight.name].is_distributed = True
paddle.distributed.all_reduce(emb_out, group=0)
paddle.distributed.all_reduce(emb_out, group=None)
return emb_out
......@@ -584,7 +829,7 @@ def split(x,
With parallel embedding, the weight is split into num_partitions partitions, each
of which is a matrix with (N/num_partitions + 1) rows and M column where the last
row as the padding idx.
Suppose we split the NxM weight into two partitons on device_0 and device_1
respectively. Then, one each device, the final weight has (N/2 + 1) rows with the
index range from 0 to N/2. On device_0, all values in the input within [0, N/2 -1]
......
......@@ -82,6 +82,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_collective_scatter_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_barrier_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_allreduce_api)
LIST(REMOVE_ITEM TEST_OPS test_new_group_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_broadcast_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_allgather_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_wait)
......@@ -177,6 +178,7 @@ endif()
if ((NOT WITH_NCCL) AND (NOT WITH_RCCL))
list(REMOVE_ITEM TEST_OPS test_imperative_group)
LIST(REMOVE_ITEM TEST_OPS test_new_group_api)
endif()
if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
......@@ -518,6 +520,7 @@ if(WITH_DISTRIBUTE)
if(WITH_GPU OR WITH_ROCM)
bash_test_modules(test_c_comm_init_op START_BASH test_c_comm_init_op.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR})
py_test_modules(test_launch_coverage MODULES test_launch_coverage)
bash_test_modules(test_new_group START_BASH test_new_group.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR})
endif()
bash_test_modules(test_fleetrun START_BASH test_fleetrun.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR})
......@@ -831,6 +834,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
set_tests_properties(test_collective_allgather_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_broadcast_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_allreduce_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_new_group_api PROPERTIES TIMEOUT 120)
if(WITH_DISTRIBUTE)
set_tests_properties(test_pipeline PROPERTIES TIMEOUT 120)
endif()
......@@ -853,6 +857,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
test_collective_barrier_api
test_collective_reduce_api
test_collective_allreduce_api
test_new_group_api
test_collective_broadcast_api
test_collective_allgather_api
PROPERTIES LABELS "RUN_TYPE=DIST")
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import os
import sys
import signal
import time
import socket
from contextlib import closing
from six import string_types
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import unittest
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
paddle.enable_static()
class TestCollectiveAllreduceNewGroupAPI(TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank):
with fluid.program_guard(main_prog, startup_program):
tindata = layers.data(
name="tindata", shape=[10, 1000], dtype='float32')
gp = paddle.distributed.new_group([0, 1])
paddle.distributed.all_reduce(
tindata, group=gp, use_calc_stream=False)
return [tindata]
if __name__ == "__main__":
runtime_main(TestCollectiveAllreduceNewGroupAPI, "allreduce")
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
import paddle
class TestNewGroupAPI(object):
def __init__(self):
paddle.distributed.init_parallel_env()
d1 = np.array([1, 2, 3])
d2 = np.array([2, 3, 4])
self.tensor1 = paddle.to_tensor(d1)
self.tensor2 = paddle.to_tensor(d2)
def test_all(self):
gp = paddle.distributed.new_group([0, 1])
print("test new group api ok")
tmp = np.array([0, 0, 0])
result = paddle.to_tensor(tmp)
paddle.distributed.scatter(
result, [self.tensor2, self.tensor1],
src=0,
group=gp,
use_calc_stream=True)
if gp.rank == 0:
assert np.array_equal(result, self.tensor2)
elif gp.rank == 1:
assert np.array_equal(result, self.tensor1)
print("test scatter api ok")
paddle.distributed.broadcast(
result, src=1, group=gp, use_calc_stream=True)
assert np.array_equal(result, self.tensor1)
print("test broadcast api ok")
paddle.distributed.reduce(result, dst=0, group=gp, use_calc_stream=True)
if gp.rank == 0:
assert np.array_equal(result,
paddle.add(self.tensor1, self.tensor1))
elif gp.rank == 1:
assert np.array_equal(result, self.tensor1)
print("test reduce api ok")
paddle.distributed.all_reduce(result, use_calc_stream=True)
assert np.array_equal(
result,
paddle.add(paddle.add(self.tensor1, self.tensor1), self.tensor1))
print("test all_reduce api ok")
paddle.distributed.wait(result, gp, use_calc_stream=True)
paddle.distributed.wait(result, gp, use_calc_stream=False)
print("test wait api ok")
result = []
paddle.distributed.all_gather(
result, self.tensor1, group=gp, use_calc_stream=True)
assert np.array_equal(result[0], self.tensor1)
assert np.array_equal(result[1], self.tensor1)
print("test all_gather api ok")
paddle.distributed.barrier(group=gp)
print("test barrier api ok")
return
if __name__ == "__main__":
gpt = TestNewGroupAPI()
gpt.test_all()
#!/bin/bash
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch --gpus=0,1 new_group.py
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
from test_collective_api_base import TestDistBase
paddle.enable_static()
class TestCollectiveAllreduceAPI(TestDistBase):
def _setup_config(self):
pass
def test_allreduce_nccl(self):
self.check_with_place("collective_allreduce_new_group_api.py",
"allreduce", "nccl")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册