提交 67859f04 编写于 作者: M Megvii Engine Team

fix(imperative): add __array__ and __array_wrap__ for tensorwrapper

GitOrigin-RevId: 87d4ab6c8eaf934b0ef3770113627c61c92ada93
上级 b1ab3646
......@@ -211,7 +211,18 @@ def _expand_args(args):
class ArrayMethodMixin(abc.ABC):
__array_priority__ = 233333
# enable tensor to be converted to numpy array
__array_priority__ = 1001
def __array__(self, dtype=None):
if dtype == None:
return self.numpy()
return self.numpy().astype(dtype)
def __array_wrap__(self, array):
return TensorWrapper(
as_raw_tensor(array, dtype=array.dtype, device=self.device)
)
@abc.abstractmethod
def _reset(self, other):
......
......@@ -50,3 +50,13 @@ def test_set_subtensor():
np.testing.assert_almost_equal(x.numpy(), [3, 1, 2], decimal=6)
x[1:3] = [4, 5]
np.testing.assert_almost_equal(x.numpy(), [3, 4, 5], decimal=6)
def test_computing_with_numpy_array():
x = np.array([1, 2, 3], dtype=np.int32)
xx = TensorWrapper(x, device="cpu0")
y = np.array([1, 0, 3], dtype=np.int32)
assert np.add(xx, y).device == xx.device
np.testing.assert_equal(np.add(xx, y).numpy(), np.add(x, y))
np.testing.assert_equal(np.equal(xx, y).numpy(), np.equal(x, y))
np.testing.assert_equal(np.equal(xx, xx).numpy(), np.equal(x, x))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册