feat(mge): remove F.identity

GitOrigin-RevId: 858be627acf028a46c57bde55161a85ff47d157f
上级 09241a1f
......@@ -42,7 +42,6 @@ __all__ = [
"full",
"full_like",
"gather",
"identity",
"linspace",
"ones",
"ones_like",
......@@ -178,18 +177,6 @@ def full_like(inp: Tensor, value: Union[int, float]) -> Tensor:
return full(inp.shape, value, dtype=inp.dtype, device=inp.device)
def identity(inp: Tensor) -> Tensor:
"""Applies an identity transformation to input tensor.
:param inp: input tensor.
:return: output tensor.
"""
op = builtin.Identity()
(data,) = convert_inputs(inp)
(output,) = apply(op, data)
return output
def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
"""
Broadcasts a tensor to given shape.
......
......@@ -11,7 +11,8 @@ from typing import Iterable, Union
import numpy as np
from ..core.ops.builtin import Copy
from ..core._wrap import device as as_device
from ..core.ops.builtin import Copy, Identity
from ..core.tensor import Tensor
from ..core.tensor.core import apply
from .math import topk as _topk
......@@ -63,12 +64,12 @@ def accuracy(
return accs
def copy(inp, cn):
def copy(inp, device=None):
r"""
Copies tensor to another device.
:param inp: input tensor.
:param cn: destination device.
:param device: destination device.
Examples:
......@@ -88,4 +89,6 @@ def copy(inp, cn):
[1 2 3]
"""
return apply(Copy(comp_node=cn), inp)[0]
if device is None:
return apply(Identity(), inp)[0]
return apply(Copy(comp_node=as_device(device).to_c()), inp)[0]
......@@ -6,7 +6,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ..functional import identity
from ..functional import copy
from .module import Module
......@@ -14,4 +14,4 @@ class Identity(Module):
r"""A placeholder identity operator that will ignore any argument."""
def forward(self, x):
return identity(x)
return copy(x)
......@@ -314,7 +314,7 @@ def test_device():
def test_identity():
x = tensor(np.random.random((5, 10)).astype(np.float32))
y = F.identity(x)
y = F.copy(x)
np.testing.assert_equal(y.numpy(), x)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部