提交 aa626726 编写于 作者: M Megvii Engine Team

feat(mge): make F.eye numpy compatible

GitOrigin-RevId: fee32537b4cb0b3a514ac9cf0e407b7a463aca4e
上级 7589bce4
......@@ -57,7 +57,7 @@ __all__ = [
]
def eye(shape, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor:
def eye(N, M=None, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor:
"""Returns a 2D tensor with ones on the diagonal and zeros elsewhere.
:param shape: expected shape of output tensor.
......@@ -72,8 +72,7 @@ def eye(shape, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor:
import numpy as np
import megengine.functional as F
data_shape = (4, 6)
out = F.eye(data_shape, dtype=np.float32)
out = F.eye(4, 6, dtype=np.float32)
print(out.numpy())
Outputs:
......@@ -86,8 +85,17 @@ def eye(shape, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor:
[0. 0. 0. 1. 0. 0.]]
"""
if M is not None:
if isinstance(N, Tensor) or isinstance(M, Tensor):
shape = astensor1d((N, M))
else:
shape = Tensor([N, M], dtype="int32", device=device)
elif isinstance(N, Tensor):
shape = N
else:
shape = Tensor(N, dtype="int32", device=device)
op = builtin.Eye(k=0, dtype=dtype, comp_node=device)
(result,) = apply(op, Tensor(shape, dtype="int32", device=device))
(result,) = apply(op, shape)
return result
......
......@@ -22,12 +22,20 @@ from megengine.distributed.helper import get_device_count_by_fork
def test_eye():
dtype = np.float32
cases = [{"input": [10, 20]}, {"input": [20, 30]}]
cases = [{"input": [10, 20]}, {"input": [30]}]
for case in cases:
np.testing.assert_allclose(
F.eye(case["input"], dtype=dtype).numpy(),
np.eye(*case["input"]).astype(dtype),
)
np.testing.assert_allclose(
F.eye(*case["input"], dtype=dtype).numpy(),
np.eye(*case["input"]).astype(dtype),
)
np.testing.assert_allclose(
F.eye(tensor(case["input"]), dtype=dtype).numpy(),
np.eye(*case["input"]).astype(dtype),
)
def test_concat():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册