提交 44d0b5da 编写于 作者: M Megvii Engine Team

feat(imperative): enable to() to copy to device

GitOrigin-RevId: f9caf17d24fd984055b3eed48136dd429e71fb2a
上级 3e11d894
......@@ -70,7 +70,7 @@ def set_default_device(device: str = "xpux"):
multi-threading parallelism at the operator level. For example,
'multithread4' will compute with 4 threads. which implements
The default value is 'xpux' to specify any device available.
The default value is 'xpux' to specify any device available. The priority of using gpu is higher when both gpu and cpu are available.
It can also be set by environmental variable `MGE_DEFAULT_DEVICE`.
"""
......
......@@ -11,6 +11,8 @@
import collections
from .core import Tensor as _Tensor
from .core.ops.builtin import Copy
from .core.tensor.core import apply
from .device import get_default_device
......@@ -30,6 +32,9 @@ class Tensor(_Tensor):
def reset_zero(self):
self *= 0
def to(self, cn):
return apply(Copy(comp_node=cn), self)[0]
def __getstate__(self):
r""" __getstate__ will be called for pickle serialization or deep copy
"""
......
......@@ -322,6 +322,8 @@ def copy_test(dst, src):
x = tensor(data, device=src)
y = F.copy(x, dst)
assert np.allclose(data, y.numpy())
z = x.to(dst)
assert np.allclose(data, z.numpy())
@pytest.mark.skipif(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册