提交 8dc23e0f 编写于 作者: M Megvii Engine Team

fix(mge/functional): fix indexing_one_hot and remote_recv

GitOrigin-RevId: 00bdfb502bc6ed3c564bd924bed85e66cee83f84
上级 3bbfef30
......@@ -20,6 +20,7 @@ from ..core.autodiff.grad import (
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
from ..core.tensor.core import apply
from ..core.tensor.tensor import Tensor
from ..device import get_default_device
from ..distributed.group import (
WORLD,
Group,
......@@ -270,16 +271,19 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
def remote_recv(
src_rank: int, shape: Tuple[int], dtype: type, cn: Optional[str] = "gpu0"
src_rank: int, shape: Tuple[int], dtype: type, device: Optional[str] = None
) -> Tensor:
"""Receive a Tensor from a remote process
:param src_rank: source process rank
:param shape: the shape of the tensor to receive
:param dtype: the data type of the tensor to receive
:param cn: the comp node to place the received tensor
:param device: the device to place the received tensor,
if None, use default device
"""
key = "{}->{}".format(src_rank, get_rank())
if device is None:
device = get_default_device()
# dummpy input
inp = tensor([0])
......@@ -290,7 +294,7 @@ def remote_recv(
op = RemoteRecv()
op.key = key
op.cn = cn
op.cn = device
op.shape = shape
op.dtype = dtype
op.addr, op.port = get_mm_server_addr()
......
......@@ -1447,7 +1447,7 @@ def indexing_one_hot(
src, (TensorWrapperBase, TensorBase)
), "src must be of Tensor type"
op = builtin.IndexingOneHot(axis=axis)
index = utils.convert_single_value(index, (src,), dtype="int32")
index = utils.convert_single_value(index, (src,), dtype="int32", device=src.device)
(result,) = apply(op, src, index)
if not keepdims:
result = remove_axis(result, axis)
......
......@@ -15,6 +15,7 @@ import pytest
import megengine as mge
import megengine.distributed as dist
from megengine import Parameter, Tensor, tensor
from megengine.device import get_default_device, set_default_device
from megengine.functional.distributed import (
all_gather,
all_reduce_max,
......@@ -449,7 +450,8 @@ def test_io_remote():
assert y.numpy()[0] == 0
else: # remote recv
dist.init_process_group("localhost", port, world_size, rank, rank)
y = remote_recv(0, val.shape, val.dtype, cn="gpu1")
y = remote_recv(0, val.shape, val.dtype)
assert y.device == "gpu1"
np.testing.assert_almost_equal(val, y.numpy())
procs = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册