From f78f529a86a00a5132ab16fabb17f1d64efcb94a Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 14 Dec 2022 14:19:43 +0800 Subject: [PATCH] fix(opr): fix the device problem of concat/stack GitOrigin-RevId: 01c97a4339803db89417a203968a10024ee3bf61 --- .../python/megengine/functional/tensor.py | 21 +++++--- .../test/unit/functional/test_tensor.py | 53 +++++++++++++++---- 2 files changed, 58 insertions(+), 16 deletions(-) diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index a0a5b3d8a..bf48d961b 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -486,12 +486,13 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: [ 9., 10., 11.]], dtype=float32) """ if len(inps) == 1: - return inps[0] + # if we return inps[0] directly, then the grad manager capture nothing + return copy(inps[0], device) if device is None: - device = get_default_device() - - (result,) = apply(builtin.Concat(axis=axis, comp_node=device), *inps) + device = get_device(inps) + device = as_device(device) + (result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps) return result @@ -517,10 +518,16 @@ def stack(inps, axis=0, device=None): [6., 7., 8.]], dtype=float32) """ if len(inps) == 1: - return expand_dims(inps[0], axis=axis) + ret = expand_dims(inps[0], axis=axis) + if device is None: + return ret + else: + return copy(ret, device) + if device is None: - device = get_default_device() - (result,) = apply(builtin.Stack(axis=axis, comp_node=device), *inps) + device = get_device(inps) + device = as_device(device) + (result,) = apply(builtin.Stack(axis=axis, comp_node=device.to_c()), *inps) return result diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index d3ecc6ec5..9d82fdafa 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -127,17 +127,27 @@ def test_condtake(is_varnode): @pytest.mark.parametrize("is_varnode", [True, False]) -def test_concat_device(is_varnode): +def test_concat_stack_device(is_varnode): if is_varnode: network = Network() else: network = None - data1 = make_tensor(np.random.random((3, 2, 2)).astype("float32"), network, "cpu0") + data1 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu0") data2 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu1") + data3 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu0") - out = F.concat([data1, data2], device="cpu0") - assert str(out.device).split(":")[0] == "cpu0" + for func in [F.concat, F.stack]: + out = F.concat([data1, data2], device="cpu1") + assert str(out.device).split(":")[0] == "cpu1" + out = F.concat([data1, data3]) + assert str(out.device).split(":")[0] == "cpu0" + + with pytest.raises(RuntimeError): + try: + out = F.concat([data1, data2]) + except: + raise RuntimeError("inputs have different devices") @pytest.mark.parametrize("is_varnode", [True, False]) @@ -219,9 +229,11 @@ def test_split_basic(is_varnode): def test_concat_and_stack(): import copy + from megengine.autodiff import GradManager + import torch def generate_test_data(max_nr_inp, max_dim, max_dim_len, test_concat=True): - nr_inp = np.random.randint(1, max_nr_inp) + nr_inp = np.random.randint(1, max_nr_inp) if max_nr_inp > 1 else 1 dims = np.random.randint(1, max_dim) cat_axis = ( np.random.randint(-dims, dims) @@ -245,13 +257,28 @@ def test_concat_and_stack(): max_nr_inp, max_dim, max_dim_len, test_concat ) inp_mges = [Tensor(inp_np) for inp_np in inp_nps] + inp_torchs = [torch.tensor(inp_np, requires_grad=True) for inp_np in inp_nps] if test_concat: - np_func, mge_func = np.concatenate, F.concat + np_func, mge_func, torch_func = np.concatenate, F.concat, torch.cat else: - np_func, mge_func = np.stack, F.stack + np_func, mge_func, torch_func = np.stack, F.stack, torch.stack + res_np = np_func(inp_nps, axis=cat_axis) - res_mge = mge_func(inp_mges, axis=cat_axis) - np.testing.assert_allclose(res_mge.numpy(), res_np) + grad_np = np.random.randn(*res_np.shape).astype(np.float32) + + gm = GradManager().attach(inp_mges) + with gm: + res_mge = mge_func(inp_mges, axis=cat_axis) + gm.backward(res_mge, Tensor(grad_np)) + + res_torch = torch_func(inp_torchs, dim=cat_axis) + res_torch.backward(torch.tensor(grad_np)) + + np.testing.assert_allclose(res_mge.numpy(), res_torch.detach().cpu().numpy()) + for inp_mge, inp_torch in zip(inp_mges, inp_torchs): + np.testing.assert_allclose( + inp_mge.grad.numpy(), inp_torch.grad.detach().cpu().numpy() + ) def test_concat(max_nr_inp, max_dim, max_dim_len): test_impl(max_nr_inp, max_dim, max_dim_len, test_concat=True) @@ -259,6 +286,14 @@ def test_concat_and_stack(): def test_stack(max_nr_inp, max_dim, max_dim_len): test_impl(max_nr_inp, max_dim, max_dim_len, test_concat=False) + # test only one input + test_concat(1, 7, 16) + test_stack(1, 7, 16) + + # test zero shape + test_concat(10, 7, 1) + test_stack(10, 7, 1) + for _ in range(3): test_concat(10, 7, 16) -- GitLab