提交 d225cbcd 编写于 作者: M Megvii Engine Team

feat(mge): add device name check

GitOrigin-RevId: d9910b6275c5b7eaf89c56584c1cbe39fe24becf
上级 be236642
......@@ -7,6 +7,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import os
import re
from .core._imperative_rt.common import CompNode, DeviceType
from .core._imperative_rt.common import set_prealloc_config as _set_prealloc_config
......@@ -22,10 +23,8 @@ __all__ = [
def _valid_device(inp):
if isinstance(inp, str) and len(inp) == 4:
if inp[0] in {"x", "c", "g"} and inp[1:3] == "pu":
if inp[3] == "x" or inp[3].isdigit():
return True
if isinstance(inp, str) and re.match("^[cxg]pu(\d+|\d+:\d+|x)$", inp):
return True
return False
......
......@@ -14,7 +14,7 @@ from .core import Tensor as _Tensor
from .core.ops.builtin import Copy
from .core.tensor.core import apply
from .core.tensor.raw_tensor import as_device
from .device import get_default_device
from .device import _valid_device, get_default_device
from .utils.deprecation import deprecated
......@@ -37,6 +37,12 @@ class Tensor(_Tensor):
self *= 0
def to(self, device):
if isinstance(device, str) and not _valid_device(device):
raise ValueError(
"invalid device name {}. For the correct format of the device name, please refer to the instruction of megengine.device.set_default_device()".format(
device
)
)
cn = as_device(device).to_c()
return apply(Copy(comp_node=cn), self)[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册