diff --git a/imperative/python/megengine/distributed/functional.py b/imperative/python/megengine/distributed/functional.py index 8a738718c35469c0e74fba4e669651c543822aeb..e0b28fc788c9c75aa30f3218647dfcc09cc08503 100644 --- a/imperative/python/megengine/distributed/functional.py +++ b/imperative/python/megengine/distributed/functional.py @@ -185,7 +185,7 @@ def reduce_sum( output = reduce_sum(input) # Rank 0 # output: Tensor([1]) # Rank 1 # output: None - + input = Tensor([rank]) group = Group([1, 0]) # first rank is root output = reduce_sum(input, group) @@ -248,7 +248,7 @@ def broadcast( output = broadcast(input) # Rank 0 # output: Tensor([0]) # Rank 1 # output: Tensor([0]) - + input = Tensor([rank]) group = Group([1, 0]) # first rank is root output = broadcast(input, group) @@ -276,7 +276,7 @@ def _bcast_param( def all_gather( - inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, + inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0, ) -> Tensor: r""" Gather tensors across the specified group and concat them at first dimension. @@ -290,6 +290,8 @@ def all_gather( None default device means the device of inp will be used. Specify "gpu0:1" to execute this operator on diffrent cuda stream, 1 is stream id, and default stream id is 0. + axis: The concat axis for collective_comm result + The default axis is 0 Returns: Result tensor. @@ -304,7 +306,7 @@ def all_gather( output = all_gather(input) # Rank 0 # output: Tensor([0 1]) # Rank 1 # output: Tensor([0 1]) - + input = Tensor([rank]) group = Group([1, 0]) output = all_gather(input, group) @@ -313,11 +315,28 @@ def all_gather( """ mode = CollectiveComm.Mode.ALL_GATHER - return collective_comm(inp, mode, group, device) + out = collective_comm(inp, mode, group, device) + if axis == 0: + return out + else: + group_size = group.size if group is not None else 1 + transformed_shape = list(inp._tuple_shape) + transformed_shape[axis] *= group_size + n, *shp = out._tuple_shape + index = ( + [_ for _ in range(1, axis)] + + [axis, 0] + + [_ for _ in range(axis + 1, out.ndim + 1)] + ) + return ( + out.reshape(group_size, n // group_size, *shp) + .transpose(index) + .reshape(transformed_shape) + ) def reduce_scatter_sum( - inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, + inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0 ) -> Tensor: r""" Reduce tensors across the specified group by sum and split them at first dimension. @@ -331,6 +350,8 @@ def reduce_scatter_sum( None default device means the device of inp will be used. Specify "gpu0:1" to execute this operator on diffrent cuda stream, 1 is stream id, and default stream id is 0. + axis: The split axis for collective_comm result + The default axis is 0, the data will split in the 0 axis Returns: Split tensor. @@ -345,7 +366,7 @@ def reduce_scatter_sum( output = reduce_scatter_sum(input) # Rank 0 # output: Tensor([0]) # Rank 1 # output: Tensor([2]) - + input = Tensor([0 1]) group = Group([1, 0]) output = reduce_scatter_sum(input, group) @@ -353,6 +374,23 @@ def reduce_scatter_sum( # Rank 1 # output: Tensor([0]) """ + group_size = group.size if group is not None else 1 + assert ( + list(inp._tuple_shape)[axis] % group_size == 0 + ), "current axis: {} can't devided by group size".format(axis) + if axis != 0: + k_new_shape = list(inp._tuple_shape) + k_new_shape[axis] //= group_size + k_new_shape[0] *= group_size + new_shape = list(inp._tuple_shape) + new_shape[axis] //= group_size + new_shape.insert(axis, group_size) + index = ( + [axis] + + [_ for _ in range(0, axis)] + + [_ for _ in range(axis + 1, inp.ndim + 1)] + ) + inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape) mode = CollectiveComm.Mode.REDUCE_SCATTER_SUM return collective_comm(inp, mode, group, device) @@ -480,7 +518,7 @@ class _Gather(Function): def gather( - inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, + inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0, ) -> Tensor: r""" Gather tensors across the specified group. @@ -495,7 +533,8 @@ def gather( None default device means the device of inp will be used. Specify "gpu0:1" to execute this operator on diffrent cuda stream, 1 is stream id, and default stream id is 0. - + axis: The concat axis for collective_comm result + The default axis is 0 Returns: Result tensor if in root process, None if in other process @@ -509,7 +548,7 @@ def gather( output = gather(input) # Rank 0 # output: Tensor([0 1]) # Rank 1 # output: None - + input = Tensor([rank]) group = Group([1, 0]) # first rank is root output = gather(input, group) @@ -517,12 +556,33 @@ def gather( # Rank 1 # output: Tensor([1 0]) """ + assert ( + axis < inp.ndim + ), "your concat_axis exceeds the dim of the tensor, the tensor shape is {}".format( + inp.shape + ) op = _Gather(group, device) (out,) = apply(op, inp) if group.rank == 0: - return out + if axis == 0: + return out + else: + group_size = group.size + transformed_shape = list(inp._tuple_shape) + transformed_shape[axis] *= group_size + n, *shp = out._tuple_shape + index = ( + [_ for _ in range(1, axis)] + + [axis, 0] + + [_ for _ in range(axis + 1, out.ndim + 1)] + ) + return ( + out.reshape(group_size, n // group_size, *shp) + .transpose(index) + .reshape(transformed_shape) + ) else: _save_output_for_autodiff(inp, out) @@ -545,7 +605,7 @@ class _Scatter(Function): def scatter( - inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, + inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0, ) -> Tensor: r""" Split tensor in root process at first dimension. @@ -559,6 +619,8 @@ def scatter( None default device means the device of inp will be used. Specify "gpu0:1" to execute this operator on diffrent cuda stream, 1 is stream id, and default stream id is 0. + axis: The concat axis for collective_comm result + The default axis is 0 Returns: Split tensor. @@ -573,7 +635,7 @@ def scatter( output = scatter(input) # Rank 0 # output: Tensor([0]) # Rank 1 # output: Tensor([1]) - + input = Tensor([0 1]) + rank*2 group = Group([1, 0]) # first rank is root output = scatter(input, group) @@ -588,13 +650,35 @@ def scatter( _bcast_tracer_state(group, inp) + assert ( + list(inp._tuple_shape)[axis] % group.size == 0 + ), "current axis: {} can't devided by group size".format(axis) + + if axis != 0: + group_size = group.size + k_new_shape = list(inp._tuple_shape) + k_new_shape[axis] //= group_size + k_new_shape[0] *= group_size + new_shape = list(inp._tuple_shape) + new_shape[axis] //= group_size + new_shape.insert(axis, group_size) + index = ( + [axis] + + [_ for _ in range(0, axis)] + + [_ for _ in range(axis + 1, inp.ndim + 1)] + ) + inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape) op = _Scatter(group, device) (out,) = apply(op, inp) return out def all_to_all( - inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, + inp: Tensor, + group: Optional[Group] = WORLD, + device: Optional[str] = None, + split_axis: int = 0, + concat_axis: int = 0, ) -> Tensor: r""" Each process scatter input tensor to all processes and return gathered tensor. @@ -608,6 +692,10 @@ def all_to_all( None default device means the device of inp will be used. Specify "gpu0:1" to execute this operator on diffrent cuda stream, 1 is stream id, and default stream id is 0. + split_axis: The axis that collectivecomm will split data + the default axis is 0 + split_axis: The axis that collectivecomm will concat data + the default axis is 0 Returns: Result tensor. @@ -622,7 +710,7 @@ def all_to_all( output = all_to_all(input) # Rank 0 # output: Tensor([0 2]) # Rank 1 # output: Tensor([1 3]) - + input = Tensor([0 1]) + rank*2 group = Group([1, 0]) output = all_to_all(input, group) @@ -630,8 +718,46 @@ def all_to_all( # Rank 1 # output: Tensor([2 1]) """ + group_size = group.size if group is not None else 1 + assert ( + list(inp._tuple_shape)[split_axis] % group_size == 0 + ), "current axis: {} can't devided by group size".format(split_axis) + origin_shape = inp._tuple_shape + if split_axis != 0: + k_new_shape = list(inp._tuple_shape) + k_new_shape[split_axis] //= group_size + k_new_shape[0] *= group_size + new_shape = list(inp._tuple_shape) + new_shape[split_axis] //= group_size + new_shape.insert(split_axis, group_size) + index = ( + [split_axis] + + [_ for _ in range(0, split_axis)] + + [_ for _ in range(split_axis + 1, inp.ndim + 1)] + ) + inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape) + mode = CollectiveComm.Mode.ALL_TO_ALL - return collective_comm(inp, mode, group, device) + out = collective_comm(inp, mode, group, device) + + if concat_axis == 0: + return out + + transformed_shape = list(origin_shape) + transformed_shape[concat_axis] *= group_size + transformed_shape[split_axis] //= group_size + + n, *shp = out._tuple_shape + index = ( + [_ for _ in range(1, concat_axis)] + + [concat_axis, 0] + + [_ for _ in range(concat_axis + 1, out.ndim + 1)] + ) + return ( + out.reshape(group_size, n // group_size, *shp) + .transpose(index) + .reshape(transformed_shape) + ) class _SendRecvGroup: diff --git a/imperative/python/test/unit/functional/test_functional_distributed_axis.py b/imperative/python/test/unit/functional/test_functional_distributed_axis.py new file mode 100644 index 0000000000000000000000000000000000000000..b9a8e093da0cd90c81cfac2995112811f7da72f7 --- /dev/null +++ b/imperative/python/test/unit/functional/test_functional_distributed_axis.py @@ -0,0 +1,142 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import numpy as np +import pytest + +import megengine as mge +import megengine.distributed as dist +from megengine import tensor +from megengine.distributed.functional import ( + all_gather, + all_to_all, + gather, + reduce_scatter_sum, + scatter, +) +from megengine.jit import trace + + +@pytest.mark.require_ngpu(2) +@pytest.mark.parametrize("shape", [(2, 3), (8, 10), (99, 77), (2, 2, 2, 2)], ids=str) +@pytest.mark.parametrize("symbolic", [False, True], ids=str) +@pytest.mark.parametrize("axis", [0, 1], ids=str) +@pytest.mark.isolated_distributed +def test_all_gather(shape, symbolic, axis): + @dist.launcher(n_gpus=2) + def worker(data, expect): + rank = dist.get_rank() + inp = tensor(data[rank]) + + def func(): + output = all_gather(inp, axis=axis) + return output + + func = trace(symbolic=symbolic)(func) + output = func() + assert np.allclose(output.numpy(), expect[rank]) + + x = np.random.random_sample(shape).astype("float32") + y = np.random.random_sample(shape).astype("float32") + z = np.concatenate((x, y), axis=axis) + data = (x, y) + expect = (z, z) + worker(data, expect) + + +@pytest.mark.require_ngpu(2) +@pytest.mark.parametrize( + "shape,symbolic", [((2, 4, 6, 8), False), ((2, 4, 6, 8), True)], ids=str +) +@pytest.mark.parametrize("axis", [1, 0, 2, 3], ids=str) +@pytest.mark.isolated_distributed +def test_reduce_scatter_sum(shape, symbolic, axis): + @dist.launcher(n_gpus=2) + def worker(data, expect): + rank = dist.get_rank() + inp = tensor(data[rank]) + + def func(): + output = reduce_scatter_sum(inp, axis=axis) + return output + + func = trace(symbolic=symbolic)(func) + output = func() + assert np.allclose(output.numpy(), expect[rank]) + + x = np.random.random_sample(shape).astype("float32") + y = np.random.random_sample(shape).astype("float32") + z = x + y + data = (x, y) + z = np.split(z, 2, axis=axis) + z = np.concatenate(z, axis=0) + expect = (z[: z.shape[0] // 2], z[z.shape[0] // 2 :]) + worker(data, expect) + + +@pytest.mark.require_ngpu(2) +@pytest.mark.parametrize( + "shape,symbolic", [((2, 4, 6, 8), True), ((2, 4, 6, 8), False)], ids=str +) +@pytest.mark.parametrize("axis", [1, 0, 2, 3], ids=str) +@pytest.mark.isolated_distributed +def test_scatter(shape, symbolic, axis): + @dist.launcher(n_gpus=2) + def worker(data, expect): + rank = dist.get_rank() + inp = tensor(data[rank]) + + def func(): + output = scatter(inp, axis=axis) + return output + + func = trace(symbolic=symbolic)(func) + output = func() + assert np.allclose(output.numpy(), expect[rank]) + + x = np.random.random_sample(shape).astype("float32") + y = x + 1 + data = (x, y) + _x = np.split(x, 2, axis=axis) + _x = np.concatenate(_x, axis=0) + expect = (_x[: _x.shape[0] // 2], _x[_x.shape[0] // 2 :]) + worker(data, expect) + + +@pytest.mark.require_ngpu(2) +@pytest.mark.parametrize("shape", [(2, 4, 6, 8)], ids=str) +@pytest.mark.parametrize("symbolic", [False, True], ids=str) +@pytest.mark.parametrize( + "split_axis,concat_axis", [(0, 1), (1, 0), (2, 0), (0, 2), (2, 3)], ids=str +) +@pytest.mark.isolated_distributed +def test_all_to_all(shape, symbolic, split_axis, concat_axis): + @dist.launcher(n_gpus=2) + def worker(data): + rank = dist.get_rank() + inp = tensor(data[rank]) + + def func(): + all_to_all_output = all_to_all( + inp, split_axis=split_axis, concat_axis=concat_axis + ) + gather_C = gather(inp, axis=concat_axis) + gather_B = gather(all_to_all_output, axis=split_axis) + if rank == 0: + return gather_B, gather_C + return all_to_all_output + + func = trace(symbolic=symbolic)(func) + ret = func() + if rank == 0: + assert np.allclose(ret[0], ret[1]) + + x = np.random.random_sample(shape).astype("float32") + y = np.random.random_sample(shape).astype("float32") + data = (x, y) + worker(data)