diff --git a/imperative/python/megengine/functional/distributed.py b/imperative/python/megengine/functional/distributed.py index 71fafd5c029d4b2b50c7d83f3628bc0ac5141a11..e10cef1167c7b7de9f95c2fea24a3e1c706a774d 100644 --- a/imperative/python/megengine/functional/distributed.py +++ b/imperative/python/megengine/functional/distributed.py @@ -20,6 +20,7 @@ from ..core.autodiff.grad import ( from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend from ..core.tensor.core import apply from ..core.tensor.tensor import Tensor +from ..device import get_default_device from ..distributed.group import ( WORLD, Group, @@ -270,16 +271,19 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor: def remote_recv( - src_rank: int, shape: Tuple[int], dtype: type, cn: Optional[str] = "gpu0" + src_rank: int, shape: Tuple[int], dtype: type, device: Optional[str] = None ) -> Tensor: """Receive a Tensor from a remote process :param src_rank: source process rank :param shape: the shape of the tensor to receive :param dtype: the data type of the tensor to receive - :param cn: the comp node to place the received tensor + :param device: the device to place the received tensor, + if None, use default device """ key = "{}->{}".format(src_rank, get_rank()) + if device is None: + device = get_default_device() # dummpy input inp = tensor([0]) @@ -290,7 +294,7 @@ def remote_recv( op = RemoteRecv() op.key = key - op.cn = cn + op.cn = device op.shape = shape op.dtype = dtype op.addr, op.port = get_mm_server_addr() diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 527f77213abd337304933045d29d53c7adc4b2c1..a934a0774f197d86c55f3c4a3ab6f15d91993d52 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1447,7 +1447,7 @@ def indexing_one_hot( src, (TensorWrapperBase, TensorBase) ), "src must be of Tensor type" op = builtin.IndexingOneHot(axis=axis) - index = utils.convert_single_value(index, (src,), dtype="int32") + index = utils.convert_single_value(index, (src,), dtype="int32", device=src.device) (result,) = apply(op, src, index) if not keepdims: result = remove_axis(result, axis) diff --git a/imperative/python/test/unit/functional/test_distributed.py b/imperative/python/test/unit/functional/test_distributed.py index 70b30fb28ae20258363c729446f4fc3592922ee7..fa0f22b79aa99a0ed7066ff8d0f7e0ab018f346c 100644 --- a/imperative/python/test/unit/functional/test_distributed.py +++ b/imperative/python/test/unit/functional/test_distributed.py @@ -15,6 +15,7 @@ import pytest import megengine as mge import megengine.distributed as dist from megengine import Parameter, Tensor, tensor +from megengine.device import get_default_device, set_default_device from megengine.functional.distributed import ( all_gather, all_reduce_max, @@ -449,7 +450,8 @@ def test_io_remote(): assert y.numpy()[0] == 0 else: # remote recv dist.init_process_group("localhost", port, world_size, rank, rank) - y = remote_recv(0, val.shape, val.dtype, cn="gpu1") + y = remote_recv(0, val.shape, val.dtype) + assert y.device == "gpu1" np.testing.assert_almost_equal(val, y.numpy()) procs = []