diff --git a/paddle/fluid/distributed/collective/ProcessGroup.h b/paddle/fluid/distributed/collective/ProcessGroup.h index 48dd6d8285699814b41ae4dde4ec2d6055a4e006..be3bfc0dc00290a8e20a94c5c254cd99e62bc58b 100644 --- a/paddle/fluid/distributed/collective/ProcessGroup.h +++ b/paddle/fluid/distributed/collective/ProcessGroup.h @@ -46,6 +46,7 @@ enum class CommType : std::uint8_t { SEND = 9, RECV = 10, BARRIER = 11, + ALLTOALL_SINGLE = 12, UNKNOWN = 100, }; @@ -143,6 +144,15 @@ class ProcessGroup { "ProcessGroup%s does not support AllToAll", GetBackendName())); } + virtual std::shared_ptr AllToAll_Single( + std::vector&, // NOLINT + std::vector&, // NOLINT + std::vector&, + std::vector&) { + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support AllToAll_Single", GetBackendName())); + } + virtual std::shared_ptr Reduce( std::vector&, // NOLINT std::vector&, // NOLINT @@ -159,6 +169,14 @@ class ProcessGroup { "ProcessGroup%s does not support Scatter", GetBackendName())); } + virtual std::shared_ptr _ReduceScatterBase( + phi::DenseTensor&, // NOLINT + phi::DenseTensor&, // NOLINT + const ReduceScatterOptions&) { // NOLINT + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support ReduceScatter", GetBackendName())); + } + protected: const int rank_; const int size_; diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc index e6e69f0be3ae5488ae0117ca1a70c706325c2644..1beca8022e9f9833d9305841e4823116808bdd6d 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc @@ -85,6 +85,34 @@ bool ProcessGroupNCCL::NCCLTask::IsCompleted() { return true; } +void ProcessGroupNCCL::CheckSplitSizes(std::vector& split_sizes, + std::vector tensor_shape) { + int64_t len_size = split_sizes.size(); + if (len_size == 0) { + PADDLE_ENFORCE_EQ(tensor_shape[0] % size_ == 0, + true, + platform::errors::InvalidArgument( + "Tensor's dim[0] must be divisible by group size " + "when split_sizes not given.")); + split_sizes.insert(split_sizes.end(), + size_, + static_cast(tensor_shape[0] / size_)); + } else { + PADDLE_ENFORCE_EQ( + len_size == size_, + true, + platform::errors::InvalidArgument( + "The length of split_sizes must be equal to group size.")); + auto sum_size = std::accumulate( + split_sizes.begin(), split_sizes.end(), static_cast(0)); + PADDLE_ENFORCE_EQ( + sum_size == tensor_shape[0], + true, + platform::errors::InvalidArgument( + "The sum of split_sizes must be equal to tensor's dim[0].")); + } +} + // TODO(sheniang03): Add timeout for wait, now timeout unused bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) { SynchronizeStreams(); @@ -637,7 +665,69 @@ std::shared_ptr ProcessGroupNCCL::AllToAll( } PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); }, - CommType::ALLREDUCE); + CommType::ALLTOALL); +} + +std::shared_ptr ProcessGroupNCCL::AllToAll_Single( + std::vector& in_tensors, + std::vector& out_tensors, + std::vector& in_sizes, + std::vector& out_sizes) { + PADDLE_ENFORCE_EQ( + CheckTensorsInCudaPlace(in_tensors), + true, + platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); + PADDLE_ENFORCE_EQ( + CheckTensorsInCudaPlace(out_tensors), + true, + platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); + return Collective( + in_tensors, + out_tensors, + [&](phi::DenseTensor& input, + phi::DenseTensor& output, + ncclComm_t comm, + const gpuStream_t& stream) { + PADDLE_ENFORCE_EQ(input.dtype() == output.dtype(), + true, + platform::errors::InvalidArgument( + "The dtypes of input and output must be equal.")); + + std::vector in_dims = phi::vectorize(input.dims()); + std::vector out_dims = phi::vectorize(output.dims()); + CheckSplitSizes(in_sizes, in_dims); + CheckSplitSizes(out_sizes, out_dims); + + size_t in_offset = 0, out_offset = 0; + size_t in_length = 0, out_length = 0; + size_t in_row_size = input.numel() / in_dims[0]; + size_t out_row_size = output.numel() / out_dims[0]; + + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto i = 0; i < size_; i++) { + in_length = in_sizes[i] * in_row_size; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( + GetPointerByOffset(input.data(), in_offset, input.dtype()), + in_length, + platform::ToNCCLDataType(input.dtype()), + i, + comm, + stream)); + in_offset += in_length; + + out_length = out_sizes[i] * out_row_size; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( + GetPointerByOffset(output.data(), out_offset, input.dtype()), + out_length, + platform::ToNCCLDataType(input.dtype()), + i, + comm, + stream)); + out_offset += out_length; + } + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + }, + CommType::ALLTOALL_SINGLE); } std::shared_ptr ProcessGroupNCCL::Reduce( @@ -721,5 +811,57 @@ std::shared_ptr ProcessGroupNCCL::Scatter( CommType::SCATTER); } +std::shared_ptr ProcessGroupNCCL::_ReduceScatterBase( + phi::DenseTensor& out_tensor, + phi::DenseTensor& in_tensor, + const ReduceScatterOptions& opts) { + // auto tensor = out_tensors.back(); + PADDLE_ENFORCE_EQ( + out_tensor.dtype(), + in_tensor.dtype(), + platform::errors::InvalidArgument( + "Input tensor and output tensor should be same dtype.")); + + PADDLE_ENFORCE_EQ( + out_tensor.numel() * size_, + in_tensor.numel(), + platform::errors::InvalidArgument("input tensor must be the same size as " + "output tensor size times world_size")); + + auto inputs = std::vector{in_tensor}; + auto outputs = std::vector{out_tensor}; + + return Collective( + inputs, + outputs, + [&](phi::DenseTensor& input, + phi::DenseTensor& output, + ncclComm_t comm, + const gpuStream_t& stream) { + if (FLAGS_use_stream_safe_cuda_allocator) { + platform::CUDADeviceGuard cuda_guard; + cuda_guard.SetDevice(output.place()); + memory::RecordStream(output.Holder(), stream); + } + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduceScatter( + input.data(), + output.data(), + output.numel(), + platform::ToNCCLDataType(input.dtype()), + ToNCCLRedType(opts.reduce_op), + comm, + stream)); + }, + CommType::REDUCE_SCATTER); +} + +void ProcessGroupNCCL::GroupStart() { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); +} + +void ProcessGroupNCCL::GroupEnd() { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); +} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h index a26f5947ce2b8c3bbf335ad4deca2813b35d8b09..a8adffe64e70d14cd20a37367f45de278dc44041 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h @@ -129,6 +129,12 @@ class ProcessGroupNCCL : public ProcessGroup { std::vector& in, std::vector& out) override; + std::shared_ptr AllToAll_Single( + std::vector& in, + std::vector& out, + std::vector& in_sizes, + std::vector& out_sizes) override; + std::shared_ptr Reduce( std::vector& tensors, std::vector& out_tensors, @@ -139,6 +145,15 @@ class ProcessGroupNCCL : public ProcessGroup { std::vector& out_tensors, const ScatterOptions&) override; + std::shared_ptr _ReduceScatterBase( + phi::DenseTensor&, // NOLINT + phi::DenseTensor&, // NOLINT + const ReduceScatterOptions&) override; + + static void GroupStart(); + + static void GroupEnd(); + protected: virtual std::shared_ptr CreateTask( std::vector places, @@ -162,8 +177,8 @@ class ProcessGroupNCCL : public ProcessGroup { std::set used_place_ids_; private: - void BcastNCCLId(std::vector& nccl_ids, - int root, // NOLINT + void BcastNCCLId(std::vector& nccl_ids, // NOLINT + int root, // NOLINT int server_fd); void BroadcastUniqueNCCLID(std::vector& nccl_ids); // NOLINT @@ -190,6 +205,9 @@ class ProcessGroupNCCL : public ProcessGroup { void CreateNCCLManagerCache(const std::string& places_key, const std::vector& places); + + void CheckSplitSizes(std::vector& split_sizes, + std::vector tensor_shape); }; } // namespace distributed diff --git a/paddle/fluid/distributed/collective/Types.h b/paddle/fluid/distributed/collective/Types.h index 973f7c643542757c0bce68f8ccdefeadc97f15d4..0ce92111f6a13579cf61dbce90565333affeb723 100644 --- a/paddle/fluid/distributed/collective/Types.h +++ b/paddle/fluid/distributed/collective/Types.h @@ -45,5 +45,9 @@ struct ScatterOptions { int root_rank = 0; }; +struct ReduceScatterOptions { + ReduceOp reduce_op = ReduceOp::SUM; +}; + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index bdaebf13f8d2a5d1c7032bad49b479ed03e3fec1..b8d5a0de820e7529b016b335e470ea80ca15ba08 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -225,6 +225,30 @@ void BindDistributed(py::module *m) { py::arg("out"), py::call_guard()) + .def( + "alltoall_single", + [](distributed::ProcessGroup &self, + py::handle py_in_tensor, + py::handle py_out_tensor, + std::vector in_sizes, + std::vector out_sizes) { + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); + auto in_dense = std::dynamic_pointer_cast( + in_tensor.impl()); + auto out_dense = std::dynamic_pointer_cast( + out_tensor.impl()); + std::vector in_tensors = {*in_dense}; + std::vector out_tensors = {*out_dense}; + return self.AllToAll_Single( + in_tensors, out_tensors, in_sizes, out_sizes); + }, + py::arg("in"), + py::arg("out"), + py::arg("in_sizes"), + py::arg("out_sizes"), + py::call_guard()) + .def( "reduce", [](distributed::ProcessGroup &self, @@ -244,7 +268,6 @@ void BindDistributed(py::module *m) { py::arg("dst"), py::arg("op") = distributed::ReduceOp::SUM, py::call_guard()) - .def( "scatter", [](distributed::ProcessGroup &self, @@ -266,23 +289,50 @@ void BindDistributed(py::module *m) { py::arg("in"), py::arg("out"), py::arg("src"), + py::call_guard()) + .def( + "_reduce_scatter_base", + [](distributed::ProcessGroup &self, + py::handle py_out_tensor, + py::handle py_in_tensor, + distributed::ReduceOp op) { + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); + distributed::ReduceScatterOptions opts; + opts.reduce_op = op; + auto dense_out = std::dynamic_pointer_cast( + out_tensor.impl()); + auto dense_in = std::dynamic_pointer_cast( + in_tensor.impl()); + return self._ReduceScatterBase(*dense_out, *dense_in, opts); + }, + py::arg("out_tensor"), + py::arg("in_tensor"), + py::arg("op") = distributed::ReduceOp::SUM, py::call_guard()); #if defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL) - py::class_>( - *m, "ProcessGroupNCCL", ProcessGroup) - .def(py::init &, - int, - int, - const platform::CUDAPlace &, - int>(), - py::arg("store"), - py::arg("rank"), - py::arg("world_size"), - py::arg("place"), - py::arg("group_id") = 0, - py::call_guard()); + auto processGroupNCCL = + py::class_>( + *m, "ProcessGroupNCCL", ProcessGroup) + .def(py::init &, + int, + int, + const platform::CUDAPlace &, + int>(), + py::arg("store"), + py::arg("rank"), + py::arg("world_size"), + py::arg("place"), + py::arg("group_id") = 0, + py::call_guard()); + + processGroupNCCL.def_static( + "group_start", []() { distributed::ProcessGroupNCCL::GroupStart(); }); + processGroupNCCL.def_static( + "group_end", []() { distributed::ProcessGroupNCCL::GroupEnd(); }); + #endif #if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_PSCORE) && \ diff --git a/python/paddle/distributed/__init__.py b/python/paddle/distributed/__init__.py index 003a14799c53e9565a1d28e953d9338742b76e71..ab83e2929e4bc3027084405bf423a79b227a2483 100644 --- a/python/paddle/distributed/__init__.py +++ b/python/paddle/distributed/__init__.py @@ -41,6 +41,14 @@ from .collective import recv # noqa: F401 from .collective import get_group # noqa: F401 from .collective import send # noqa: F401 from .collective import wait # noqa: F401 +from .collective import is_initialized # noqa: F401 +from .collective import destroy_process_group # noqa: F401 +from .collective import alltoall_single # noqa: F401 +from .collective import isend # noqa: F401 +from .collective import irecv # noqa: F401 +from .collective import batch_isend_irecv # noqa: F401 +from .collective import P2POp # noqa: F401 +from .collective import reduce_scatter # noqa: F401 from .auto_parallel import shard_op # noqa: F401 from .auto_parallel import shard_tensor # noqa: F401 @@ -59,33 +67,11 @@ from . import utils # noqa: F401 from .sharding import * # noqa: F401 __all__ = [ # noqa - "spawn", - "launch", - "scatter", - "broadcast", - "ParallelEnv", - "new_group", - "init_parallel_env", - "gloo_init_parallel_env", - "gloo_barrier", - "gloo_release", - "QueueDataset", - "split", - "CountFilterEntry", - "ShowClickEntry", - "get_world_size", - "get_group", - "all_gather", - "InMemoryDataset", - "barrier", - "all_reduce", - "alltoall", - "send", - "reduce", - "recv", - "ReduceOp", - "wait", - "get_rank", - "ProbabilityEntry", - "ParallelMode", + "spawn", "launch", "scatter", "broadcast", "ParallelEnv", "new_group", + "init_parallel_env", "gloo_init_parallel_env", "gloo_barrier", + "gloo_release", "QueueDataset", "split", "CountFilterEntry", + "ShowClickEntry", "get_world_size", "get_group", "all_gather", + "InMemoryDataset", "barrier", "all_reduce", "alltoall", "send", "reduce", + "recv", "ReduceOp", "wait", "get_rank", "ProbabilityEntry", "ParallelMode", + "is_initialized", "isend", "irecv", "reduce_scatter" ] diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index cb634a4b6ac1a78bc09967ec1736afbe4d15c576..2506c3073941afa6a86b6739abc169e870e69a08 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -36,6 +36,7 @@ import paddle.fluid as fluid import paddle.fluid.core as core from paddle import _C_ops import paddle.fluid.dygraph_utils as dygraph_utils +import contextlib __all__ = [] @@ -136,6 +137,10 @@ _group_map = {} # Dict[name, Group] _group_map_by_name = {} +# backend map by group : the map of all backend from their groups +# Dict[group, backend] +_group_map_backend = {} + # Name of the default group for init_parallel_env _default_group_name = "_default_pg" @@ -175,9 +180,8 @@ def _get_group_map_by_name(): def _get_default_group(): global _group_map_by_name - assert _default_group_name in _group_map_by_name, ( - "Call paddle.distributed.init_parallel_env first " - "to initialize the distributed environment.") + assert is_initialized(), ("Call paddle.distributed.init_parallel_env first " + "to initialize the distributed environment.") return _get_group_map_by_name()[_default_group_name] @@ -193,10 +197,29 @@ def _set_group_map_by_name(name, group): _group_map_by_name[name] = group +def _set_group_map_backend(group, backend): + global _group_map_backend + assert group not in _group_map_backend + _group_map_backend[group] = backend + + def _new_ring_id(): return len(_get_group_map()) + max(_get_global_env().nrings, 9) +def _get_reduce_op(reduce_op, func_name): + if reduce_op == ReduceOp.SUM: + return core.ReduceOp.SUM + elif reduce_op == ReduceOp.MAX: + return core.ReduceOp.MAX + elif reduce_op == ReduceOp.MIN: + return core.ReduceOp.MIN + elif reduce_op == ReduceOp.PROD: + return core.ReduceOp.PRODUCT + else: + raise ValueError("Unknown reduce_op type for {}.".format(func_name)) + + def get_group(id=0): """ @@ -400,6 +423,7 @@ def new_group(ranks=None, backend=None): group = Group(rank, size, id=gid, ranks=ranks, pg=pg, name=group_name) _group_map_by_name[group_name] = group _group_map[gid] = group + _group_map_backend[group] = backend # TODO(shenliang03): This is a temporary solution to solve the problem of # hang caused by tcp @@ -462,6 +486,75 @@ def new_group(ranks=None, backend=None): return gp +def is_initialized(): + """ + + Check whether the distributed environment has been initialized + + Returns (bool): `True` if distributed environment has been initialized, otherwise `False`. + + Examples: + .. code-block:: python + + # required: distributed + import paddle + + print(paddle.distributed.is_initialized()) + # False + + paddle.distributed.init_parallel_env() + print(paddle.distributed.is_initialized()) + # True + + """ + global _group_map_by_name + return _default_group_name in _group_map_by_name + + +def destroy_process_group(group=None): + """ + Destroy a given group for communication + + Args: + group (ProcessGroup, optional): The group to be destroyed. All of process groups, including + the default group, will be destroyed and the distributed + environment will be deinitialized. + + Returns : None + + Examples: + .. code-block:: python + + # required: distributed + import paddle + + paddle.distributed.init_parallel_env() + group = paddle.distributed.new_group([0, 1]) + + paddle.distributed.destroy_process_group(group) + print(paddle.distributed.is_initialized()) + # True + paddle.distributed.destroy_process_group() + print(paddle.distributed.is_initialized()) + # False + + """ + global _group_map + global _group_map_by_name + + pg = _get_default_group() if group is None else group + assert _group_map.get(pg.id, None) is not None, "Invalid group." + + if group is None: + _group_map.clear() + _group_map_by_name.clear() + _group_map_backend.clear() + else: + del _group_map[pg.id] + del _group_map_by_name[pg.name] + del _group_map_backend[pg] + + def wait(tensor, group=None, use_calc_stream=True): """ @@ -663,16 +756,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True): return if in_dygraph_mode(): - if op == ReduceOp.SUM: - op_type = core.ReduceOp.SUM - elif op == ReduceOp.MAX: - op_type = core.ReduceOp.MAX - elif op == ReduceOp.MIN: - op_type = core.ReduceOp.MIN - elif op == ReduceOp.PROD: - op_type = core.ReduceOp.PRODUCT - else: - raise ValueError("Unknown reduce_op type for allreduce.") + op_type = _get_reduce_op(op, "all_reduce") group = _get_default_group() if group is None else group task = group.process_group.allreduce(tensor, op_type) if use_calc_stream: @@ -768,16 +852,7 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True): return if in_dygraph_mode(): - if op == ReduceOp.SUM: - op_type = core.ReduceOp.SUM - elif op == ReduceOp.MAX: - op_type = core.ReduceOp.MAX - elif op == ReduceOp.MIN: - op_type = core.ReduceOp.MIN - elif op == ReduceOp.PROD: - op_type = core.ReduceOp.PRODUCT - else: - raise ValueError("Unknown reduce_op type for reduce.") + op_type = _get_reduce_op(op, "reduce") group = _get_default_group() if group is None else group gdst = group.get_group_rank(dst) assert gdst >= 0, ("dst rank out of group, need global rank") @@ -1781,10 +1856,10 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True): Args: in_tensor_list (list): A list of input Tensors. Every element in the list must be a Tensor whose data type should be float16, float32, float64, int32 or int64. - out_tensor_list (Tensor): A list of output Tensors. The data type of its elements should be the same as the + out_tensor_list (list): A list of output Tensors. The data type of its elements should be the same as the data type of the input Tensors. group (Group, optional): The group instance return by new_group or None for global default group. Default: None. - use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream. Default: True. + use_calc_stream (bool, optional): Whether to use calculation stream (True) or communication stream. Default: True. Returns: None. @@ -1867,6 +1942,94 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True): out_tensor_list.extend(paddle.split(out, nranks, 0)) +def alltoall_single(in_tensor, + out_tensor, + in_split_sizes=None, + out_split_sizes=None, + group=None, + use_calc_stream=True): + """ + Scatter a single input tensor to all participators and gather the received tensors in out_tensor. + + .. note:: + ``alltoall_single`` is only supported in eager mode. + + Args: + in_tensor (Tensor): Input tensor. The data type should be float16, float32, float64, int32 or int64. + out_tensor (Tensor): Output Tensor. The data type should be the same as the data type of the input Tensor. + in_split_sizes (list[int], optional): Split sizes of ``in_tensor`` for dim[0]. If not given, dim[0] of ``in_tensor`` + must be divisible by group size and ``in_tensor`` will be scattered averagely to all participators. Default: None. + out_split_sizes (list[int], optional): Split sizes of ``out_tensor`` for dim[0]. If not given, dim[0] of ``out_tensor`` + must be divisible by group size and ``out_tensor`` will be gathered averagely from all participators. Default: None. + group (Group, optional): The group instance return by ``new_group`` or None for global default group. Default: None. + use_calc_stream (bool, optional): Whether to use calculation stream (True) or communication stream. Default: True. + + Returns: + None, if ``use_calc_stream`` is set to ``True``; ``Task`` of ``group``, if ``use_calc_stream`` is set to ``False``. + + Examples: + .. code-block:: python + + # required: distributed + import paddle + import paddle.distributed as dist + + dist.init_parallel_env() + rank = dist.get_rank() + size = dist.get_world_size() + + # case 1 + input = paddle.arange(2, dtype='int64') + rank * 2 + # input for rank 0: [0, 1] + # input for rank 1: [2, 3] + + output = paddle.empty([2], dtype='int64') + dist.alltoall_single(input, output) + # output for rank 0: [0, 2] + # output for rank 1: [1, 3] + + # case 2 + in_split_sizes = [i + 1 for i in range(size)] + # in_split_sizes for rank 0: [1, 2] and for rank 1: [1, 2] + out_split_sizes = [rank + 1 for i in range(size)] + # out_split_sizes for rank 0: [1, 1] and for rank 1: [2, 2] + + input = paddle.ones([sum(in_split_sizes), size], dtype='float32') * rank + # input for rank 0: [[0., 0.], [0., 0.], [0., 0.]] + # input for rank 1: [[1., 1.], [1., 1.], [1., 1.]] + output = paddle.empty([(rank + 1) * size, size], dtype='float32') + + group = dist.new_group([0, 1]) + task = dist.alltoall_single(input, + output, + in_split_sizes, + out_split_sizes, + use_calc_stream=False, + group=group) + task.wait() + # output for rank 0: [[0., 0.], [1., 1.]] + # output for rank 1: [[0., 0.], [0., 0.], [1., 1.], [1., 1.]] + + """ + if group is not None and not group.is_member(): + return + + assert in_dygraph_mode(), "Only suppport alltoall_single in eager mode." + # _check_single_tensor + + group = _get_default_group() if group is None else group + in_split_sizes = [] if in_split_sizes is None else in_split_sizes + out_split_sizes = [] if out_split_sizes is None else out_split_sizes + + task = group.process_group.alltoall_single(in_tensor, out_tensor, + in_split_sizes, out_split_sizes) + if use_calc_stream: + task.wait() + return + else: + return task + + def send(tensor, dst=0, group=None, use_calc_stream=True): """ Send a tensor to the receiver. @@ -1902,7 +2065,8 @@ def send(tensor, dst=0, group=None, use_calc_stream=True): if in_dygraph_mode(): group = _get_default_group() if group is None else group - task = group.process_group.send(tensor, dst) + group_dst_rank = group.get_group_rank(dst) + task = group.process_group.send(tensor, group_dst_rank) if use_calc_stream: task.wait() return None @@ -1964,7 +2128,8 @@ def recv(tensor, src=0, group=None, use_calc_stream=True): if in_dygraph_mode(): group = _get_default_group() if group is None else group - task = group.process_group.recv(tensor, src) + group_src_rank = group.get_group_rank(src) + task = group.process_group.recv(tensor, group_src_rank) if use_calc_stream: task.wait() return None @@ -1991,3 +2156,390 @@ def recv(tensor, src=0, group=None, use_calc_stream=True): 'dtype': tensor.dtype, 'use_calc_stream': use_calc_stream, }) + + +def _check_single_tensor(tensor, tensor_name): + if not isinstance(tensor, (core.eager.Tensor, paddle.Tensor)): + raise RuntimeError("Invalid function argument. Expected parameter {}" + "to be of type paddle.Tensor, but it's {}".format( + tensor_name, type(tensor))) + + +def _check_tensor_list(tensor_list, tensor_name): + if not isinstance(tensor_list, list) or \ + not all(isinstance(t, (core.eager.Tensor, paddle.Tensor)) for t in tensor_list): + raise RuntimeError("Invalid function argument. Expected parameter {}" + "to be of type paddle.Tensor".format(tensor_name)) + + +def isend(tensor, dst, group=None): + """ + Sends a tensor asynchronously + + Args: + tensor (Tensor): The Tensor to send. Its data type + should be float16, float32, float64, int32 or int64. + dst (int): The destination rank. + group (Group, optional): The group instance return by new_group or None for global default group. Default: None. + + Returns: + A distributed task object. + + Warning: + This API only supports the dygraph mode. + + Examples: + .. code-block:: python + + # required: distributed + import paddle + import paddle.distributed as dist + + dist.init_parallel_env() + rank = dist.get_rank() + world_size = dist.get_world_size() + + if rank == 0: + data = paddle.to_tensor([7, 8, 9]) + task = paddle.distributed.isend(data, dst=1) + else: + data = paddle.to_tensor([1, 2, 3]) + task = paddle.distributed.irecv(data, src=0) + + task.wait() + + print(data) + # paddle.tensor([7, 8, 9]) # Rank-0 + # paddle.tensor([7, 8, 9]) # Rank-1 + + """ + _check_single_tensor(tensor, "tensor") + if group is not None and not group.is_member(): + return + + if in_dygraph_mode(): + group = _get_default_group() if group is None else group + group_dst_rank = group.get_group_rank(dst) + assert group_dst_rank >= 0, ("dst rank out of group, need global rank") + return group.process_group.send(tensor, group_dst_rank) + else: + raise RuntimeError("Don't support static graph mode currently.") + + +def irecv(tensor, src=None, group=None): + """ + Receive a tensor to the sender. + + Args: + tensor (Tensor): The Tensor to receive. Its data type + should be float16, float32, float64, int32 or int64. + src (int): The source rank id. + group (Group, optional): The group instance return by new_group or None for global default group. Default: None. + + Returns: + A distributed task object. + + Warning: + This API only supports the dygraph mode. + + Examples: + .. code-block:: python + + # required: distributed + import paddle + import paddle.distributed as dist + + dist.init_parallel_env() + rank = dist.get_rank() + world_size = dist.get_world_size() + + if rank == 0: + data = paddle.to_tensor([7, 8, 9]) + task = paddle.distributed.isend(data, dst=1) + else: + data = paddle.to_tensor([1, 2, 3]) + task = paddle.distributed.irecv(data, src=0) + + task.wait() + + print(data) + # paddle.tensor([7, 8, 9]) # Rank-0 + # paddle.tensor([7, 8, 9]) # Rank-1 + """ + _check_single_tensor(tensor, "tensor") + if group is not None and not group.is_member(): + return + + if in_dygraph_mode(): + group = _get_default_group() if group is None else group + group_src_rank = group.get_group_rank(src) + assert group_src_rank >= 0, ("src rank out of group, need global rank") + return group.process_group.recv(tensor, group_src_rank) + else: + raise RuntimeError("Don't support static graph mode currently.") + + +class P2POp(object): + """ + A class that makes point-to-point operations for "batch_isend_irecv". + + This class creates the type of P2P operation, communication buffer, peer rank, + Group. Instances of this class will be passed to + ``paddle.distributed.batch_isend_irecv`` for point-to-point communication. + + Args: + op (callable): A function to send data to or receive data from a peer process. + The type of ``op`` is either ``paddle.distributed.isend`` or ``paddle.distributed.irecv``. + tensor (Tensor): Tensor to send or receive. + peer (int): The destination or source rank. + group (Group, optional): The group instance return by new_group or None for global + default group. Default: None. + + """ + + def __init__(self, op, tensor, peer, group=None): + if op not in [isend, irecv]: + raise RuntimeError("Invalid ``op`` function. Expected ``op`` " + "to be of type ``paddle.distributed.isend`` or " + "``paddle.distributed.irecv``.") + _check_single_tensor(tensor, "tensor") + + self.op = op + self.tensor = tensor + self.peer = peer + self.group = _get_default_group() if group is None else group + + +@contextlib.contextmanager +def _with_batch_p2p_guard(backend): + if backend == "nccl": + core.ProcessGroupNCCL.group_start() + try: + yield + finally: + if backend == "nccl": + core.ProcessGroupNCCL.group_end() + + +def _check_p2p_op_list(p2p_op_list): + """ + Helper to check that the ``p2p_op_list`` is a list of P2POp instances and + all ops use the same backend. + """ + if not isinstance(p2p_op_list, list) or not all( + isinstance(p2p_op, P2POp) for p2p_op in p2p_op_list): + raise RuntimeError("Invalid ``p2p_op_list``. Each op is expected to " + "to be of type ``paddle.distributed.P2POp``.") + + backend = _group_map_backend[p2p_op_list[0].group] + if not all(backend == _group_map_backend[p2p_op.group] + for p2p_op in p2p_op_list): + raise RuntimeError("All groups need to use the same backend.") + + +def batch_isend_irecv(p2p_op_list): + """ + Send or Receive a batch of tensors asynchronously and return a list of requests. + + Process each of the point-to-point operations in ``p2p_op_list`` and return the + corresponding tasks. NCCL are currently supported. + + Args: + p2p_op_list: A list of point-to-point operations(type of each operator is + ``paddle.distributed.P2POp``). The order of the isend/irecv in the list + matters and it needs to match with corresponding isend/irecv on the + remote end. + + Returns: + A list of distributed tasks returned by calling the corresponding + op in the op_list. + + Warning: + This API only supports the dygraph mode. + + Examples: + .. code-block:: python + + # required: distributed + + import paddle + import paddle.distributed as dist + + dist.init_parallel_env() + rank = dist.get_rank() + world_size = dist.get_world_size() + + send_t = paddle.arange(2) + rank + # paddle.tensor([0, 1]) # Rank-0 + # paddle.tensor([1, 2]) # Rank-1 + + recv_t = paddle.empty(shape=[2], dtype=send_t.dtype) + + send_op = dist.P2POp(dist.isend, send_t, (rank + 1) % world_size) + recv_op = dist.P2POp(dist.irecv, recv_t, (rank - 1 + world_size) % world_size) + + tasks = dist.batch_isend_irecv([send_op, recv_op]) + + for task in tasks: + task.wait() + + print(recv_t) + # paddle.tensor([1, 2]) # Rank-0 + # paddle.tensor([0, 1]) # Rank-1 + """ + _check_p2p_op_list(p2p_op_list) + group = p2p_op_list[0].group + if group is not None and not group.is_member(): + return + + if in_dygraph_mode(): + group = _get_default_group() if group is None else group + backend = _group_map_backend[group] + tasks = [] + with _with_batch_p2p_guard(backend): + for p2p_op in p2p_op_list: + op = p2p_op.op + tensor = p2p_op.tensor + peer = p2p_op.peer + comm_group = p2p_op.group + task = op(tensor, peer, comm_group) + if task is not None: + tasks.append(task) + return tasks + else: + raise RuntimeError("Don't support static graph mode currently.") + + +def reduce_scatter(tensor, + tensor_list, + op=ReduceOp.SUM, + group=None, + use_calc_stream=True): + """ + Reduces, then scatters a list of tensors to all processes in a group + + Args: + tensor (Tensor): Output tensor. + tensor_list (list[Tensor]): List of tensors to reduce and scatter. + op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM. + group (Group, optional): The group instance return by new_group or None for global + default group. Default: None. + use_calc_stream (bool, optional): Whether this op should be an async op. + + Returns: + Async task handle, if use_calc_stream is set to False. + None, if use_calc_stream or if not part of the group. + + Warning: + This API only supports the dygraph mode. + + + Examples: + .. code-block:: python + + # required: distributed + import paddle + import paddle.distributed as dist + + dist.init_parallel_env() + rank = dist.get_rank() + world_size = dist.get_world_size() + + if rank == 0: + t1 = paddle.to_tensor([0, 1]) + t2 = paddle.to_tensor([2, 3]) + else: + t1 = paddle.to_tensor([4, 5]) + t2 = paddle.to_tensor([6, 7]) + + tensor_list = [t1, t2] + + output = paddle.empty(shape=[2], dtype=tensor_list[0].dtype) + dist.reduce_scatter(output, tensor_list) + + print(output) + # [4, 6] # Rank-0 + # [8, 10] # Rank-1 + + """ + _check_single_tensor(tensor, "tensor") + _check_tensor_list(tensor_list, "tensor_list") + + if group is not None and not group.is_member(): + return + + if in_dygraph_mode(): + op_type = _get_reduce_op(op, "reduce_scatter") + group = _get_default_group() if group is None else group + + temp = paddle.concat(tensor_list, axis=0) + task = group.process_group._reduce_scatter_base(tensor, temp, op_type) + if use_calc_stream: + task.wait() + return None + else: + return task + else: + raise RuntimeError("Don't support static graph mode currently.") + + +def _reduce_scatter_base(output, + input, + op=ReduceOp.SUM, + group=None, + use_calc_stream=True): + """ + Reduces, then scatters a flattened tensor to all processes in a group. + + Args: + output (Tensor): Output tensor. + input (Tensor): Input tensor that is of size output tensor size times world size + op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream (False). + Default to True. + Returns: + Async task handle, if use_calc_stream is set to False. + None, if use_calc_stream or if not part of the group. + + Examples: + .. code-block:: python + + # required: distributed + + import paddle + import paddle.distributed as dist + + dist.init_parallel_env() + rank = dist.get_rank() + world_size = dist.get_world_size() + + input = paddle.arange(4) + rank + # [0, 1, 2, 3] # Rank-0 + # [1, 2, 3, 4] # Rank-1 + + output = paddle.empty(shape=[2], dtype=input.dtype) + paddle.distributed.collective._reduce_scatter_base(output, input) + print(output) + # [1, 3] # Rank-0 + # [5, 7] # Rank-1 + + """ + _check_single_tensor(output, "output") + _check_single_tensor(input, "input") + + if group is not None and not group.is_member(): + return + + if in_dygraph_mode(): + op_type = _get_reduce_op(op, "_reduce_scatter_base") + group = _get_default_group() if group is None else group + task = group.process_group._reduce_scatter_base(output, input, op_type) + if use_calc_stream: + task.wait() + return None + else: + return task + else: + raise RuntimeError("Don't support static graph mode currently.") diff --git a/python/paddle/distributed/parallel.py b/python/paddle/distributed/parallel.py index 52d19ae52b2bafca056ac4906d5849bb319842fb..e95b771fe6f6a428d9c5a5b4f005e20792c5fa14 100644 --- a/python/paddle/distributed/parallel.py +++ b/python/paddle/distributed/parallel.py @@ -42,6 +42,7 @@ from paddle.distributed.collective import _set_default_backend from paddle.distributed.collective import _set_default_store from paddle.distributed.collective import _new_process_group_impl from paddle.distributed.collective import Group +from paddle.distributed.collective import _set_group_map_backend __all__ = [] @@ -257,6 +258,7 @@ def init_parallel_env(): name=_default_group_name) _set_group_map_by_name(_default_group_name, group) _set_group_map(0, group) + _set_group_map_backend(group, backend) parallel_helper._set_parallel_ctx(True) paddle.distributed.barrier(group=group) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 06bec07d7acaf778367f2082170b89f230ac5d1c..606f39c5e3b424d291773f492a17d872fd8104b0 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -72,7 +72,10 @@ list(APPEND DIST_TEST_OPS test_auto_parallel_data_unshard) list(APPEND DIST_TEST_OPS test_auto_parallel_save_load) list(APPEND DIST_TEST_OPS test_auto_parallel_autoconvert) list(APPEND DIST_TEST_OPS test_collective_process_group) +list(APPEND DIST_TEST_OPS test_collective_alltoall_single) list(APPEND DIST_TEST_OPS test_eager_dist_api) +list(APPEND DIST_TEST_OPS test_collective_batch_isend_irecv) +list(APPEND DIST_TEST_OPS test_collective_reduce_scatter) set(MIXED_DIST_TEST_OPS ${DIST_TEST_OPS}) #remove distribute unittests. list(APPEND MIXED_DIST_TEST_OPS test_dgc_op) @@ -334,7 +337,11 @@ if((NOT WITH_GPU) AND (NOT WITH_ROCM)) list(REMOVE_ITEM TEST_OPS test_auto_parallel_save_load) list(REMOVE_ITEM TEST_OPS test_auto_parallel_autoconvert) list(REMOVE_ITEM TEST_OPS test_collective_process_group) + list(REMOVE_ITEM TEST_OPS test_collective_alltoall_single) list(REMOVE_ITEM TEST_OPS test_eager_dist_api) + list(REMOVE_ITEM TEST_OPS test_collective_batch_isend_irecv) + list(REMOVE_ITEM TEST_OPS test_collective_reduce_scatter) + elseif(WITH_GPU) if(${CUDNN_VERSION} VERSION_LESS 7100) list(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op) @@ -1569,8 +1576,10 @@ if(WITH_DISTRIBUTE set_tests_properties(test_auto_parallel_save_load PROPERTIES TIMEOUT 120) set_tests_properties(test_auto_parallel_autoconvert PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_process_group PROPERTIES TIMEOUT 120) + set_tests_properties(test_collective_alltoall_single PROPERTIES TIMEOUT 60) set_tests_properties(test_eager_dist_api PROPERTIES TIMEOUT 100) - + set_tests_properties(test_collective_batch_isend_irecv PROPERTIES TIMEOUT 100) + set_tests_properties(test_collective_reduce_scatter PROPERTIES TIMEOUT 100) if(${NCCL_VERSION} VERSION_GREATER_EQUAL 2212) set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 200) diff --git a/python/paddle/fluid/tests/unittests/collective_alltoall_single.py b/python/paddle/fluid/tests/unittests/collective_alltoall_single.py new file mode 100644 index 0000000000000000000000000000000000000000..cb6777d20bc25bbf50fcdd0c1820fd55a18ffb4b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective_alltoall_single.py @@ -0,0 +1,86 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import unittest + +import paddle +import numpy as np +import random +import paddle.distributed as dist +import paddle.fluid as fluid +import paddle.distributed.fleet as fleet +from paddle import framework + + +class TestCollectiveAllToAllSingle(unittest.TestCase): + + def setUp(self): + assert not paddle.distributed.is_initialized(), \ + "The distributed environment has not been initialized." + dist.init_parallel_env() + assert paddle.distributed.is_initialized(), \ + "The distributed environment has been initialized." + + paddle.fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True}) + + def test_collective_alltoall_single(self): + rank = dist.get_rank() + size = dist.get_world_size() + + # case 1 + input = paddle.ones([size, size], dtype='int64') * rank + output = paddle.empty([size, size], dtype='int64') + expected_output = paddle.concat( + [paddle.ones([1, size], dtype='int64') * i for i in range(size)]) + + group = dist.new_group([0, 1]) + dist.alltoall_single(input, output, group=group) + + np.testing.assert_allclose(output.numpy(), expected_output.numpy()) + dist.destroy_process_group(group) + + # case 2 + in_split_sizes = [i + 1 for i in range(size)] + out_split_sizes = [rank + 1 for i in range(size)] + + input = paddle.ones([sum(in_split_sizes), size], dtype='float32') * rank + output = paddle.empty([(rank + 1) * size, size], dtype='float32') + expected_output = paddle.concat([ + paddle.ones([rank + 1, size], dtype='float32') * i + for i in range(size) + ]) + + group = dist.new_group([0, 1]) + task = dist.alltoall_single(input, + output, + in_split_sizes, + out_split_sizes, + use_calc_stream=False, + group=group) + task.wait() + + np.testing.assert_allclose(output.numpy(), expected_output.numpy()) + dist.destroy_process_group(group) + + def tearDown(self): + dist.destroy_process_group() + assert not paddle.distributed.is_initialized(), \ + "The distributed environment has been deinitialized." + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective_batch_isend_irecv.py b/python/paddle/fluid/tests/unittests/collective_batch_isend_irecv.py new file mode 100644 index 0000000000000000000000000000000000000000..5aa309a2bbe5dc2f28c0c9afc69f2b820939f8f5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective_batch_isend_irecv.py @@ -0,0 +1,57 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import unittest + +import paddle +import numpy as np +import random +import paddle.distributed as dist +import paddle.fluid as fluid +import paddle.distributed.fleet as fleet +from paddle import framework + + +class TestCollectiveBatchIsendIrecv(unittest.TestCase): + + def setUp(self): + dist.init_parallel_env() + paddle.fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True}) + + def test_collective_batch_isend_irecv(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + send_t = paddle.arange(2) + rank + # paddle.tensor([0, 1]) # Rank-0 + # paddle.tensor([1, 2]) # Rank-1 + recv_t = paddle.empty(shape=[2], dtype=send_t.dtype) + send_op = dist.P2POp(dist.isend, send_t, (rank + 1) % world_size) + recv_op = dist.P2POp(dist.irecv, recv_t, + (rank - 1 + world_size) % world_size) + tasks = dist.batch_isend_irecv([send_op, recv_op]) + + for task in tasks: + task.wait() + + if rank == 0: + np.testing.assert_allclose(recv_t.numpy(), [1, 2]) + elif rank == 1: + np.testing.assert_allclose(recv_t.numpy(), [0, 1]) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective_reduce_scatter.py b/python/paddle/fluid/tests/unittests/collective_reduce_scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..0e36296e4089cf17200e02f074dd2500bcc67044 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective_reduce_scatter.py @@ -0,0 +1,98 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import unittest + +import paddle +import numpy as np +import random +import paddle.distributed as dist +import paddle.fluid as fluid +import paddle.distributed.fleet as fleet +from paddle import framework + + +class TestCollectiveReduceScatter(unittest.TestCase): + + def setUp(self): + dist.init_parallel_env() + paddle.fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True}) + + def test_collective_reduce_scatter_sum(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + + if rank == 0: + t1 = paddle.to_tensor([0, 1]) + t2 = paddle.to_tensor([2, 3]) + else: + t1 = paddle.to_tensor([4, 5]) + t2 = paddle.to_tensor([6, 7]) + + input_list = [t1, t2] + + output = paddle.empty(shape=[2], dtype=input_list[0].dtype) + dist.reduce_scatter(output, input_list) + + if rank == 0: + np.testing.assert_allclose(output.numpy(), [4, 6]) + elif rank == 1: + np.testing.assert_allclose(output.numpy(), [8, 10]) + + def test_collective_reduce_scatter_max(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + + if rank == 0: + t1 = paddle.to_tensor([0, 1], dtype="float16") + t2 = paddle.to_tensor([2, 3], dtype="float16") + else: + t1 = paddle.to_tensor([4, 5], dtype="float16") + t2 = paddle.to_tensor([6, 7], dtype="float16") + + input_list = [t1, t2] + + output = paddle.empty(shape=[2], dtype=input_list[0].dtype) + dist.reduce_scatter(output, input_list, op=dist.ReduceOp.MAX) + + if rank == 0: + np.testing.assert_allclose(output.numpy(), [4, 5]) + elif rank == 1: + np.testing.assert_allclose(output.numpy(), [6, 7]) + + def test_collective_reduce_scatter_base(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + + input = paddle.arange(4) + rank + # [0, 1, 2, 3] # Rank-0 + # [1, 2, 3, 4] # Rank-1 + + output = paddle.empty(shape=[2], dtype=input.dtype) + task = paddle.distributed.collective._reduce_scatter_base( + output, input, use_calc_stream=False) + + task.wait() + + if rank == 0: + np.testing.assert_allclose(output.numpy(), [1, 3]) + elif rank == 1: + np.testing.assert_allclose(output.numpy(), [5, 7]) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_collective_alltoall_single.py b/python/paddle/fluid/tests/unittests/test_collective_alltoall_single.py new file mode 100644 index 0000000000000000000000000000000000000000..e848404850d9ec1b7123620b08e072db27f1bd7d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_collective_alltoall_single.py @@ -0,0 +1,32 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import unittest +import paddle.fluid as fluid + +from test_parallel_dygraph_dataparallel import TestMultipleGpus + + +class TestCollectiveAllToAllSingle(TestMultipleGpus): + + def test_collective_alltoall_single(self): + self.run_mnist_2gpu('collective_alltoall_single.py', eager_mode=True) + + +if __name__ == "__main__": + os.environ["FLAGS_enable_eager_mode"] = "1" + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_collective_batch_isend_irecv.py b/python/paddle/fluid/tests/unittests/test_collective_batch_isend_irecv.py new file mode 100644 index 0000000000000000000000000000000000000000..a93c417b99c65ba8342b456c1b2def99912bbf4d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_collective_batch_isend_irecv.py @@ -0,0 +1,32 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import unittest +import paddle.fluid as fluid + +from test_parallel_dygraph_dataparallel import TestMultipleGpus + + +class TestCollectiveBatchIsendIrecv(TestMultipleGpus): + + def test_collective_batch_isend_irecv(self): + self.run_mnist_2gpu('collective_batch_isend_irecv.py', eager_mode=True) + + +if __name__ == "__main__": + os.environ["FLAGS_enable_eager_mode"] = "1" + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_collective_reduce_scatter.py b/python/paddle/fluid/tests/unittests/test_collective_reduce_scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..93d181243b1fac8ed2eb9362425974b843445a0f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_collective_reduce_scatter.py @@ -0,0 +1,32 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import unittest +import paddle.fluid as fluid + +from test_parallel_dygraph_dataparallel import TestMultipleGpus + + +class TestCollectiveReduceScatter(TestMultipleGpus): + + def test_collective_reduce_scatter(self): + self.run_mnist_2gpu('collective_reduce_scatter.py', eager_mode=True) + + +if __name__ == "__main__": + os.environ["FLAGS_enable_eager_mode"] = "1" + unittest.main()