提交 f78f529a 编写于 作者: M Megvii Engine Team

fix(opr): fix the device problem of concat/stack

GitOrigin-RevId: 01c97a4339803db89417a203968a10024ee3bf61
上级 0e83bce4
...@@ -486,12 +486,13 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: ...@@ -486,12 +486,13 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor:
[ 9., 10., 11.]], dtype=float32) [ 9., 10., 11.]], dtype=float32)
""" """
if len(inps) == 1: if len(inps) == 1:
return inps[0] # if we return inps[0] directly, then the grad manager capture nothing
return copy(inps[0], device)
if device is None: if device is None:
device = get_default_device() device = get_device(inps)
device = as_device(device)
(result,) = apply(builtin.Concat(axis=axis, comp_node=device), *inps) (result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps)
return result return result
...@@ -517,10 +518,16 @@ def stack(inps, axis=0, device=None): ...@@ -517,10 +518,16 @@ def stack(inps, axis=0, device=None):
[6., 7., 8.]], dtype=float32) [6., 7., 8.]], dtype=float32)
""" """
if len(inps) == 1: if len(inps) == 1:
return expand_dims(inps[0], axis=axis) ret = expand_dims(inps[0], axis=axis)
if device is None: if device is None:
device = get_default_device() return ret
(result,) = apply(builtin.Stack(axis=axis, comp_node=device), *inps) else:
return copy(ret, device)
if device is None:
device = get_device(inps)
device = as_device(device)
(result,) = apply(builtin.Stack(axis=axis, comp_node=device.to_c()), *inps)
return result return result
......
...@@ -127,18 +127,28 @@ def test_condtake(is_varnode): ...@@ -127,18 +127,28 @@ def test_condtake(is_varnode):
@pytest.mark.parametrize("is_varnode", [True, False]) @pytest.mark.parametrize("is_varnode", [True, False])
def test_concat_device(is_varnode): def test_concat_stack_device(is_varnode):
if is_varnode: if is_varnode:
network = Network() network = Network()
else: else:
network = None network = None
data1 = make_tensor(np.random.random((3, 2, 2)).astype("float32"), network, "cpu0") data1 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu0")
data2 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu1") data2 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu1")
data3 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu0")
out = F.concat([data1, data2], device="cpu0") for func in [F.concat, F.stack]:
out = F.concat([data1, data2], device="cpu1")
assert str(out.device).split(":")[0] == "cpu1"
out = F.concat([data1, data3])
assert str(out.device).split(":")[0] == "cpu0" assert str(out.device).split(":")[0] == "cpu0"
with pytest.raises(RuntimeError):
try:
out = F.concat([data1, data2])
except:
raise RuntimeError("inputs have different devices")
@pytest.mark.parametrize("is_varnode", [True, False]) @pytest.mark.parametrize("is_varnode", [True, False])
def test_stack(is_varnode): def test_stack(is_varnode):
...@@ -219,9 +229,11 @@ def test_split_basic(is_varnode): ...@@ -219,9 +229,11 @@ def test_split_basic(is_varnode):
def test_concat_and_stack(): def test_concat_and_stack():
import copy import copy
from megengine.autodiff import GradManager
import torch
def generate_test_data(max_nr_inp, max_dim, max_dim_len, test_concat=True): def generate_test_data(max_nr_inp, max_dim, max_dim_len, test_concat=True):
nr_inp = np.random.randint(1, max_nr_inp) nr_inp = np.random.randint(1, max_nr_inp) if max_nr_inp > 1 else 1
dims = np.random.randint(1, max_dim) dims = np.random.randint(1, max_dim)
cat_axis = ( cat_axis = (
np.random.randint(-dims, dims) np.random.randint(-dims, dims)
...@@ -245,13 +257,28 @@ def test_concat_and_stack(): ...@@ -245,13 +257,28 @@ def test_concat_and_stack():
max_nr_inp, max_dim, max_dim_len, test_concat max_nr_inp, max_dim, max_dim_len, test_concat
) )
inp_mges = [Tensor(inp_np) for inp_np in inp_nps] inp_mges = [Tensor(inp_np) for inp_np in inp_nps]
inp_torchs = [torch.tensor(inp_np, requires_grad=True) for inp_np in inp_nps]
if test_concat: if test_concat:
np_func, mge_func = np.concatenate, F.concat np_func, mge_func, torch_func = np.concatenate, F.concat, torch.cat
else: else:
np_func, mge_func = np.stack, F.stack np_func, mge_func, torch_func = np.stack, F.stack, torch.stack
res_np = np_func(inp_nps, axis=cat_axis) res_np = np_func(inp_nps, axis=cat_axis)
grad_np = np.random.randn(*res_np.shape).astype(np.float32)
gm = GradManager().attach(inp_mges)
with gm:
res_mge = mge_func(inp_mges, axis=cat_axis) res_mge = mge_func(inp_mges, axis=cat_axis)
np.testing.assert_allclose(res_mge.numpy(), res_np) gm.backward(res_mge, Tensor(grad_np))
res_torch = torch_func(inp_torchs, dim=cat_axis)
res_torch.backward(torch.tensor(grad_np))
np.testing.assert_allclose(res_mge.numpy(), res_torch.detach().cpu().numpy())
for inp_mge, inp_torch in zip(inp_mges, inp_torchs):
np.testing.assert_allclose(
inp_mge.grad.numpy(), inp_torch.grad.detach().cpu().numpy()
)
def test_concat(max_nr_inp, max_dim, max_dim_len): def test_concat(max_nr_inp, max_dim, max_dim_len):
test_impl(max_nr_inp, max_dim, max_dim_len, test_concat=True) test_impl(max_nr_inp, max_dim, max_dim_len, test_concat=True)
...@@ -259,6 +286,14 @@ def test_concat_and_stack(): ...@@ -259,6 +286,14 @@ def test_concat_and_stack():
def test_stack(max_nr_inp, max_dim, max_dim_len): def test_stack(max_nr_inp, max_dim, max_dim_len):
test_impl(max_nr_inp, max_dim, max_dim_len, test_concat=False) test_impl(max_nr_inp, max_dim, max_dim_len, test_concat=False)
# test only one input
test_concat(1, 7, 16)
test_stack(1, 7, 16)
# test zero shape
test_concat(10, 7, 1)
test_stack(10, 7, 1)
for _ in range(3): for _ in range(3):
test_concat(10, 7, 16) test_concat(10, 7, 16)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册