提交 ea8eb4cf 编写于 作者: M Megvii Engine Team

feat(mge/distributed): scalar support for distributed functions

GitOrigin-RevId: 53f3575baf58d709d752618e90cdec6f93b631e5
上级 b83c77e1
......@@ -11,6 +11,7 @@ from typing import Optional, Tuple
from ..core._imperative_rt.core2 import apply
from ..core.autodiff.grad import _grad_manager_dict
from ..core.ops.builtin import CollectiveComm, Copy, PyOpBase, RemoteRecv, RemoteSend
from ..core.tensor.utils import isscalar, setscalar
from ..device import get_default_device
from ..tensor import Tensor
from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank
......@@ -50,7 +51,18 @@ def collective_comm(inp, mode, group, device):
backend=get_backend(),
comp_node=device,
)
return apply(op, inp)[0]
(result,) = apply(op, inp)
# assume all workers have homogeneous shape
if mode in (
CollectiveComm.Mode.REDUCE_SUM,
CollectiveComm.Mode.BROADCAST,
CollectiveComm.Mode.ALL_REDUCE_SUM,
CollectiveComm.Mode.ALL_REDUCE_MAX,
CollectiveComm.Mode.ALL_REDUCE_MIN,
):
if isscalar(inp):
setscalar(result)
return result
def reduce_sum(
......@@ -289,6 +301,11 @@ def remote_recv(
g.wrt(inp)
g._refkeeper.append(inp)
_isscalar = False
if len(shape) == 0:
shape = (1,)
_isscalar = True
op = RemoteRecv()
op.key = key
op.cn = device
......@@ -298,4 +315,6 @@ def remote_recv(
op.rank_from = src_rank
(ret,) = apply(_RemoteRecv(op), inp)
if _isscalar:
setscalar(ret)
return ret
......@@ -13,7 +13,7 @@ import pytest
import megengine as mge
import megengine.distributed as dist
from megengine import Parameter, Tensor, tensor
from megengine import Parameter, tensor
from megengine.core._imperative_rt.core2 import sync
from megengine.device import get_default_device, set_default_device
from megengine.distributed.helper import get_device_count_by_fork
......@@ -53,14 +53,14 @@ def test_reduce_sum():
assert np.allclose(output.numpy(), 0)
def check(shape):
x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32")
x = np.random.rand(*shape)
y = np.random.rand(*shape)
z = x + y
data = (x, y)
expect = (z, None)
worker(data, expect)
for shape in [(2, 3), (8, 10), (99, 77)]:
for shape in [(), (1,), (2, 3), (8, 10), (99, 77)]:
check(shape)
......@@ -81,13 +81,13 @@ def test_broadcast():
assert np.allclose(output.numpy(), expect[rank])
def check(shape):
x = np.random.rand(*shape).astype("float32")
x = np.random.rand(*shape)
y = x + 1
data = (x, y)
expect = (x, x)
worker(data, expect)
for shape in [(2, 3), (8, 10), (99, 77)]:
for shape in [(), (1,), (2, 3), (8, 10), (99, 77)]:
check(shape)
......@@ -164,14 +164,14 @@ def test_all_reduce_sum():
assert np.allclose(output.numpy(), expect[rank])
def check(shape):
x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32")
x = np.random.rand(*shape)
y = np.random.rand(*shape)
z = x + y
data = (x, y)
expect = (z, z)
worker(data, expect)
for shape in [(2, 3), (8, 10), (99, 77)]:
for shape in [(), (1,), (2, 3), (8, 10), (99, 77)]:
check(shape)
......@@ -192,14 +192,14 @@ def test_all_reduce_max():
assert np.allclose(output.numpy(), expect[rank])
def check(shape):
x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32")
x = np.random.rand(*shape)
y = np.random.rand(*shape)
z = np.maximum(x, y)
data = (x, y)
expect = (z, z)
worker(data, expect)
for shape in [(2, 3), (8, 10), (99, 77)]:
for shape in [(), (1,), (2, 3), (8, 10), (99, 77)]:
check(shape)
......@@ -220,14 +220,14 @@ def test_all_reduce_min():
assert np.allclose(output.numpy(), expect[rank])
def check(shape):
x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32")
x = np.random.rand(*shape)
y = np.random.rand(*shape)
z = np.minimum(x, y)
data = (x, y)
expect = (z, z)
worker(data, expect)
for shape in [(2, 3), (8, 10), (99, 77)]:
for shape in [(), (1,), (2, 3), (8, 10), (99, 77)]:
check(shape)
......@@ -327,18 +327,18 @@ def test_all_to_all():
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
def test_io_remote():
val = np.random.rand(4, 5).astype(np.float32)
@dist.launcher(n_gpus=2)
def worker():
def worker(val, shape):
rank = dist.get_rank()
if rank == 0: # remote send
x = Tensor(val, device="gpu0")
x = tensor(val, device="gpu0")
remote_send(x, 1)
sync()
else: # remote recv
y = remote_recv(0, val.shape, val.dtype)
y = remote_recv(0, shape, np.float32)
assert y.device == "gpu1"
np.testing.assert_almost_equal(val, y.numpy())
worker()
for shape in [(), (1,), (4, 5)]:
val = np.random.rand(*shape)
worker(val, shape)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册