未验证 提交 fb55e00e 编写于 作者: 李季 提交者: GitHub

Fix alltoall (#34064)

* fix the bug that happened in alltoall in dygraph mode
上级 db4bd24b
......@@ -1481,7 +1481,7 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True):
np_data2 = np.array([[19, 20, 21], [22, 23, 24]])
data1 = paddle.to_tensor(np_data1)
data2 = paddle.to_tensor(np_data2)
paddle.distributed.all_to_all([data1, data2], out_tensor_list)
paddle.distributed.alltoall([data1, data2], out_tensor_list)
# out for rank 0: [[[1, 2, 3], [4, 5, 6]], [[13, 14, 15], [16, 17, 18]]]
# out for rank 1: [[[7, 8, 9], [10, 11, 12]], [[19, 20, 21], [22, 23, 24]]]
"""
......@@ -1490,15 +1490,15 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True):
ring_id = 0 if group is None else group.id
temp = paddle.concat(in_tensor_list, axis=0)
nranks = len(in_tensor_list)
if in_dygraph_mode():
_C_ops.alltoall_(temp, 'use_calc_stream', use_calc_stream, 'ring_id',
ring_id)
out = _C_ops.alltoall(temp, 'use_calc_stream', use_calc_stream,
'ring_id', ring_id)
else:
op_type = 'alltoall'
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(
dtype=in_tensor_list[0].dtype)
nranks = len(in_tensor_list)
if not isinstance(in_tensor_list, list):
raise ValueError("The type of 'in_tensor_list' for all_to_all "
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import os
import sys
import signal
import time
import socket
from contextlib import closing
from six import string_types
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import unittest
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
class TestCollectiveAllToAllAPI(TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
tindata = paddle.to_tensor(indata)
tindata = paddle.split(tindata, 2, axis=0)
tout_data = []
paddle.distributed.alltoall(tindata, tout_data)
output_data = []
for data in tout_data:
output_data.append(data.numpy())
return output_data
if __name__ == "__main__":
runtime_main(TestCollectiveAllToAllAPI, "alltoall")
......@@ -29,6 +29,13 @@ class TestCollectiveAllToAllAPI(TestDistBase):
def test_alltoall_nccl(self):
self.check_with_place("collective_alltoall_api.py", "alltoall", "nccl")
def test_alltoall_nccl_dygraph(self):
self.check_with_place(
"collective_alltoall_api_dygraph.py",
"alltoall",
"nccl",
static_mode="0")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册