From 7bd02d247728a11f39eb87251fbe84b7ee080e3e Mon Sep 17 00:00:00 2001 From: Wen Sun <35923278+HermitSun@users.noreply.github.com> Date: Mon, 29 Aug 2022 19:45:40 +0800 Subject: [PATCH] Completes basic dtypes for all_reduce api in eager mode (#45440) --- python/paddle/distributed/collective.py | 32 ++++++++--------- .../collective_allreduce_api_dygraph.py | 36 +++++++++++++++++++ .../test_collective_allreduce_api.py | 25 +++++++++++++ 3 files changed, 75 insertions(+), 18 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/collective/collective_allreduce_api_dygraph.py diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 0879467b72e..e2dae09b48a 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -775,8 +775,8 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True): """ Reduce a tensor over all ranks so that all get the result. - As shown below, 4 GPUs each start 4 processes and the data on each GPU is represnted - by the GPU number. The reduce operator is sum. Through all_reduce operator, + As shown below, one process is started with a GPU and the data of this process is represented + by its group rank. The reduce operator is sum. Through all_reduce operator, each GPU will have the sum of the data from all GPUs. .. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allreduce.png @@ -786,8 +786,8 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True): Args: tensor (Tensor): The input Tensor. It also works as the output Tensor. Its data type - should be float16, float32, float64, int32 or int64. - op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM. + should be float16, float32, float64, int32, int64, int8, uint8 or bool. + op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM. group (Group): The group instance return by new_group or None for global default group. use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False). Default to True. @@ -799,21 +799,16 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True): .. code-block:: python # required: distributed - import numpy as np import paddle - from paddle.distributed import ReduceOp from paddle.distributed import init_parallel_env paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id) init_parallel_env() if paddle.distributed.ParallelEnv().local_rank == 0: - np_data = np.array([[4, 5, 6], [4, 5, 6]]) + data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]]) else: - np_data = np.array([[1, 2, 3], [1, 2, 3]]) - data = paddle.to_tensor(np_data) + data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]]) paddle.distributed.all_reduce(data) - out = data.numpy() - # [[5, 7, 9], [5, 7, 9]] """ if group is not None and not group.is_member(): return @@ -849,9 +844,10 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True): else: raise ValueError("Unknown parameter: {}.".format(op)) - check_variable_and_dtype( - tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'], - 'all_reduce') + check_variable_and_dtype(tensor, 'tensor', [ + 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8', + 'bool' + ], 'all_reduce') if op == ReduceOp.SUM: op_type = 'c_allreduce_sum' elif op == ReduceOp.MAX: @@ -888,7 +884,7 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True): tensor (Tensor): The output Tensor for the destination and the input Tensor otherwise. Its data type should be float16, float32, float64, int32 or int64. dst (int): The destination rank id. - op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM. + op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default value is ReduceOp.SUM. group (Group): The group instance return by new_group or None for global default group. use_calc_stream (bool): Wether to use calculation stream (True) or communication stream (False). Default to True. @@ -984,7 +980,7 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True): """ Gather tensors from all participators and all get the result. As shown - below, 4 GPUs each start 4 processes and the data on each GPU is represnted + below, 4 GPUs each starts 4 processes and the data on each GPU is represented by the GPU number. Through the all_gather operator, each GPU will have data from all GPUs. @@ -2581,7 +2577,7 @@ def reduce_scatter(tensor, Args: tensor (Tensor): Output tensor. tensor_list (list[Tensor]): List of tensors to reduce and scatter. - op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM. + op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM. group (Group, optional): The group instance return by new_group or None for global default group. Default: None. use_calc_stream (bool, optional): Whether this op should be an async op. @@ -2654,7 +2650,7 @@ def _reduce_scatter_base(output, Args: output (Tensor): Output tensor. input (Tensor): Input tensor that is of size output tensor size times world size - op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM. + op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM. group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream (False). diff --git a/python/paddle/fluid/tests/unittests/collective/collective_allreduce_api_dygraph.py b/python/paddle/fluid/tests/unittests/collective/collective_allreduce_api_dygraph.py new file mode 100644 index 00000000000..83588d450a7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/collective_allreduce_api_dygraph.py @@ -0,0 +1,36 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +import paddle.fluid as fluid +import unittest +import test_collective_api_base as test_base + + +class TestCollectiveAllreduceAPI(test_base.TestCollectiveAPIRunnerBase): + + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program, rank, indata=None): + with fluid.program_guard(main_prog, startup_program): + tindata = paddle.to_tensor(indata) + paddle.distributed.all_reduce(tindata) + return [tindata.numpy()] + + +if __name__ == "__main__": + test_base.runtime_main(TestCollectiveAllreduceAPI, "allreduce") diff --git a/python/paddle/fluid/tests/unittests/collective/test_collective_allreduce_api.py b/python/paddle/fluid/tests/unittests/collective/test_collective_allreduce_api.py index 5ec08aa72e2..2598606fc9c 100644 --- a/python/paddle/fluid/tests/unittests/collective/test_collective_allreduce_api.py +++ b/python/paddle/fluid/tests/unittests/collective/test_collective_allreduce_api.py @@ -41,6 +41,31 @@ class TestCollectiveAllreduceAPI(TestDistBase): self.check_with_place("collective_allreduce_api.py", "allreduce", "gloo", "2") + def test_allreduce_nccl_dygraph(self): + dtypes_to_test = [ + 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8', + 'bool' + ] + for dtype in dtypes_to_test: + self.check_with_place("collective_allreduce_api_dygraph.py", + "allreduce", + "nccl", + static_mode="0", + dtype=dtype) + + def test_allreduce_gloo_dygraph(self): + dtypes_to_test = [ + 'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8', + 'bool' + ] + for dtype in dtypes_to_test: + self.check_with_place("collective_allreduce_api_dygraph.py", + "allreduce", + "gloo", + "2", + static_mode="0", + dtype=dtype) + if __name__ == '__main__': unittest.main() -- GitLab