提交 2e9ba679 编写于 作者: M Megvii Engine Team

feat(mge/device): __repr__ method will show physical device

GitOrigin-RevId: 050c3864a7d99234a02114197dd0f499cb88f413
上级 2efba9a3
...@@ -22,11 +22,13 @@ class Device: ...@@ -22,11 +22,13 @@ class Device:
else: else:
self._cn = CompNode(device) self._cn = CompNode(device)
self.logical_name = self._cn.logical_name
def to_c(self): def to_c(self):
return self._cn return self._cn
def __repr__(self): def __repr__(self):
return "{}({})".format(type(self).__qualname__, self) return "{}({})".format(type(self).__qualname__, repr(self._cn))
def __str__(self): def __str__(self):
return str(self._cn) return str(self._cn)
......
...@@ -67,7 +67,7 @@ class Tensor(_Tensor): ...@@ -67,7 +67,7 @@ class Tensor(_Tensor):
state = { state = {
"data": self.numpy(), "data": self.numpy(),
"device": str(self.device), "device": self.device.logical_name,
"dtype": self.dtype, "dtype": self.dtype,
"qdict": self.q_dict, "qdict": self.q_dict,
} }
...@@ -75,13 +75,13 @@ class Tensor(_Tensor): ...@@ -75,13 +75,13 @@ class Tensor(_Tensor):
def __setstate__(self, state): def __setstate__(self, state):
data = state.pop("data") data = state.pop("data")
device = state.pop("device") logical_device = state.pop("device")
if self.dmap_callback is not None: if self.dmap_callback is not None:
assert isinstance(device, str) assert isinstance(logical_device, str)
device = self.dmap_callback(device) logical_device = self.dmap_callback(logical_device)
dtype = state.pop("dtype") dtype = state.pop("dtype")
self.q_dict = state.pop("qdict") self.q_dict = state.pop("qdict")
super().__init__(data, dtype=dtype, device=device) super().__init__(data, dtype=dtype, device=logical_device)
def detach(self): def detach(self):
r""" r"""
......
...@@ -55,10 +55,16 @@ void init_common(py::module m) { ...@@ -55,10 +55,16 @@ void init_common(py::module m) {
auto&& PyCompNode = py::class_<CompNode>(m, "CompNode") auto&& PyCompNode = py::class_<CompNode>(m, "CompNode")
.def(py::init()) .def(py::init())
.def(py::init(py::overload_cast<const std::string&>(&CompNode::load))) .def(py::init(py::overload_cast<const std::string&>(&CompNode::load)))
.def_property_readonly("logical_name", [](const CompNode& cn) {
return cn.to_string_logical();
})
.def("create_event", &CompNode::create_event, py::arg("flags") = 0ul) .def("create_event", &CompNode::create_event, py::arg("flags") = 0ul)
.def("_set_default_device", &set_default_device) .def("_set_default_device", &set_default_device)
.def("_get_default_device", &get_default_device) .def("_get_default_device", &get_default_device)
.def("__str__", &CompNode::to_string_logical) .def("__str__", &CompNode::to_string_logical)
.def("__repr__", [](const CompNode& cn) {
return py::str("\"" + cn.to_string() + "\" from \"" + cn.to_string_logical() + "\"");
})
.def_static("_sync_all", &CompNode::sync_all) .def_static("_sync_all", &CompNode::sync_all)
.def(py::self == py::self) .def(py::self == py::self)
.def_static("_get_device_count", &CompNode::get_device_count, .def_static("_get_device_count", &CompNode::get_device_count,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册