提交 e6e29748 编写于 作者: M Megvii Engine Team

chore(mge/imperative): fix Graph.make_const

GitOrigin-RevId: 0f4c62aebf2975f2d2fc11c029c273a45527eb8b
上级 ac3408bf
......@@ -11,6 +11,8 @@ import threading
import weakref
from concurrent.futures import Future, ThreadPoolExecutor
import numpy as np
from .. import _imperative_rt
from .._imperative_rt.ops import BackwardGraph
from .._wrap import device as as_device
......@@ -32,6 +34,8 @@ class Graph(_imperative_rt.ComputingGraph):
wrapper, cache = VarNode, self._var_cache
elif type(obj) is _imperative_rt.OperatorNode:
wrapper, cache = OpNode, self._op_cache
else:
raise TypeError(type(obj))
if obj not in cache:
cache[obj] = wrapper(obj)
return cache[obj]
......@@ -62,6 +66,11 @@ class Graph(_imperative_rt.ComputingGraph):
assert dtype is None and device is None
return self._wrap(_imperative_rt.make_shared(self, data))
else:
data = np.asarray(data, dtype=dtype)
if data.dtype == np.float64:
data = data.astype(np.float32)
elif data.dtype == np.int64:
data = data.astype(np.int32)
device = as_device(device).to_c()
return self._wrap(_imperative_rt.make_const(self, data, device, dtype))
......
......@@ -181,10 +181,10 @@ void init_graph_rt(py::module m) {
m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype) {
if (!cn.valid()) {
throw py::type_error("device must not be None");
cn = CompNode::load("xpux");
}
auto hv = npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype);
opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node();
return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node();
});
m.def("input_callback", [input_callback](std::function<DeviceTensorND(void)> callback,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册