From 44d0b5daf527515ebbbaf4047ff9bf5b1b66ae2e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 8 Sep 2020 15:05:13 +0800 Subject: [PATCH] feat(imperative): enable to() to copy to device GitOrigin-RevId: f9caf17d24fd984055b3eed48136dd429e71fb2a --- imperative/python/megengine/device.py | 2 +- imperative/python/megengine/tensor.py | 5 +++++ imperative/python/test/unit/functional/test_tensor.py | 2 ++ 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/device.py b/imperative/python/megengine/device.py index 731b8985..631e49d7 100644 --- a/imperative/python/megengine/device.py +++ b/imperative/python/megengine/device.py @@ -70,7 +70,7 @@ def set_default_device(device: str = "xpux"): multi-threading parallelism at the operator level. For example, 'multithread4' will compute with 4 threads. which implements - The default value is 'xpux' to specify any device available. + The default value is 'xpux' to specify any device available. The priority of using gpu is higher when both gpu and cpu are available. It can also be set by environmental variable `MGE_DEFAULT_DEVICE`. """ diff --git a/imperative/python/megengine/tensor.py b/imperative/python/megengine/tensor.py index 1848463c..89436323 100644 --- a/imperative/python/megengine/tensor.py +++ b/imperative/python/megengine/tensor.py @@ -11,6 +11,8 @@ import collections from .core import Tensor as _Tensor +from .core.ops.builtin import Copy +from .core.tensor.core import apply from .device import get_default_device @@ -30,6 +32,9 @@ class Tensor(_Tensor): def reset_zero(self): self *= 0 + def to(self, cn): + return apply(Copy(comp_node=cn), self)[0] + def __getstate__(self): r""" __getstate__ will be called for pickle serialization or deep copy """ diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index 72e1fb73..2b86933d 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -322,6 +322,8 @@ def copy_test(dst, src): x = tensor(data, device=src) y = F.copy(x, dst) assert np.allclose(data, y.numpy()) + z = x.to(dst) + assert np.allclose(data, z.numpy()) @pytest.mark.skipif( -- GitLab