From 8dc23e0fdf7b6f03bcf072701ee0d50c87c84bd5 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 4 Sep 2020 16:34:49 +0800 Subject: [PATCH] fix(mge/functional): fix indexing_one_hot and remote_recv GitOrigin-RevId: 00bdfb502bc6ed3c564bd924bed85e66cee83f84 --- imperative/python/megengine/functional/distributed.py | 10 +++++++--- imperative/python/megengine/functional/nn.py | 2 +- .../python/test/unit/functional/test_distributed.py | 4 +++- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/imperative/python/megengine/functional/distributed.py b/imperative/python/megengine/functional/distributed.py index 71fafd5c..e10cef11 100644 --- a/imperative/python/megengine/functional/distributed.py +++ b/imperative/python/megengine/functional/distributed.py @@ -20,6 +20,7 @@ from ..core.autodiff.grad import ( from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend from ..core.tensor.core import apply from ..core.tensor.tensor import Tensor +from ..device import get_default_device from ..distributed.group import ( WORLD, Group, @@ -270,16 +271,19 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor: def remote_recv( - src_rank: int, shape: Tuple[int], dtype: type, cn: Optional[str] = "gpu0" + src_rank: int, shape: Tuple[int], dtype: type, device: Optional[str] = None ) -> Tensor: """Receive a Tensor from a remote process :param src_rank: source process rank :param shape: the shape of the tensor to receive :param dtype: the data type of the tensor to receive - :param cn: the comp node to place the received tensor + :param device: the device to place the received tensor, + if None, use default device """ key = "{}->{}".format(src_rank, get_rank()) + if device is None: + device = get_default_device() # dummpy input inp = tensor([0]) @@ -290,7 +294,7 @@ def remote_recv( op = RemoteRecv() op.key = key - op.cn = cn + op.cn = device op.shape = shape op.dtype = dtype op.addr, op.port = get_mm_server_addr() diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 527f7721..a934a077 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1447,7 +1447,7 @@ def indexing_one_hot( src, (TensorWrapperBase, TensorBase) ), "src must be of Tensor type" op = builtin.IndexingOneHot(axis=axis) - index = utils.convert_single_value(index, (src,), dtype="int32") + index = utils.convert_single_value(index, (src,), dtype="int32", device=src.device) (result,) = apply(op, src, index) if not keepdims: result = remove_axis(result, axis) diff --git a/imperative/python/test/unit/functional/test_distributed.py b/imperative/python/test/unit/functional/test_distributed.py index 70b30fb2..fa0f22b7 100644 --- a/imperative/python/test/unit/functional/test_distributed.py +++ b/imperative/python/test/unit/functional/test_distributed.py @@ -15,6 +15,7 @@ import pytest import megengine as mge import megengine.distributed as dist from megengine import Parameter, Tensor, tensor +from megengine.device import get_default_device, set_default_device from megengine.functional.distributed import ( all_gather, all_reduce_max, @@ -449,7 +450,8 @@ def test_io_remote(): assert y.numpy()[0] == 0 else: # remote recv dist.init_process_group("localhost", port, world_size, rank, rank) - y = remote_recv(0, val.shape, val.dtype, cn="gpu1") + y = remote_recv(0, val.shape, val.dtype) + assert y.device == "gpu1" np.testing.assert_almost_equal(val, y.numpy()) procs = [] -- GitLab