提交 266263dc 编写于 作者: M Megvii Engine Team

fix(opr): fix the device problem of concat/stack

GitOrigin-RevId: 01c97a4339803db89417a203968a10024ee3bf61
上级 c850c7eb
......@@ -486,12 +486,13 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor:
[ 9., 10., 11.]], dtype=float32)
"""
if len(inps) == 1:
return inps[0]
# if we return inps[0] directly, then the grad manager capture nothing
return copy(inps[0], device)
if device is None:
device = get_default_device()
(result,) = apply(builtin.Concat(axis=axis, comp_node=device), *inps)
device = get_device(inps)
device = as_device(device)
(result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps)
return result
......@@ -517,10 +518,16 @@ def stack(inps, axis=0, device=None):
[6., 7., 8.]], dtype=float32)
"""
if len(inps) == 1:
return expand_dims(inps[0], axis=axis)
ret = expand_dims(inps[0], axis=axis)
if device is None:
device = get_default_device()
(result,) = apply(builtin.Stack(axis=axis, comp_node=device), *inps)
return ret
else:
return copy(ret, device)
if device is None:
device = get_device(inps)
device = as_device(device)
(result,) = apply(builtin.Stack(axis=axis, comp_node=device.to_c()), *inps)
return result
......
......@@ -127,18 +127,28 @@ def test_condtake(is_varnode):
@pytest.mark.parametrize("is_varnode", [True, False])
def test_concat_device(is_varnode):
def test_concat_stack_device(is_varnode):
if is_varnode:
network = Network()
else:
network = None
data1 = make_tensor(np.random.random((3, 2, 2)).astype("float32"), network, "cpu0")
data1 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu0")
data2 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu1")
data3 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu0")
out = F.concat([data1, data2], device="cpu0")
for func in [F.concat, F.stack]:
out = F.concat([data1, data2], device="cpu1")
assert str(out.device).split(":")[0] == "cpu1"
out = F.concat([data1, data3])
assert str(out.device).split(":")[0] == "cpu0"
with pytest.raises(RuntimeError):
try:
out = F.concat([data1, data2])
except:
raise RuntimeError("inputs have different devices")
@pytest.mark.parametrize("is_varnode", [True, False])
def test_stack(is_varnode):
......@@ -219,9 +229,11 @@ def test_split_basic(is_varnode):
def test_concat_and_stack():
import copy
from megengine.autodiff import GradManager
import torch
def generate_test_data(max_nr_inp, max_dim, max_dim_len, test_concat=True):
nr_inp = np.random.randint(1, max_nr_inp)
nr_inp = np.random.randint(1, max_nr_inp) if max_nr_inp > 1 else 1
dims = np.random.randint(1, max_dim)
cat_axis = (
np.random.randint(-dims, dims)
......@@ -245,13 +257,28 @@ def test_concat_and_stack():
max_nr_inp, max_dim, max_dim_len, test_concat
)
inp_mges = [Tensor(inp_np) for inp_np in inp_nps]
inp_torchs = [torch.tensor(inp_np, requires_grad=True) for inp_np in inp_nps]
if test_concat:
np_func, mge_func = np.concatenate, F.concat
np_func, mge_func, torch_func = np.concatenate, F.concat, torch.cat
else:
np_func, mge_func = np.stack, F.stack
np_func, mge_func, torch_func = np.stack, F.stack, torch.stack
res_np = np_func(inp_nps, axis=cat_axis)
grad_np = np.random.randn(*res_np.shape).astype(np.float32)
gm = GradManager().attach(inp_mges)
with gm:
res_mge = mge_func(inp_mges, axis=cat_axis)
np.testing.assert_allclose(res_mge.numpy(), res_np)
gm.backward(res_mge, Tensor(grad_np))
res_torch = torch_func(inp_torchs, dim=cat_axis)
res_torch.backward(torch.tensor(grad_np))
np.testing.assert_allclose(res_mge.numpy(), res_torch.detach().cpu().numpy())
for inp_mge, inp_torch in zip(inp_mges, inp_torchs):
np.testing.assert_allclose(
inp_mge.grad.numpy(), inp_torch.grad.detach().cpu().numpy()
)
def test_concat(max_nr_inp, max_dim, max_dim_len):
test_impl(max_nr_inp, max_dim, max_dim_len, test_concat=True)
......@@ -259,6 +286,14 @@ def test_concat_and_stack():
def test_stack(max_nr_inp, max_dim, max_dim_len):
test_impl(max_nr_inp, max_dim, max_dim_len, test_concat=False)
# test only one input
test_concat(1, 7, 16)
test_stack(1, 7, 16)
# test zero shape
test_concat(10, 7, 1)
test_stack(10, 7, 1)
for _ in range(3):
test_concat(10, 7, 16)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册