提交 8dc23e0f 编写于 作者: M Megvii Engine Team

fix(mge/functional): fix indexing_one_hot and remote_recv

GitOrigin-RevId: 00bdfb502bc6ed3c564bd924bed85e66cee83f84
上级 3bbfef30
...@@ -20,6 +20,7 @@ from ..core.autodiff.grad import ( ...@@ -20,6 +20,7 @@ from ..core.autodiff.grad import (
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
from ..core.tensor.core import apply from ..core.tensor.core import apply
from ..core.tensor.tensor import Tensor from ..core.tensor.tensor import Tensor
from ..device import get_default_device
from ..distributed.group import ( from ..distributed.group import (
WORLD, WORLD,
Group, Group,
...@@ -270,16 +271,19 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor: ...@@ -270,16 +271,19 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
def remote_recv( def remote_recv(
src_rank: int, shape: Tuple[int], dtype: type, cn: Optional[str] = "gpu0" src_rank: int, shape: Tuple[int], dtype: type, device: Optional[str] = None
) -> Tensor: ) -> Tensor:
"""Receive a Tensor from a remote process """Receive a Tensor from a remote process
:param src_rank: source process rank :param src_rank: source process rank
:param shape: the shape of the tensor to receive :param shape: the shape of the tensor to receive
:param dtype: the data type of the tensor to receive :param dtype: the data type of the tensor to receive
:param cn: the comp node to place the received tensor :param device: the device to place the received tensor,
if None, use default device
""" """
key = "{}->{}".format(src_rank, get_rank()) key = "{}->{}".format(src_rank, get_rank())
if device is None:
device = get_default_device()
# dummpy input # dummpy input
inp = tensor([0]) inp = tensor([0])
...@@ -290,7 +294,7 @@ def remote_recv( ...@@ -290,7 +294,7 @@ def remote_recv(
op = RemoteRecv() op = RemoteRecv()
op.key = key op.key = key
op.cn = cn op.cn = device
op.shape = shape op.shape = shape
op.dtype = dtype op.dtype = dtype
op.addr, op.port = get_mm_server_addr() op.addr, op.port = get_mm_server_addr()
......
...@@ -1447,7 +1447,7 @@ def indexing_one_hot( ...@@ -1447,7 +1447,7 @@ def indexing_one_hot(
src, (TensorWrapperBase, TensorBase) src, (TensorWrapperBase, TensorBase)
), "src must be of Tensor type" ), "src must be of Tensor type"
op = builtin.IndexingOneHot(axis=axis) op = builtin.IndexingOneHot(axis=axis)
index = utils.convert_single_value(index, (src,), dtype="int32") index = utils.convert_single_value(index, (src,), dtype="int32", device=src.device)
(result,) = apply(op, src, index) (result,) = apply(op, src, index)
if not keepdims: if not keepdims:
result = remove_axis(result, axis) result = remove_axis(result, axis)
......
...@@ -15,6 +15,7 @@ import pytest ...@@ -15,6 +15,7 @@ import pytest
import megengine as mge import megengine as mge
import megengine.distributed as dist import megengine.distributed as dist
from megengine import Parameter, Tensor, tensor from megengine import Parameter, Tensor, tensor
from megengine.device import get_default_device, set_default_device
from megengine.functional.distributed import ( from megengine.functional.distributed import (
all_gather, all_gather,
all_reduce_max, all_reduce_max,
...@@ -449,7 +450,8 @@ def test_io_remote(): ...@@ -449,7 +450,8 @@ def test_io_remote():
assert y.numpy()[0] == 0 assert y.numpy()[0] == 0
else: # remote recv else: # remote recv
dist.init_process_group("localhost", port, world_size, rank, rank) dist.init_process_group("localhost", port, world_size, rank, rank)
y = remote_recv(0, val.shape, val.dtype, cn="gpu1") y = remote_recv(0, val.shape, val.dtype)
assert y.device == "gpu1"
np.testing.assert_almost_equal(val, y.numpy()) np.testing.assert_almost_equal(val, y.numpy())
procs = [] procs = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册