From 07741593fadbda741f8b0f9935add64901b93f3b Mon Sep 17 00:00:00 2001 From: kuizhiqing Date: Thu, 1 Apr 2021 11:19:43 +0800 Subject: [PATCH] new group (#31682) * new group * ci compatible fix * assert nccl --- paddle/fluid/imperative/bkcl_context.cc | 29 +- paddle/fluid/imperative/bkcl_context.h | 2 + paddle/fluid/imperative/nccl_context.cc | 24 + paddle/fluid/imperative/nccl_context.h | 2 + paddle/fluid/imperative/parallel_context.h | 2 + .../collective/c_sync_calc_stream_op.cc | 70 +-- .../collective/c_sync_comm_stream_op.cc | 74 ++-- paddle/fluid/pybind/imperative.cc | 10 +- paddle/fluid/pybind/op_function_generator.cc | 2 + python/paddle/distributed/collective.py | 411 ++++++++++++++---- .../fluid/tests/unittests/CMakeLists.txt | 5 + .../collective_allreduce_new_group_api.py | 56 +++ .../paddle/fluid/tests/unittests/new_group.py | 83 ++++ .../fluid/tests/unittests/test_new_group.sh | 19 + .../tests/unittests/test_new_group_api.py | 35 ++ 15 files changed, 670 insertions(+), 154 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/collective_allreduce_new_group_api.py create mode 100644 python/paddle/fluid/tests/unittests/new_group.py create mode 100755 python/paddle/fluid/tests/unittests/test_new_group.sh create mode 100644 python/paddle/fluid/tests/unittests/test_new_group_api.py diff --git a/paddle/fluid/imperative/bkcl_context.cc b/paddle/fluid/imperative/bkcl_context.cc index 873068a0d31..886179feb19 100644 --- a/paddle/fluid/imperative/bkcl_context.cc +++ b/paddle/fluid/imperative/bkcl_context.cc @@ -19,12 +19,11 @@ #include #include +#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/platform/bkcl_helper.h" #include "paddle/fluid/platform/collective_helper.h" -#include "paddle/fluid/platform/gen_comm_id_helper.h" - -#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/gen_comm_id_helper.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/string/split.h" #include "paddle/fluid/string/string_helper.h" @@ -77,7 +76,7 @@ void BKCLParallelContext::Init() { bkcl_ids.resize(strategy_.nrings_); if (strategy_.local_rank_ == 0) { - // generate the unique ncclid on the root worker + // generate the unique bkclid on the root worker for (size_t i = 0; i < bkcl_ids.size(); ++i) { auto ret = bkcl_get_unique_id(&bkcl_ids[i]); PADDLE_ENFORCE_EQ(BKCL_SUCCESS, ret, @@ -99,6 +98,28 @@ void BKCLParallelContext::Init() { } } +void BKCLParallelContext::InitWithRingID(int ring_id) { + std::vector bkcl_ids; + bkcl_ids.resize(1); + + if (strategy_.local_rank_ == 0) { + // generate the unique bkclid on the root worker + auto ret = bkcl_get_unique_id(&bkcl_ids[0]); + PADDLE_ENFORCE_EQ(BKCL_SUCCESS, ret, + platform::errors::PreconditionNotMet( + "BKCL get unique id failed [%d]", ret)); + } + BcastBKCLId(bkcl_ids, 0); + + int xpu_id = BOOST_GET_CONST(platform::XPUPlace, place_).device; + VLOG(0) << "init BKCL context nranks: " << strategy_.nranks_ + << " local rank: " << strategy_.local_rank_ << " xpu id: " << xpu_id + << " ring id: " << ring_id; + // it will assign bkcl_comm in XPUDeviceContext within ring_id + platform::BKCLCommContext::Instance().CreateBKCLComm( + &bkcl_ids[0], strategy_.nranks_, strategy_.local_rank_, xpu_id, ring_id); +} + void BKCLParallelContext::AllReduceByStream(const framework::Variable &src, framework::Variable *dst, int ring_id, bool use_calc_stream) { diff --git a/paddle/fluid/imperative/bkcl_context.h b/paddle/fluid/imperative/bkcl_context.h index d7d917f2008..86e4d97b3c7 100644 --- a/paddle/fluid/imperative/bkcl_context.h +++ b/paddle/fluid/imperative/bkcl_context.h @@ -36,6 +36,8 @@ class BKCLParallelContext : public ParallelContext { void Init() override; + void InitWithRingID(int ring_id) override; + void AllReduceByStream(const framework::Variable& src, framework::Variable* dst, int ring_id, bool use_calc_stream) override; diff --git a/paddle/fluid/imperative/nccl_context.cc b/paddle/fluid/imperative/nccl_context.cc index eb0135d15e0..7e7c4ceea0b 100644 --- a/paddle/fluid/imperative/nccl_context.cc +++ b/paddle/fluid/imperative/nccl_context.cc @@ -79,6 +79,30 @@ void NCCLParallelContext::Init() { } } +void NCCLParallelContext::InitWithRingID(int ring_id) { + std::vector nccl_ids; + nccl_ids.resize(1); + + if (strategy_.local_rank_ == 0) { + // generate the unique ncclid on the root worker + platform::dynload::ncclGetUniqueId(&nccl_ids[0]); + } + BcastNCCLId(nccl_ids, 0); + + int gpu_id = BOOST_GET_CONST(platform::CUDAPlace, place_).device; + VLOG(0) << "init nccl context nranks: " << strategy_.nranks_ + << " local rank: " << strategy_.local_rank_ << " gpu id: " << gpu_id + << " ring id: " << ring_id; + // it will assign nccl_comm in CUDADeviceContext within ring_id + platform::NCCLCommContext::Instance().CreateNCCLComm( + &nccl_ids[0], strategy_.nranks_, strategy_.local_rank_, gpu_id, ring_id); + + compute_events_.emplace_back(platform::CudaEventResourcePool::Instance().New( + BOOST_GET_CONST(platform::CUDAPlace, place_).device)); + comm_events_.emplace_back(platform::CudaEventResourcePool::Instance().New( + BOOST_GET_CONST(platform::CUDAPlace, place_).device)); +} + void NCCLParallelContext::AllReduceByStream(const framework::Variable &src, framework::Variable *dst, int ring_id, bool use_calc_stream) { diff --git a/paddle/fluid/imperative/nccl_context.h b/paddle/fluid/imperative/nccl_context.h index 51e5743aebd..292ef1661c3 100644 --- a/paddle/fluid/imperative/nccl_context.h +++ b/paddle/fluid/imperative/nccl_context.h @@ -53,6 +53,8 @@ class NCCLParallelContext : public ParallelContext { void Init() override; + void InitWithRingID(int ring_id) override; + void AllReduceByStream(const framework::Variable& src, framework::Variable* dst, int ring_id, bool use_calc_stream) override; diff --git a/paddle/fluid/imperative/parallel_context.h b/paddle/fluid/imperative/parallel_context.h index ef0a9604092..9a76311f2ed 100644 --- a/paddle/fluid/imperative/parallel_context.h +++ b/paddle/fluid/imperative/parallel_context.h @@ -50,6 +50,8 @@ class ParallelContext { virtual void Init() = 0; + virtual void InitWithRingID(int ring_id) = 0; + virtual void AllReduceByStream(const framework::Variable& src, framework::Variable* dst, int ring_id, bool use_calc_stream) = 0; diff --git a/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc b/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc index c4abe284d72..700d1173e2f 100644 --- a/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc +++ b/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc @@ -15,40 +15,20 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" -namespace paddle { -namespace framework { -class Scope; -} // namespace framework -} // namespace paddle - namespace paddle { namespace operators { -class CSyncCalcStreamOp : public framework::OperatorBase { +class CSyncCalcStreamOp : public framework::OperatorWithKernel { public: - CSyncCalcStreamOp(const std::string& type, - const framework::VariableNameMap& inputs, - const framework::VariableNameMap& outputs, - const framework::AttributeMap& attrs) - : OperatorBase(type, inputs, outputs, attrs) {} - - void RunImpl(const framework::Scope& scope, - const platform::Place& place) const override { - PADDLE_ENFORCE_EQ(is_gpu_place(place), true, - platform::errors::PreconditionNotMet( - "Sync stream op can run on gpu place only for now.")); -#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && !defined(_WIN32) - auto dev_ctx = static_cast( - platform::DeviceContextPool::Instance().Get(place)); -#ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(dev_ctx->stream())); -#else - PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(dev_ctx->stream())); -#endif -#else - PADDLE_THROW(platform::errors::PreconditionNotMet( - "PaddlePaddle should compile with GPU.")); -#endif + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override {} + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(framework::proto::VarType::FP32, + ctx.GetPlace()); } }; @@ -65,10 +45,36 @@ Call calculation stream synchronization. } }; +template +class CSyncCalcStreamCudaKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && !defined(_WIN32) + + auto place = ctx.GetPlace(); + auto dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(dev_ctx->stream())); +#else + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(dev_ctx->stream())); +#endif + +#else + PADDLE_THROW(platform::errors::PreconditionNotMet( + "PaddlePaddle should compile with GPU.")); +#endif + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(c_sync_calc_stream, ops::CSyncCalcStreamOp, - ops::CSyncCalcStreamOpMaker); +REGISTER_OP_WITHOUT_GRADIENT(c_sync_calc_stream, ops::CSyncCalcStreamOp, + ops::CSyncCalcStreamOpMaker); + +REGISTER_OP_CUDA_KERNEL(c_sync_calc_stream, + ops::CSyncCalcStreamCudaKernel); diff --git a/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc b/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc index adf27069f52..95b9cd040fe 100644 --- a/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc +++ b/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc @@ -14,45 +14,25 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" -namespace paddle { -namespace framework { -class Scope; -} // namespace framework -} // namespace paddle + #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) -#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/nccl_helper.h" #endif namespace paddle { namespace operators { -class CSyncCommStreamOp : public framework::OperatorBase { +class CSyncCommStreamOp : public framework::OperatorWithKernel { public: - CSyncCommStreamOp(const std::string& type, - const framework::VariableNameMap& inputs, - const framework::VariableNameMap& outputs, - const framework::AttributeMap& attrs) - : OperatorBase(type, inputs, outputs, attrs) {} - - void RunImpl(const framework::Scope& scope, - const platform::Place& place) const override { - PADDLE_ENFORCE_EQ(is_gpu_place(place), true, - platform::errors::PreconditionNotMet( - "Sync stream op can run on gpu place only for now.")); + using framework::OperatorWithKernel::OperatorWithKernel; -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - int ring_id = Attr("ring_id"); - auto stream = - platform::NCCLCommContext::Instance().Get(ring_id, place)->stream(); -#ifdef PADDLE_WITH_RCCL - PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream)); -#else - PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); -#endif -#else - PADDLE_THROW(platform::errors::PreconditionNotMet( - "PaddlePaddle should compile with GPU.")); -#endif + void InferShape(framework::InferShapeContext* ctx) const override {} + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(framework::proto::VarType::FP32, + ctx.GetPlace()); } }; @@ -72,10 +52,38 @@ Call communication stream synchronization. } }; +template +class CSyncCommStreamCudaKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + + auto place = ctx.GetPlace(); + + int ring_id = ctx.Attr("ring_id"); + auto stream = + platform::NCCLCommContext::Instance().Get(ring_id, place)->stream(); + +#ifdef PADDLE_WITH_RCCL + PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream)); +#else + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); +#endif + +#else + PADDLE_THROW(platform::errors::PreconditionNotMet( + "PaddlePaddle should compile with GPU.")); +#endif + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(c_sync_comm_stream, ops::CSyncCommStreamOp, - ops::CSyncCommStreamOpMaker); +REGISTER_OP_WITHOUT_GRADIENT(c_sync_comm_stream, ops::CSyncCommStreamOp, + ops::CSyncCommStreamOpMaker); + +REGISTER_OP_CUDA_KERNEL(c_sync_comm_stream, + ops::CSyncCommStreamCudaKernel); diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 38ba1dc0293..c1c1387a84c 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -1578,7 +1578,10 @@ void BindImperative(py::module *m_ptr) { m, "NCCLParallelContext") .def(py::init()) - .def("init", [](imperative::NCCLParallelContext &self) { self.Init(); }); + .def("init", [](imperative::NCCLParallelContext &self) { self.Init(); }) + .def("init_with_ring_id", + &imperative::NCCLParallelContext::InitWithRingID, + py::arg("ring_id")); #endif #if defined(PADDLE_WITH_XPU_BKCL) @@ -1587,7 +1590,10 @@ void BindImperative(py::module *m_ptr) { m, "BKCLParallelContext") .def(py::init()) - .def("init", [](imperative::BKCLParallelContext &self) { self.Init(); }); + .def("init", [](imperative::BKCLParallelContext &self) { self.Init(); }) + .def("init_with_ring_id", + &imperative::BKCLParallelContext::InitWithRingID, + py::arg("ring_id")); #endif } diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index 69856fa4fa1..282b0e1d81c 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -119,6 +119,8 @@ std::map> op_passing_outs_map = { {"fill_constant", {"Out"}}, {"matmul", {"Out"}}, {"c_broadcast", {"Out"}}, + {"c_sync_calc_stream", {"Out"}}, + {"c_sync_comm_stream", {"Out"}}, {"c_allreduce_sum", {"Out"}}, {"c_allreduce_max", {"Out"}}, {"c_allreduce_min", {"Out"}}, diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index a6eb896802f..8e5c35995b2 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -26,6 +26,9 @@ import paddle.fluid as fluid import paddle.fluid.core as core __all__ = [ + 'wait', + 'new_group', + 'get_group', 'broadcast', 'all_reduce', 'reduce', @@ -75,30 +78,225 @@ class ReduceOp: PROD = 3 -class _Group(): - """The abstract representation of group.""" +class Group(): + """ + The abstract representation of group. + """ - def __init__(self, rank, rank_num): + def __init__(self, rank, rank_num, id=0, ranks=[]): self.rank = rank self.nranks = rank_num + self.id = id + self.ranks = ranks + + def is_member(self): + if self.rank < 0: + return False + if self.nranks < 2: + return False + return True + + def get_group_rank(self, rank): + if self.id == 0: + return rank + if self.is_member() and rank in self.ranks: + return self.ranks.index(rank) + else: + return -1 + + +_global_env = None + + +def _get_global_env(): + global _global_env + if not _global_env: + _global_env = paddle.distributed.ParallelEnv() + return _global_env + + +# group map : the map of all group, 0 for GlobalGroup +# Dict[int, Group] +_group_map = {} + + +def _get_group_map(): + global _group_map + if not _group_map: + genv = _get_global_env() + _group_map[0] = Group(genv.rank, genv.world_size, 0) + return _group_map + + +def _get_global_group(): + return _get_group_map()[0] + + +def _new_ring_id(): + return len(_get_group_map()) + max(_get_global_env().nrings, 9) + + +def get_group(id=0): + """ + + Get group instance by group id. + + Args: + id (int): the group id + + Returns: + Group: the group instance. + + Examples: + .. code-block:: python + + ... + gid = paddle.distributed.new_group([2,4,6]) + paddle.distributed.get_group(gid.id) + + """ + + gm = _get_group_map() + return gm[group] if group in gm else None + + +def new_group(ranks=None, backend=None): + """ + + Creates a new distributed comminication group. + + Args: + ranks (list): The global ranks of group members, list as sorted. + backend (str): The backend used to create group, only nccl is supported now. + + Returns: + Group: The group instance. Nerver return None. + + Examples: + .. code-block:: python + + import numpy as np + import paddle + + paddle.distributed.init_parallel_env() + tindata = np.random.random([10, 1000]).astype('float32') + tindata = paddle.to_tensor(tindata) + gid = paddle.distributed.new_group([2,4,6]) + paddle.distributed.all_reduce(tindata, group=gid, use_calc_stream=False) + + """ + + if not backend: + backend = 'nccl' + assert backend == 'nccl', ("backend other than nccl is not supported yet") + + genv = _get_global_env() + global_rank = genv.rank + + ring_id = _new_ring_id() + + global _group_map + if global_rank not in ranks: + gp = Group(-1, -1, ring_id, ranks) + _group_map[ring_id] = gp + return gp + + ranks = sorted(ranks) + group_rank = ranks.index(global_rank) + group_size = len(ranks) + gp = Group(group_rank, group_size, ring_id, ranks) + _group_map[ring_id] = gp + + if group_size < 2: + return gp + + strategy = core.ParallelStrategy() + strategy.nranks = group_size + strategy.local_rank = group_rank + strategy.trainer_endpoints = [genv.trainer_endpoints[i] for i in ranks] + strategy.current_endpoint = genv.current_endpoint + strategy.nrings = 1 + + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(genv.device_id) + core.NCCLParallelContext(strategy, place).init_with_ring_id(ring_id) + else: + assert False + + return gp + +def wait(tensor, group=None, use_calc_stream=True): + """ + + wait to sync stream for group. + + Args: + tensor (Tensor): The Tensor used before sync. + group (Group): The Group instance to perform sync. + use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False), + default to False. + + Returns: + None. + + Examples: + .. code-block:: python + + + import numpy as np + import paddle + + paddle.distributed.init_parallel_env() + tindata = np.random.random([10, 1000]).astype('float32') + tindata = paddle.to_tensor(tindata) + paddle.distributed.all_reduce(tindata, use_calc_stream=True) + paddle.distributed.wait(tindata) + + """ + + if group is not None and not group.is_member(): + return + + ring_id = 0 if group is None else group.id + + if use_calc_stream: + _sync_calc_stream(tensor) + else: + _sync_comm_stream(tensor, ring_id) + + +def _sync_calc_stream(tensor): + + if in_dygraph_mode(): + return core.ops.c_sync_calc_stream(tensor, tensor) + + op_type = 'c_sync_calc_stream' + + helper = LayerHelper(op_type, **locals()) + helper.append_op( + type=op_type, + inputs={'X': [tensor]}, + outputs={'Out': [tensor]}, ) -# NOTE(chenweihang): Lazily initialized global group information -# If we initialize _default_group when import module, it will -# not update when we use spawn to run multi-process training -_default_group = None +def _sync_comm_stream(tensor, ring_id=0): -def _get_global_default_group(): - global _default_group - if _default_group is None: - _default_group = _Group( - int(os.getenv("PADDLE_TRAINER_ID", "0")), - int(os.getenv("PADDLE_TRAINERS_NUM", "1"))) - return _default_group + if in_dygraph_mode(): + return core.ops.c_sync_comm_stream([tensor], [tensor], 'ring_id', + ring_id) + op_type = 'c_sync_comm_stream' -def broadcast(tensor, src, group=0): + helper = LayerHelper(op_type, **locals()) + helper.append_op( + type=op_type, + inputs={'X': [tensor]}, + outputs={'Out': [tensor]}, + attrs={'ring_id': ring_id}, ) + + +def broadcast(tensor, src, group=None, use_calc_stream=True): """ Broadcast a tensor from the source to all others. @@ -107,7 +305,9 @@ def broadcast(tensor, src, group=0): tensor (Tensor): The Tensor to send if current rank is the source, or the tensor to receive otherwise. Its data type should be float16, float32, float64, int32 or int64. src (int): The source rank. - group (int): The process group to work on. It is Optional. + group (Group): The group instance return by new_group or None for global default group. + use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False), + default to True. Returns: None. @@ -130,17 +330,25 @@ def broadcast(tensor, src, group=0): out = data.numpy() # [[1, 2, 3], [1, 2, 3]] """ + + if group is not None and not group.is_member(): + return + + if not isinstance(src, int): + raise ValueError("src should be int.") + + ring_id = 0 if group is None else group.id + gsrc = src if group is None else group.get_group_rank(src) + if in_dygraph_mode(): - return core.ops.c_broadcast(tensor, tensor, 'root', src, - 'use_calc_stream', True, 'ring_id', group) + return core.ops.c_broadcast(tensor, tensor, 'root', gsrc, + 'use_calc_stream', use_calc_stream, + 'ring_id', ring_id) op_type = 'c_broadcast' check_variable_and_dtype( tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'], 'broadcast') - if not isinstance(src, int) or not isinstance(group, int): - raise ValueError("Both the type of 'src' and 'group' for broadcast " - "should be int.") helper = LayerHelper(op_type, **locals()) helper.append_op( @@ -148,13 +356,13 @@ def broadcast(tensor, src, group=0): inputs={'X': [tensor]}, outputs={'Out': [tensor]}, attrs={ - 'root': src, - 'use_calc_stream': True, - 'ring_id': group, + 'root': gsrc, + 'use_calc_stream': use_calc_stream, + 'ring_id': ring_id, }) -def all_reduce(tensor, op=ReduceOp.SUM, group=0): +def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True): """ Reduce a tensor over all ranks so that all get the result. @@ -163,7 +371,9 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=0): tensor (Tensor): The input Tensor. It also works as the output Tensor. Its data type should be float16, float32, float64, int32 or int64. op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. - group (int): Optional. The process group to work on. + group (Group): The group instance return by new_group or None for global default group. + use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False), + default to True. Returns: None. @@ -187,19 +397,25 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=0): out = data.numpy() # [[5, 7, 9], [5, 7, 9]] """ + if group is not None and not group.is_member(): + return + + ring_id = 0 if group is None else group.id + if in_dygraph_mode(): if op == ReduceOp.SUM: return core.ops.c_allreduce_sum(tensor, tensor, 'use_calc_stream', - True, 'ring_id', group) + use_calc_stream, 'ring_id', ring_id) elif op == ReduceOp.MAX: return core.ops.c_allreduce_max(tensor, tensor, 'use_calc_stream', - True, 'ring_id', group) + use_calc_stream, 'ring_id', ring_id) elif op == ReduceOp.MIN: return core.ops.c_allreduce_min(tensor, tensor, 'use_calc_stream', - True, 'ring_id', group) + use_calc_stream, 'ring_id', ring_id) elif op == ReduceOp.PROD: return core.ops.c_allreduce_prod(tensor, tensor, 'use_calc_stream', - True, 'ring_id', group) + use_calc_stream, 'ring_id', + ring_id) else: raise ValueError("Unknown parameter: {}.".format(op)) @@ -217,18 +433,18 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=0): op_type = 'c_allreduce_min' elif op == ReduceOp.PROD: op_type = 'c_allreduce_prod' - if not isinstance(group, int): - raise ValueError("The type of 'group' for all_reduce should be int.") + if not isinstance(ring_id, int): + raise ValueError("The type of 'ring_id' for all_reduce should be int.") helper = LayerHelper(op_type, **locals()) helper.append_op( type=op_type, inputs={'X': [tensor]}, outputs={'Out': [tensor]}, - attrs={'ring_id': group, - 'use_calc_stream': True}) + attrs={'ring_id': ring_id, + 'use_calc_stream': use_calc_stream}) -def reduce(tensor, dst, op=ReduceOp.SUM, group=0): +def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True): """ Reduce a tensor to the destination from all others. @@ -238,7 +454,9 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=0): should be float16, float32, float64, int32 or int64. dst (int): The destination rank id. op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. - group (int): The id of the process group to work on. + group (Group): The group instance return by new_group or None for global default group. + use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False), + default to True. Returns: None. @@ -261,20 +479,32 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=0): out = data.numpy() # [[5, 7, 9], [5, 7, 9]] """ + if group is not None and not group.is_member(): + return + + if not isinstance(dst, int): + raise ValueError("dst should be int.") + + ring_id = 0 if group is None else group.id + gdst = dst if group is None else group.get_group_rank(dst) + if in_dygraph_mode(): if op == ReduceOp.SUM: return core.ops.c_reduce_sum(tensor, tensor, 'use_calc_stream', - True, 'ring_id', group, 'root_id', dst) + use_calc_stream, 'ring_id', ring_id, + 'root_id', gdst) elif op == ReduceOp.MAX: return core.ops.c_reduce_max(tensor, tensor, 'use_calc_stream', - True, 'ring_id', group, 'root_id', dst) + use_calc_stream, 'ring_id', ring_id, + 'root_id', gdst) elif op == ReduceOp.MIN: return core.ops.c_reduce_min(tensor, tensor, 'use_calc_stream', - True, 'ring_id', group, 'root_id', dst) + use_calc_stream, 'ring_id', ring_id, + 'root_id', gdst) elif op == ReduceOp.PROD: return core.ops.c_reduce_prod(tensor, tensor, 'use_calc_stream', - True, 'ring_id', group, 'root_id', - dst) + use_calc_stream, 'ring_id', ring_id, + 'root_id', gdst) else: raise ValueError("Unknown parameter: {}.".format(op)) @@ -295,22 +525,19 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=0): elif op == ReduceOp.PROD: op_type = 'c_reduce_prod' - if not isinstance(dst, int) or not isinstance(group, int): - raise ValueError("Both the type of 'dst' and 'group' for reduce " - "should be int.") helper = LayerHelper(op_type, **locals()) helper.append_op( type=op_type, inputs={'X': [tensor]}, outputs={'Out': [tensor]}, attrs={ - 'ring_id': group, - 'use_calc_stream': True, - 'root_id': dst, + 'ring_id': ring_id, + 'use_calc_stream': use_calc_stream, + 'root_id': gdst, }) -def all_gather(tensor_list, tensor, group=0): +def all_gather(tensor_list, tensor, group=None, use_calc_stream=True): """ Gather tensors from all participators and all get the result. @@ -320,7 +547,9 @@ def all_gather(tensor_list, tensor, group=0): should be float16, float32, float64, int32 or int64. tensor (Tensor): The Tensor to send. Its data type should be float16, float32, float64, int32 or int64. - group (int): The id of the process group to work on. + group (Group): The group instance return by new_group or None for global default group. + use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False), + default to True. Returns: None. @@ -348,13 +577,19 @@ def all_gather(tensor_list, tensor, group=0): data2 = paddle.to_tensor(np_data2) paddle.distributed.all_gather(tensor_list, data2) """ + if group is not None and not group.is_member(): + return + + ring_id = 0 if group is None else group.id + nranks = _get_global_group().nranks if group is None else group.nranks + op_type = 'c_allgather' helper = LayerHelper(op_type, **locals()) out = helper.create_variable_for_type_inference(dtype=tensor.dtype) - _default_group = _get_global_default_group() + if in_dygraph_mode(): - core.ops.c_allgather(tensor, out, 'use_calc_stream', True, 'ring_id', - group, 'nranks', _default_group.nranks) + core.ops.c_allgather(tensor, out, 'use_calc_stream', use_calc_stream, + 'ring_id', ring_id, 'nranks', nranks) else: if not isinstance(tensor_list, list): raise ValueError("The type of 'tensor_list' for all_gather " @@ -367,23 +602,20 @@ def all_gather(tensor_list, tensor, group=0): check_variable_and_dtype( tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'], 'all_gather') - if not isinstance(group, int): - raise ValueError("The type of 'group' for all_gather " - "should be int.") helper.append_op( type=op_type, inputs={'X': [tensor]}, outputs={'Out': [out]}, attrs={ - 'ring_id': group, - 'use_calc_stream': True, - 'nranks': _default_group.nranks + 'ring_id': ring_id, + 'use_calc_stream': use_calc_stream, + 'nranks': nranks }) - tensor_list.extend(paddle.split(out, _default_group.nranks, 0)) + tensor_list.extend(paddle.split(out, nranks, 0)) -def scatter(tensor, tensor_list=None, src=0, group=0): +def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True): """ Scatter a tensor to all participators. @@ -394,7 +626,9 @@ def scatter(tensor, tensor_list=None, src=0, group=0): tensor_list (list): A list of Tensors to scatter. Every element in the list must be a Tensor whose data type should be float16, float32, float64, int32 or int64. src (int): The source rank id. - group (int): The id of the process group to work on. + group (Group): The group instance return by new_group or None for global default group. + use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False), + default to True. Returns: None. @@ -422,45 +656,51 @@ def scatter(tensor, tensor_list=None, src=0, group=0): paddle.distributed.scatter(data1, tensor_list=[data1, data2], src=1) out = data1.numpy() """ + if group is not None and not group.is_member(): + return + + if not isinstance(src, int): + raise ValueError("src should be int.") + + ring_id = 0 if group is None else group.id + gsrc = src if group is None else group.get_group_rank(src) + rank = _get_global_group().rank if group is None else group.rank + nranks = _get_global_group().nranks if group is None else group.nranks + op_type = 'c_scatter' - _default_group = _get_global_default_group() - rank = _default_group.rank - nranks = _default_group.nranks - if rank != src: + + if rank != gsrc: tensor_list = [] for _ in range(nranks): tensor_list.append(tensor) temp = paddle.concat(tensor_list, axis=0) if in_dygraph_mode(): - return core.ops.c_scatter(temp, tensor, 'use_calc_stream', True, - 'ring_id', group, 'nranks', - _default_group.nranks, 'root', src) + return core.ops.c_scatter(temp, tensor, 'use_calc_stream', + use_calc_stream, 'ring_id', ring_id, 'nranks', + nranks, 'root', gsrc) check_variable_and_dtype( tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'], 'scatter') - if not isinstance(group, int) or not isinstance(src, int): - raise ValueError("Both the type of 'src' and 'group' for scatter " - "should be int.") helper = LayerHelper(op_type, **locals()) helper.append_op( type=op_type, inputs={'X': [temp]}, outputs={'Out': [tensor]}, attrs={ - 'ring_id': group, - 'root': src, - 'use_calc_stream': True, + 'ring_id': ring_id, + 'root': gsrc, + 'use_calc_stream': use_calc_stream, 'nranks': nranks, }) -def barrier(group=0): +def barrier(group=None): """ Barrier among all participators in the group. Args: - group (int): The id of the process group to work on. + group (Group): The group instance return by new_group or None for global default group. Returns: None. @@ -475,18 +715,23 @@ def barrier(group=0): init_parallel_env() paddle.distributed.barrier() """ + if group is not None and not group.is_member(): + return + + ring_id = 0 if group is None else group.id + op_type = 'barrier' temp = fill_constant([1], dtype="int32", value="1") if in_dygraph_mode(): - return core.ops.barrier(temp, temp, 'ring_id', group) - if not isinstance(group, int): + return core.ops.barrier(temp, temp, 'ring_id', ring_id) + if not isinstance(ring_id, int): raise ValueError("The type of 'group' for barrier must be int.") helper = LayerHelper(op_type, **locals()) helper.append_op( type=op_type, inputs={'X': [temp]}, outputs={'Out': [temp]}, - attrs={'ring_id': group}) + attrs={'ring_id': ring_id}) def _parallel_linear(x, num_rows, num_cols, axis, param_attr, bias_attr, @@ -515,10 +760,10 @@ def _parallel_linear(x, num_rows, num_cols, axis, param_attr, bias_attr, if gather_out: if axis == 0: - paddle.distributed.all_reduce(linear_out, group=0) + paddle.distributed.all_reduce(linear_out) else: output = [] - paddle.distributed.all_gather(output, linear_out, group=0) + paddle.distributed.all_gather(output, linear_out) linear_out = paddle.concat(output, axis=len(linear_out.shape) - 1) return linear_out @@ -559,7 +804,7 @@ def _parallel_embedding(x, per_part_embeddings, origin_size, param_attr, main_block = paddle.static.default_main_program().global_block() startup_block.vars[embedding.weight.name].is_distributed = True main_block.vars[embedding.weight.name].is_distributed = True - paddle.distributed.all_reduce(emb_out, group=0) + paddle.distributed.all_reduce(emb_out, group=None) return emb_out @@ -584,7 +829,7 @@ def split(x, With parallel embedding, the weight is split into num_partitions partitions, each of which is a matrix with (N/num_partitions + 1) rows and M column where the last row as the padding idx. - + Suppose we split the NxM weight into two partitons on device_0 and device_1 respectively. Then, one each device, the final weight has (N/2 + 1) rows with the index range from 0 to N/2. On device_0, all values in the input within [0, N/2 -1] diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 0c292d355dd..0abb61d95aa 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -82,6 +82,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) LIST(REMOVE_ITEM TEST_OPS test_collective_scatter_api) LIST(REMOVE_ITEM TEST_OPS test_collective_barrier_api) LIST(REMOVE_ITEM TEST_OPS test_collective_allreduce_api) + LIST(REMOVE_ITEM TEST_OPS test_new_group_api) LIST(REMOVE_ITEM TEST_OPS test_collective_broadcast_api) LIST(REMOVE_ITEM TEST_OPS test_collective_allgather_api) LIST(REMOVE_ITEM TEST_OPS test_collective_wait) @@ -177,6 +178,7 @@ endif() if ((NOT WITH_NCCL) AND (NOT WITH_RCCL)) list(REMOVE_ITEM TEST_OPS test_imperative_group) + LIST(REMOVE_ITEM TEST_OPS test_new_group_api) endif() if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) @@ -518,6 +520,7 @@ if(WITH_DISTRIBUTE) if(WITH_GPU OR WITH_ROCM) bash_test_modules(test_c_comm_init_op START_BASH test_c_comm_init_op.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR}) py_test_modules(test_launch_coverage MODULES test_launch_coverage) + bash_test_modules(test_new_group START_BASH test_new_group.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR}) endif() bash_test_modules(test_fleetrun START_BASH test_fleetrun.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR}) @@ -831,6 +834,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32) set_tests_properties(test_collective_allgather_api PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_broadcast_api PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_allreduce_api PROPERTIES TIMEOUT 120) + set_tests_properties(test_new_group_api PROPERTIES TIMEOUT 120) if(WITH_DISTRIBUTE) set_tests_properties(test_pipeline PROPERTIES TIMEOUT 120) endif() @@ -853,6 +857,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32) test_collective_barrier_api test_collective_reduce_api test_collective_allreduce_api + test_new_group_api test_collective_broadcast_api test_collective_allgather_api PROPERTIES LABELS "RUN_TYPE=DIST") diff --git a/python/paddle/fluid/tests/unittests/collective_allreduce_new_group_api.py b/python/paddle/fluid/tests/unittests/collective_allreduce_new_group_api.py new file mode 100644 index 00000000000..597765cfb98 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective_allreduce_new_group_api.py @@ -0,0 +1,56 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import os +import sys +import signal +import time +import socket +from contextlib import closing +from six import string_types +import math +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +import paddle.fluid.unique_name as nameGen +from paddle.fluid import core +import unittest +from multiprocessing import Process +import paddle.fluid.layers as layers +from functools import reduce +from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main + +paddle.enable_static() + + +class TestCollectiveAllreduceNewGroupAPI(TestCollectiveAPIRunnerBase): + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program, rank): + with fluid.program_guard(main_prog, startup_program): + tindata = layers.data( + name="tindata", shape=[10, 1000], dtype='float32') + gp = paddle.distributed.new_group([0, 1]) + paddle.distributed.all_reduce( + tindata, group=gp, use_calc_stream=False) + return [tindata] + + +if __name__ == "__main__": + runtime_main(TestCollectiveAllreduceNewGroupAPI, "allreduce") diff --git a/python/paddle/fluid/tests/unittests/new_group.py b/python/paddle/fluid/tests/unittests/new_group.py new file mode 100644 index 00000000000..fb7beeee1df --- /dev/null +++ b/python/paddle/fluid/tests/unittests/new_group.py @@ -0,0 +1,83 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import os +import paddle + + +class TestNewGroupAPI(object): + def __init__(self): + paddle.distributed.init_parallel_env() + d1 = np.array([1, 2, 3]) + d2 = np.array([2, 3, 4]) + self.tensor1 = paddle.to_tensor(d1) + self.tensor2 = paddle.to_tensor(d2) + + def test_all(self): + gp = paddle.distributed.new_group([0, 1]) + print("test new group api ok") + + tmp = np.array([0, 0, 0]) + result = paddle.to_tensor(tmp) + paddle.distributed.scatter( + result, [self.tensor2, self.tensor1], + src=0, + group=gp, + use_calc_stream=True) + if gp.rank == 0: + assert np.array_equal(result, self.tensor2) + elif gp.rank == 1: + assert np.array_equal(result, self.tensor1) + print("test scatter api ok") + + paddle.distributed.broadcast( + result, src=1, group=gp, use_calc_stream=True) + assert np.array_equal(result, self.tensor1) + print("test broadcast api ok") + + paddle.distributed.reduce(result, dst=0, group=gp, use_calc_stream=True) + if gp.rank == 0: + assert np.array_equal(result, + paddle.add(self.tensor1, self.tensor1)) + elif gp.rank == 1: + assert np.array_equal(result, self.tensor1) + print("test reduce api ok") + + paddle.distributed.all_reduce(result, use_calc_stream=True) + assert np.array_equal( + result, + paddle.add(paddle.add(self.tensor1, self.tensor1), self.tensor1)) + print("test all_reduce api ok") + + paddle.distributed.wait(result, gp, use_calc_stream=True) + paddle.distributed.wait(result, gp, use_calc_stream=False) + print("test wait api ok") + + result = [] + paddle.distributed.all_gather( + result, self.tensor1, group=gp, use_calc_stream=True) + assert np.array_equal(result[0], self.tensor1) + assert np.array_equal(result[1], self.tensor1) + print("test all_gather api ok") + + paddle.distributed.barrier(group=gp) + print("test barrier api ok") + + return + + +if __name__ == "__main__": + gpt = TestNewGroupAPI() + gpt.test_all() diff --git a/python/paddle/fluid/tests/unittests/test_new_group.sh b/python/paddle/fluid/tests/unittests/test_new_group.sh new file mode 100755 index 00000000000..998ead8db32 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_new_group.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e + +CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch --gpus=0,1 new_group.py diff --git a/python/paddle/fluid/tests/unittests/test_new_group_api.py b/python/paddle/fluid/tests/unittests/test_new_group_api.py new file mode 100644 index 00000000000..b9b80d3b431 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_new_group_api.py @@ -0,0 +1,35 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +import paddle + +from test_collective_api_base import TestDistBase + +paddle.enable_static() + + +class TestCollectiveAllreduceAPI(TestDistBase): + def _setup_config(self): + pass + + def test_allreduce_nccl(self): + self.check_with_place("collective_allreduce_new_group_api.py", + "allreduce", "nccl") + + +if __name__ == '__main__': + unittest.main() -- GitLab