未验证 提交 fb55e00e 编写于 作者: 李季 提交者: GitHub

Fix alltoall (#34064)

* fix the bug that happened in alltoall in dygraph mode
上级 db4bd24b
...@@ -1481,7 +1481,7 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True): ...@@ -1481,7 +1481,7 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True):
np_data2 = np.array([[19, 20, 21], [22, 23, 24]]) np_data2 = np.array([[19, 20, 21], [22, 23, 24]])
data1 = paddle.to_tensor(np_data1) data1 = paddle.to_tensor(np_data1)
data2 = paddle.to_tensor(np_data2) data2 = paddle.to_tensor(np_data2)
paddle.distributed.all_to_all([data1, data2], out_tensor_list) paddle.distributed.alltoall([data1, data2], out_tensor_list)
# out for rank 0: [[[1, 2, 3], [4, 5, 6]], [[13, 14, 15], [16, 17, 18]]] # out for rank 0: [[[1, 2, 3], [4, 5, 6]], [[13, 14, 15], [16, 17, 18]]]
# out for rank 1: [[[7, 8, 9], [10, 11, 12]], [[19, 20, 21], [22, 23, 24]]] # out for rank 1: [[[7, 8, 9], [10, 11, 12]], [[19, 20, 21], [22, 23, 24]]]
""" """
...@@ -1490,15 +1490,15 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True): ...@@ -1490,15 +1490,15 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True):
ring_id = 0 if group is None else group.id ring_id = 0 if group is None else group.id
temp = paddle.concat(in_tensor_list, axis=0) temp = paddle.concat(in_tensor_list, axis=0)
nranks = len(in_tensor_list)
if in_dygraph_mode(): if in_dygraph_mode():
_C_ops.alltoall_(temp, 'use_calc_stream', use_calc_stream, 'ring_id', out = _C_ops.alltoall(temp, 'use_calc_stream', use_calc_stream,
ring_id) 'ring_id', ring_id)
else: else:
op_type = 'alltoall' op_type = 'alltoall'
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference( out = helper.create_variable_for_type_inference(
dtype=in_tensor_list[0].dtype) dtype=in_tensor_list[0].dtype)
nranks = len(in_tensor_list)
if not isinstance(in_tensor_list, list): if not isinstance(in_tensor_list, list):
raise ValueError("The type of 'in_tensor_list' for all_to_all " raise ValueError("The type of 'in_tensor_list' for all_to_all "
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import os
import sys
import signal
import time
import socket
from contextlib import closing
from six import string_types
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import unittest
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
class TestCollectiveAllToAllAPI(TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
tindata = paddle.to_tensor(indata)
tindata = paddle.split(tindata, 2, axis=0)
tout_data = []
paddle.distributed.alltoall(tindata, tout_data)
output_data = []
for data in tout_data:
output_data.append(data.numpy())
return output_data
if __name__ == "__main__":
runtime_main(TestCollectiveAllToAllAPI, "alltoall")
...@@ -29,6 +29,13 @@ class TestCollectiveAllToAllAPI(TestDistBase): ...@@ -29,6 +29,13 @@ class TestCollectiveAllToAllAPI(TestDistBase):
def test_alltoall_nccl(self): def test_alltoall_nccl(self):
self.check_with_place("collective_alltoall_api.py", "alltoall", "nccl") self.check_with_place("collective_alltoall_api.py", "alltoall", "nccl")
def test_alltoall_nccl_dygraph(self):
self.check_with_place(
"collective_alltoall_api_dygraph.py",
"alltoall",
"nccl",
static_mode="0")
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册