From ea8eb4cf721ebea6a4b35432eb5aff30b9e4cdeb Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 13 Jan 2021 17:27:28 +0800 Subject: [PATCH] feat(mge/distributed): scalar support for distributed functions GitOrigin-RevId: 53f3575baf58d709d752618e90cdec6f93b631e5 --- .../megengine/distributed/functional.py | 21 +++++++++- .../functional/test_functional_distributed.py | 42 +++++++++---------- 2 files changed, 41 insertions(+), 22 deletions(-) diff --git a/imperative/python/megengine/distributed/functional.py b/imperative/python/megengine/distributed/functional.py index 32a24ff14..02efd29c1 100644 --- a/imperative/python/megengine/distributed/functional.py +++ b/imperative/python/megengine/distributed/functional.py @@ -11,6 +11,7 @@ from typing import Optional, Tuple from ..core._imperative_rt.core2 import apply from ..core.autodiff.grad import _grad_manager_dict from ..core.ops.builtin import CollectiveComm, Copy, PyOpBase, RemoteRecv, RemoteSend +from ..core.tensor.utils import isscalar, setscalar from ..device import get_default_device from ..tensor import Tensor from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank @@ -50,7 +51,18 @@ def collective_comm(inp, mode, group, device): backend=get_backend(), comp_node=device, ) - return apply(op, inp)[0] + (result,) = apply(op, inp) + # assume all workers have homogeneous shape + if mode in ( + CollectiveComm.Mode.REDUCE_SUM, + CollectiveComm.Mode.BROADCAST, + CollectiveComm.Mode.ALL_REDUCE_SUM, + CollectiveComm.Mode.ALL_REDUCE_MAX, + CollectiveComm.Mode.ALL_REDUCE_MIN, + ): + if isscalar(inp): + setscalar(result) + return result def reduce_sum( @@ -289,6 +301,11 @@ def remote_recv( g.wrt(inp) g._refkeeper.append(inp) + _isscalar = False + if len(shape) == 0: + shape = (1,) + _isscalar = True + op = RemoteRecv() op.key = key op.cn = device @@ -298,4 +315,6 @@ def remote_recv( op.rank_from = src_rank (ret,) = apply(_RemoteRecv(op), inp) + if _isscalar: + setscalar(ret) return ret diff --git a/imperative/python/test/unit/functional/test_functional_distributed.py b/imperative/python/test/unit/functional/test_functional_distributed.py index 7342b1642..09125d562 100644 --- a/imperative/python/test/unit/functional/test_functional_distributed.py +++ b/imperative/python/test/unit/functional/test_functional_distributed.py @@ -13,7 +13,7 @@ import pytest import megengine as mge import megengine.distributed as dist -from megengine import Parameter, Tensor, tensor +from megengine import Parameter, tensor from megengine.core._imperative_rt.core2 import sync from megengine.device import get_default_device, set_default_device from megengine.distributed.helper import get_device_count_by_fork @@ -53,14 +53,14 @@ def test_reduce_sum(): assert np.allclose(output.numpy(), 0) def check(shape): - x = np.random.rand(*shape).astype("float32") - y = np.random.rand(*shape).astype("float32") + x = np.random.rand(*shape) + y = np.random.rand(*shape) z = x + y data = (x, y) expect = (z, None) worker(data, expect) - for shape in [(2, 3), (8, 10), (99, 77)]: + for shape in [(), (1,), (2, 3), (8, 10), (99, 77)]: check(shape) @@ -81,13 +81,13 @@ def test_broadcast(): assert np.allclose(output.numpy(), expect[rank]) def check(shape): - x = np.random.rand(*shape).astype("float32") + x = np.random.rand(*shape) y = x + 1 data = (x, y) expect = (x, x) worker(data, expect) - for shape in [(2, 3), (8, 10), (99, 77)]: + for shape in [(), (1,), (2, 3), (8, 10), (99, 77)]: check(shape) @@ -164,14 +164,14 @@ def test_all_reduce_sum(): assert np.allclose(output.numpy(), expect[rank]) def check(shape): - x = np.random.rand(*shape).astype("float32") - y = np.random.rand(*shape).astype("float32") + x = np.random.rand(*shape) + y = np.random.rand(*shape) z = x + y data = (x, y) expect = (z, z) worker(data, expect) - for shape in [(2, 3), (8, 10), (99, 77)]: + for shape in [(), (1,), (2, 3), (8, 10), (99, 77)]: check(shape) @@ -192,14 +192,14 @@ def test_all_reduce_max(): assert np.allclose(output.numpy(), expect[rank]) def check(shape): - x = np.random.rand(*shape).astype("float32") - y = np.random.rand(*shape).astype("float32") + x = np.random.rand(*shape) + y = np.random.rand(*shape) z = np.maximum(x, y) data = (x, y) expect = (z, z) worker(data, expect) - for shape in [(2, 3), (8, 10), (99, 77)]: + for shape in [(), (1,), (2, 3), (8, 10), (99, 77)]: check(shape) @@ -220,14 +220,14 @@ def test_all_reduce_min(): assert np.allclose(output.numpy(), expect[rank]) def check(shape): - x = np.random.rand(*shape).astype("float32") - y = np.random.rand(*shape).astype("float32") + x = np.random.rand(*shape) + y = np.random.rand(*shape) z = np.minimum(x, y) data = (x, y) expect = (z, z) worker(data, expect) - for shape in [(2, 3), (8, 10), (99, 77)]: + for shape in [(), (1,), (2, 3), (8, 10), (99, 77)]: check(shape) @@ -327,18 +327,18 @@ def test_all_to_all(): @pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") @pytest.mark.isolated_distributed def test_io_remote(): - val = np.random.rand(4, 5).astype(np.float32) - @dist.launcher(n_gpus=2) - def worker(): + def worker(val, shape): rank = dist.get_rank() if rank == 0: # remote send - x = Tensor(val, device="gpu0") + x = tensor(val, device="gpu0") remote_send(x, 1) sync() else: # remote recv - y = remote_recv(0, val.shape, val.dtype) + y = remote_recv(0, shape, np.float32) assert y.device == "gpu1" np.testing.assert_almost_equal(val, y.numpy()) - worker() + for shape in [(), (1,), (4, 5)]: + val = np.random.rand(*shape) + worker(val, shape) -- GitLab