提交 2cc85487 编写于 作者: M Megvii Engine Team

fix(imperative): fix hardcode of default device

GitOrigin-RevId: 722c4debfaf3c4a27029ea9f207e65c35dd16f21
上级 403a1e7b
...@@ -17,8 +17,6 @@ __all__ = [ ...@@ -17,8 +17,6 @@ __all__ = [
"set_default_device", "set_default_device",
] ]
_default_device = os.getenv("MGE_DEFAULT_DEVICE", "xpux")
def _valid_device(inp): def _valid_device(inp):
if isinstance(inp, str) and len(inp) == 4: if isinstance(inp, str) and len(inp) == 4:
...@@ -76,9 +74,8 @@ def set_default_device(device: str = "xpux"): ...@@ -76,9 +74,8 @@ def set_default_device(device: str = "xpux"):
It can also be set by environmental variable `MGE_DEFAULT_DEVICE`. It can also be set by environmental variable `MGE_DEFAULT_DEVICE`.
""" """
global _default_device # pylint: disable=global-statement
assert _valid_device(device), "Invalid device name {}".format(device) assert _valid_device(device), "Invalid device name {}".format(device)
_default_device = device CompNode._set_default_device(device)
def get_default_device() -> str: def get_default_device() -> str:
...@@ -86,4 +83,7 @@ def get_default_device() -> str: ...@@ -86,4 +83,7 @@ def get_default_device() -> str:
It returns the value set by :func:`~.set_default_device`. It returns the value set by :func:`~.set_default_device`.
""" """
return _default_device return CompNode._get_default_device()
set_default_device(os.getenv("MGE_DEFAULT_DEVICE", "xpux"))
...@@ -39,13 +39,25 @@ auto def_TensorND(py::object parent, const char* name) { ...@@ -39,13 +39,25 @@ auto def_TensorND(py::object parent, const char* name) {
&XTensorND::template copy_from_fixlayout<HostTensorStorage>)); &XTensorND::template copy_from_fixlayout<HostTensorStorage>));
} }
std::string default_device = "xpux";
} // namespace } // namespace
void set_default_device(const std::string &device) {
default_device = device;
}
std::string get_default_device() {
return default_device;
}
void init_common(py::module m) { void init_common(py::module m) {
auto&& PyCompNode = py::class_<CompNode>(m, "CompNode") auto&& PyCompNode = py::class_<CompNode>(m, "CompNode")
.def(py::init()) .def(py::init())
.def(py::init(py::overload_cast<const std::string&>(&CompNode::load))) .def(py::init(py::overload_cast<const std::string&>(&CompNode::load)))
.def("create_event", &CompNode::create_event, py::arg("flags") = 0ul) .def("create_event", &CompNode::create_event, py::arg("flags") = 0ul)
.def("_set_default_device", &set_default_device)
.def("_get_default_device", &get_default_device)
.def("__str__", &CompNode::to_string_logical) .def("__str__", &CompNode::to_string_logical)
.def_static("_sync_all", &CompNode::sync_all) .def_static("_sync_all", &CompNode::sync_all)
.def(py::self == py::self) .def(py::self == py::self)
......
...@@ -14,3 +14,6 @@ ...@@ -14,3 +14,6 @@
#include "./helper.h" #include "./helper.h"
void init_common(pybind11::module m); void init_common(pybind11::module m);
void set_default_device(const std::string &device);
std::string get_default_device();
\ No newline at end of file
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "megbrain/imperative.h" #include "megbrain/imperative.h"
#include "./helper.h" #include "./helper.h"
#include "megbrain/plugin/profiler.h" #include "megbrain/plugin/profiler.h"
#include "./common.h"
namespace py = pybind11; namespace py = pybind11;
...@@ -230,7 +231,7 @@ void init_graph_rt(py::module m) { ...@@ -230,7 +231,7 @@ void init_graph_rt(py::module m) {
m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype) { m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype) {
if (!cn.valid()) { if (!cn.valid()) {
cn = CompNode::load("xpux"); cn = CompNode::load(get_default_device());
} }
auto hv = npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype); auto hv = npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype);
return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node(); return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node();
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "megbrain/imperative/interpreter.h" #include "megbrain/imperative/interpreter.h"
#include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/opr_attr.h"
#include "./helper.h" #include "./helper.h"
#include "./common.h"
namespace py = pybind11; namespace py = pybind11;
...@@ -53,7 +54,7 @@ void init_imperative_rt(py::module m) { ...@@ -53,7 +54,7 @@ void init_imperative_rt(py::module m) {
py::class_<Interpreter::Channel>(m, "Interpreter") py::class_<Interpreter::Channel>(m, "Interpreter")
.def("put", [](Interpreter::Channel& self, py::array data, DType dtype, CompNode cn) { .def("put", [](Interpreter::Channel& self, py::array data, DType dtype, CompNode cn) {
if (!cn.valid()) { if (!cn.valid()) {
cn = CompNode::load("xpux"); cn = CompNode::load(get_default_device());
} }
constexpr int size_threshhold = TensorShape::MAX_NDIM; constexpr int size_threshhold = TensorShape::MAX_NDIM;
if (data.size() > size_threshhold) { if (data.size() > size_threshhold) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册