From 67859f04e143bb05e11016d146794aac06ff0257 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 7 Sep 2020 14:08:37 +0800 Subject: [PATCH] fix(imperative): add __array__ and __array_wrap__ for tensorwrapper GitOrigin-RevId: 87d4ab6c8eaf934b0ef3770113627c61c92ada93 --- .../python/megengine/core/tensor/tensor_wrapper.py | 13 ++++++++++++- imperative/python/test/unit/test_tensor_wrapper.py | 10 ++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/core/tensor/tensor_wrapper.py b/imperative/python/megengine/core/tensor/tensor_wrapper.py index fd2e6bea..722d5dc9 100644 --- a/imperative/python/megengine/core/tensor/tensor_wrapper.py +++ b/imperative/python/megengine/core/tensor/tensor_wrapper.py @@ -211,7 +211,18 @@ def _expand_args(args): class ArrayMethodMixin(abc.ABC): - __array_priority__ = 233333 + # enable tensor to be converted to numpy array + __array_priority__ = 1001 + + def __array__(self, dtype=None): + if dtype == None: + return self.numpy() + return self.numpy().astype(dtype) + + def __array_wrap__(self, array): + return TensorWrapper( + as_raw_tensor(array, dtype=array.dtype, device=self.device) + ) @abc.abstractmethod def _reset(self, other): diff --git a/imperative/python/test/unit/test_tensor_wrapper.py b/imperative/python/test/unit/test_tensor_wrapper.py index c2f8def6..26bc9c9c 100644 --- a/imperative/python/test/unit/test_tensor_wrapper.py +++ b/imperative/python/test/unit/test_tensor_wrapper.py @@ -50,3 +50,13 @@ def test_set_subtensor(): np.testing.assert_almost_equal(x.numpy(), [3, 1, 2], decimal=6) x[1:3] = [4, 5] np.testing.assert_almost_equal(x.numpy(), [3, 4, 5], decimal=6) + + +def test_computing_with_numpy_array(): + x = np.array([1, 2, 3], dtype=np.int32) + xx = TensorWrapper(x, device="cpu0") + y = np.array([1, 0, 3], dtype=np.int32) + assert np.add(xx, y).device == xx.device + np.testing.assert_equal(np.add(xx, y).numpy(), np.add(x, y)) + np.testing.assert_equal(np.equal(xx, y).numpy(), np.equal(x, y)) + np.testing.assert_equal(np.equal(xx, xx).numpy(), np.equal(x, x)) -- GitLab