From 70bc746c201045b8baed245a826ef68c2e900518 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 16 Sep 2020 14:05:51 +0800 Subject: [PATCH] fix(mge): fix bug of identity() GitOrigin-RevId: 4bfd3cafb54af339f263111ef3fb0c27244e0ca5 --- imperative/python/megengine/functional/tensor.py | 2 +- imperative/python/test/unit/functional/test_tensor.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 8ba9596a..31053b78 100644 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -189,7 +189,7 @@ def identity(inp: Tensor) -> Tensor: :return: output tensor. """ op = builtin.Identity() - (data,) = utils.convert_inputs(inp) + (data,) = convert_inputs(inp) (output,) = apply(op, data) return output diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index b7a320b9..c17cf310 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -367,6 +367,12 @@ def test_device(): np.testing.assert_almost_equal(y5.numpy(), y6.numpy()) +def test_identity(): + x = tensor(np.random.random((5, 10)).astype(np.float32)) + y = F.identity(x) + np.testing.assert_equal(y.numpy(), x) + + def copy_test(dst, src): data = np.random.random((2, 3)).astype(np.float32) x = tensor(data, device=src) -- GitLab