提交 70bc746c 编写于 作者: M Megvii Engine Team

fix(mge): fix bug of identity()

GitOrigin-RevId: 4bfd3cafb54af339f263111ef3fb0c27244e0ca5
上级 a8c75ee5
......@@ -189,7 +189,7 @@ def identity(inp: Tensor) -> Tensor:
:return: output tensor.
"""
op = builtin.Identity()
(data,) = utils.convert_inputs(inp)
(data,) = convert_inputs(inp)
(output,) = apply(op, data)
return output
......
......@@ -367,6 +367,12 @@ def test_device():
np.testing.assert_almost_equal(y5.numpy(), y6.numpy())
def test_identity():
x = tensor(np.random.random((5, 10)).astype(np.float32))
y = F.identity(x)
np.testing.assert_equal(y.numpy(), x)
def copy_test(dst, src):
data = np.random.random((2, 3)).astype(np.float32)
x = tensor(data, device=src)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册