From 81b6a7338207a9183897416316627a19763c2cc7 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 14 Sep 2020 21:54:07 +0800 Subject: [PATCH] fix(mge): fix bug of tensor.T GitOrigin-RevId: 9fe9347b006dc894aaad8e22ba60d3f00c2216aa --- imperative/python/megengine/core/tensor/tensor_wrapper.py | 2 +- imperative/python/test/unit/core/test_tensor_wrapper.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/core/tensor/tensor_wrapper.py b/imperative/python/megengine/core/tensor/tensor_wrapper.py index 05f2d7091..840dfcab9 100644 --- a/imperative/python/megengine/core/tensor/tensor_wrapper.py +++ b/imperative/python/megengine/core/tensor/tensor_wrapper.py @@ -371,7 +371,7 @@ class ArrayMethodMixin(abc.ABC): def transpose(self, *args): if not args: - args = reversed(range(self.ndim)) + args = range(self.ndim)[::-1] return _transpose(self, _expand_args(args)) def flatten(self): diff --git a/imperative/python/test/unit/core/test_tensor_wrapper.py b/imperative/python/test/unit/core/test_tensor_wrapper.py index 26bc9c9c3..a90b109a4 100644 --- a/imperative/python/test/unit/core/test_tensor_wrapper.py +++ b/imperative/python/test/unit/core/test_tensor_wrapper.py @@ -60,3 +60,9 @@ def test_computing_with_numpy_array(): np.testing.assert_equal(np.add(xx, y).numpy(), np.add(x, y)) np.testing.assert_equal(np.equal(xx, y).numpy(), np.equal(x, y)) np.testing.assert_equal(np.equal(xx, xx).numpy(), np.equal(x, x)) + + +def test_transpose(): + x = np.random.rand(2, 5).astype("float32") + xx = TensorWrapper(x) + np.testing.assert_almost_equal(xx.T.numpy(), x.T) -- GitLab