diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 6dfd1b2a2aa67df15870dfe31b67b52eb31e36ce..cfba9f656b333a743fbf5890d5928a3178faede9 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -1481,7 +1481,7 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True): np_data2 = np.array([[19, 20, 21], [22, 23, 24]]) data1 = paddle.to_tensor(np_data1) data2 = paddle.to_tensor(np_data2) - paddle.distributed.all_to_all([data1, data2], out_tensor_list) + paddle.distributed.alltoall([data1, data2], out_tensor_list) # out for rank 0: [[[1, 2, 3], [4, 5, 6]], [[13, 14, 15], [16, 17, 18]]] # out for rank 1: [[[7, 8, 9], [10, 11, 12]], [[19, 20, 21], [22, 23, 24]]] """ @@ -1490,15 +1490,15 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True): ring_id = 0 if group is None else group.id temp = paddle.concat(in_tensor_list, axis=0) + nranks = len(in_tensor_list) if in_dygraph_mode(): - _C_ops.alltoall_(temp, 'use_calc_stream', use_calc_stream, 'ring_id', - ring_id) + out = _C_ops.alltoall(temp, 'use_calc_stream', use_calc_stream, + 'ring_id', ring_id) else: op_type = 'alltoall' helper = LayerHelper(op_type, **locals()) out = helper.create_variable_for_type_inference( dtype=in_tensor_list[0].dtype) - nranks = len(in_tensor_list) if not isinstance(in_tensor_list, list): raise ValueError("The type of 'in_tensor_list' for all_to_all " diff --git a/python/paddle/fluid/tests/unittests/collective_alltoall_api_dygraph.py b/python/paddle/fluid/tests/unittests/collective_alltoall_api_dygraph.py new file mode 100644 index 0000000000000000000000000000000000000000..02a59aef071f8239cac1374aff9ee9fe36faf12c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective_alltoall_api_dygraph.py @@ -0,0 +1,56 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import os +import sys +import signal +import time +import socket +from contextlib import closing +from six import string_types +import math +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +import paddle.fluid.unique_name as nameGen +from paddle.fluid import core +import unittest +from multiprocessing import Process +import paddle.fluid.layers as layers +from functools import reduce +from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main + + +class TestCollectiveAllToAllAPI(TestCollectiveAPIRunnerBase): + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program, rank, indata=None): + with fluid.program_guard(main_prog, startup_program): + tindata = paddle.to_tensor(indata) + tindata = paddle.split(tindata, 2, axis=0) + tout_data = [] + paddle.distributed.alltoall(tindata, tout_data) + output_data = [] + for data in tout_data: + output_data.append(data.numpy()) + return output_data + + +if __name__ == "__main__": + runtime_main(TestCollectiveAllToAllAPI, "alltoall") diff --git a/python/paddle/fluid/tests/unittests/test_collective_alltoall_api.py b/python/paddle/fluid/tests/unittests/test_collective_alltoall_api.py index fab975a9d6249f274952e52bb59fd5f61badc116..bb6a8c29bc508fe0bf8d647a9b75f33add7e762e 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_alltoall_api.py +++ b/python/paddle/fluid/tests/unittests/test_collective_alltoall_api.py @@ -29,6 +29,13 @@ class TestCollectiveAllToAllAPI(TestDistBase): def test_alltoall_nccl(self): self.check_with_place("collective_alltoall_api.py", "alltoall", "nccl") + def test_alltoall_nccl_dygraph(self): + self.check_with_place( + "collective_alltoall_api_dygraph.py", + "alltoall", + "nccl", + static_mode="0") + if __name__ == '__main__': unittest.main()