diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index e7bd2c6ee4f9ee583464e82f9b47c1766d3815e2..ada53eadb61dc995345a5acad70bc085cc28e7f6 100644 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -42,7 +42,6 @@ __all__ = [ "full", "full_like", "gather", - "identity", "linspace", "ones", "ones_like", @@ -178,18 +177,6 @@ def full_like(inp: Tensor, value: Union[int, float]) -> Tensor: return full(inp.shape, value, dtype=inp.dtype, device=inp.device) -def identity(inp: Tensor) -> Tensor: - """Applies an identity transformation to input tensor. - - :param inp: input tensor. - :return: output tensor. - """ - op = builtin.Identity() - (data,) = convert_inputs(inp) - (output,) = apply(op, data) - return output - - def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: """ Broadcasts a tensor to given shape. diff --git a/imperative/python/megengine/functional/utils.py b/imperative/python/megengine/functional/utils.py index 8446f964827260890abb23890c7901d97fb4018a..d518b69f21b8d6cfddaa10d521bbc89b980e9987 100644 --- a/imperative/python/megengine/functional/utils.py +++ b/imperative/python/megengine/functional/utils.py @@ -11,7 +11,8 @@ from typing import Iterable, Union import numpy as np -from ..core.ops.builtin import Copy +from ..core._wrap import device as as_device +from ..core.ops.builtin import Copy, Identity from ..core.tensor import Tensor from ..core.tensor.core import apply from .math import topk as _topk @@ -63,12 +64,12 @@ def accuracy( return accs -def copy(inp, cn): +def copy(inp, device=None): r""" Copies tensor to another device. :param inp: input tensor. - :param cn: destination device. + :param device: destination device. Examples: @@ -88,4 +89,6 @@ def copy(inp, cn): [1 2 3] """ - return apply(Copy(comp_node=cn), inp)[0] + if device is None: + return apply(Identity(), inp)[0] + return apply(Copy(comp_node=as_device(device).to_c()), inp)[0] diff --git a/imperative/python/megengine/module/identity.py b/imperative/python/megengine/module/identity.py index 51b31e505370020a14744e39054979da5c197027..a948d256a7f48741c3e44da4db8501394ab56336 100644 --- a/imperative/python/megengine/module/identity.py +++ b/imperative/python/megengine/module/identity.py @@ -6,7 +6,7 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from ..functional import identity +from ..functional import copy from .module import Module @@ -14,4 +14,4 @@ class Identity(Module): r"""A placeholder identity operator that will ignore any argument.""" def forward(self, x): - return identity(x) + return copy(x) diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index 732eea4801dd501cc8c167b0509304cf82398f02..bde49584fe22987d87af3810527f6a49f6c0f00f 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -314,7 +314,7 @@ def test_device(): def test_identity(): x = tensor(np.random.random((5, 10)).astype(np.float32)) - y = F.identity(x) + y = F.copy(x) np.testing.assert_equal(y.numpy(), x)