From aa62672629f50c7a07f50833feeb53f4bfbdae8b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sat, 10 Oct 2020 21:09:13 +0800 Subject: [PATCH] feat(mge): make F.eye numpy compatible GitOrigin-RevId: fee32537b4cb0b3a514ac9cf0e407b7a463aca4e --- imperative/python/megengine/functional/tensor.py | 16 ++++++++++++---- .../python/test/unit/functional/test_tensor.py | 10 +++++++++- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 0b30d7982..99f57438c 100644 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -57,7 +57,7 @@ __all__ = [ ] -def eye(shape, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor: +def eye(N, M=None, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor: """Returns a 2D tensor with ones on the diagonal and zeros elsewhere. :param shape: expected shape of output tensor. @@ -72,8 +72,7 @@ def eye(shape, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor: import numpy as np import megengine.functional as F - data_shape = (4, 6) - out = F.eye(data_shape, dtype=np.float32) + out = F.eye(4, 6, dtype=np.float32) print(out.numpy()) Outputs: @@ -86,8 +85,17 @@ def eye(shape, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor: [0. 0. 0. 1. 0. 0.]] """ + if M is not None: + if isinstance(N, Tensor) or isinstance(M, Tensor): + shape = astensor1d((N, M)) + else: + shape = Tensor([N, M], dtype="int32", device=device) + elif isinstance(N, Tensor): + shape = N + else: + shape = Tensor(N, dtype="int32", device=device) op = builtin.Eye(k=0, dtype=dtype, comp_node=device) - (result,) = apply(op, Tensor(shape, dtype="int32", device=device)) + (result,) = apply(op, shape) return result diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index bde49584f..b58668aac 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -22,12 +22,20 @@ from megengine.distributed.helper import get_device_count_by_fork def test_eye(): dtype = np.float32 - cases = [{"input": [10, 20]}, {"input": [20, 30]}] + cases = [{"input": [10, 20]}, {"input": [30]}] for case in cases: np.testing.assert_allclose( F.eye(case["input"], dtype=dtype).numpy(), np.eye(*case["input"]).astype(dtype), ) + np.testing.assert_allclose( + F.eye(*case["input"], dtype=dtype).numpy(), + np.eye(*case["input"]).astype(dtype), + ) + np.testing.assert_allclose( + F.eye(tensor(case["input"]), dtype=dtype).numpy(), + np.eye(*case["input"]).astype(dtype), + ) def test_concat(): -- GitLab