提交 2cc85487 编写于 作者: M Megvii Engine Team

fix(imperative): fix hardcode of default device

GitOrigin-RevId: 722c4debfaf3c4a27029ea9f207e65c35dd16f21
上级 403a1e7b
......@@ -17,8 +17,6 @@ __all__ = [
"set_default_device",
]
_default_device = os.getenv("MGE_DEFAULT_DEVICE", "xpux")
def _valid_device(inp):
if isinstance(inp, str) and len(inp) == 4:
......@@ -76,9 +74,8 @@ def set_default_device(device: str = "xpux"):
It can also be set by environmental variable `MGE_DEFAULT_DEVICE`.
"""
global _default_device # pylint: disable=global-statement
assert _valid_device(device), "Invalid device name {}".format(device)
_default_device = device
CompNode._set_default_device(device)
def get_default_device() -> str:
......@@ -86,4 +83,7 @@ def get_default_device() -> str:
It returns the value set by :func:`~.set_default_device`.
"""
return _default_device
return CompNode._get_default_device()
set_default_device(os.getenv("MGE_DEFAULT_DEVICE", "xpux"))
......@@ -39,13 +39,25 @@ auto def_TensorND(py::object parent, const char* name) {
&XTensorND::template copy_from_fixlayout<HostTensorStorage>));
}
std::string default_device = "xpux";
} // namespace
void set_default_device(const std::string &device) {
default_device = device;
}
std::string get_default_device() {
return default_device;
}
void init_common(py::module m) {
auto&& PyCompNode = py::class_<CompNode>(m, "CompNode")
.def(py::init())
.def(py::init(py::overload_cast<const std::string&>(&CompNode::load)))
.def("create_event", &CompNode::create_event, py::arg("flags") = 0ul)
.def("_set_default_device", &set_default_device)
.def("_get_default_device", &get_default_device)
.def("__str__", &CompNode::to_string_logical)
.def_static("_sync_all", &CompNode::sync_all)
.def(py::self == py::self)
......
......@@ -14,3 +14,6 @@
#include "./helper.h"
void init_common(pybind11::module m);
void set_default_device(const std::string &device);
std::string get_default_device();
\ No newline at end of file
......@@ -19,6 +19,7 @@
#include "megbrain/imperative.h"
#include "./helper.h"
#include "megbrain/plugin/profiler.h"
#include "./common.h"
namespace py = pybind11;
......@@ -230,7 +231,7 @@ void init_graph_rt(py::module m) {
m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype) {
if (!cn.valid()) {
cn = CompNode::load("xpux");
cn = CompNode::load(get_default_device());
}
auto hv = npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype);
return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node();
......
......@@ -21,6 +21,7 @@
#include "megbrain/imperative/interpreter.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "./helper.h"
#include "./common.h"
namespace py = pybind11;
......@@ -53,7 +54,7 @@ void init_imperative_rt(py::module m) {
py::class_<Interpreter::Channel>(m, "Interpreter")
.def("put", [](Interpreter::Channel& self, py::array data, DType dtype, CompNode cn) {
if (!cn.valid()) {
cn = CompNode::load("xpux");
cn = CompNode::load(get_default_device());
}
constexpr int size_threshhold = TensorShape::MAX_NDIM;
if (data.size() > size_threshhold) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册