提交 08cc1032 编写于 作者: M Megvii Engine Team

fix(imperative): fix persistent_cache

GitOrigin-RevId: 8f7bb5899f91c9350cb7dd9a3a4dffac29784d5d
上级 998f71a8
......@@ -93,7 +93,6 @@ _persistent_cache_impl_ins.reg()
atexit.register(_full_sync)
del _set_fork_exec_path_for_timed_func
del _persistent_cache_impl_ins
# subpackages
import megengine.autodiff
......
......@@ -366,7 +366,7 @@ namespace detail {
return true;
}
static handle cast(mgb::PersistentCache::Blob blob, return_value_policy /* policy */, handle /* parent */) {
return bytes((const char*)blob.ptr, blob.size);
return bytes((const char*)blob.ptr, blob.size).release();
}
};
......
......@@ -421,8 +421,10 @@ PyObject* TensorWrapper::numpy() {
}
return np_val.release().ptr();
}
auto&& hv = interpreter_for_py->get_value(m_tensor->m_handle.get());
auto&& hv = [&]() {
py::gil_scoped_release _;
return interpreter_for_py->get_value(m_tensor->m_handle.get());
}();
auto arr = py::reinterpret_steal<py::array>(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE));
if (!arr) {
PyErr_SetString(PyExc_ValueError, "tensor invalid");
......@@ -492,7 +494,10 @@ PyObject* TensorWrapper::_dev_tensor(){
if (m_tensor->m_trace_info.recording && !skip_tracing) {
PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "data_read", py::cast(true).release().ptr());
}
auto dev_tensor = interpreter_for_py->get_dev_tensor(m_tensor->m_handle.get());
auto dev_tensor = [&](){
py::gil_scoped_release _;
return interpreter_for_py->get_dev_tensor(m_tensor->m_handle.get());
}();
return py::cast(dev_tensor).release().ptr();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册