提交 08cc1032 编写于 作者: M Megvii Engine Team

fix(imperative): fix persistent_cache

GitOrigin-RevId: 8f7bb5899f91c9350cb7dd9a3a4dffac29784d5d
上级 998f71a8
...@@ -93,7 +93,6 @@ _persistent_cache_impl_ins.reg() ...@@ -93,7 +93,6 @@ _persistent_cache_impl_ins.reg()
atexit.register(_full_sync) atexit.register(_full_sync)
del _set_fork_exec_path_for_timed_func del _set_fork_exec_path_for_timed_func
del _persistent_cache_impl_ins
# subpackages # subpackages
import megengine.autodiff import megengine.autodiff
......
...@@ -366,7 +366,7 @@ namespace detail { ...@@ -366,7 +366,7 @@ namespace detail {
return true; return true;
} }
static handle cast(mgb::PersistentCache::Blob blob, return_value_policy /* policy */, handle /* parent */) { static handle cast(mgb::PersistentCache::Blob blob, return_value_policy /* policy */, handle /* parent */) {
return bytes((const char*)blob.ptr, blob.size); return bytes((const char*)blob.ptr, blob.size).release();
} }
}; };
......
...@@ -421,8 +421,10 @@ PyObject* TensorWrapper::numpy() { ...@@ -421,8 +421,10 @@ PyObject* TensorWrapper::numpy() {
} }
return np_val.release().ptr(); return np_val.release().ptr();
} }
auto&& hv = [&]() {
auto&& hv = interpreter_for_py->get_value(m_tensor->m_handle.get()); py::gil_scoped_release _;
return interpreter_for_py->get_value(m_tensor->m_handle.get());
}();
auto arr = py::reinterpret_steal<py::array>(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE)); auto arr = py::reinterpret_steal<py::array>(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE));
if (!arr) { if (!arr) {
PyErr_SetString(PyExc_ValueError, "tensor invalid"); PyErr_SetString(PyExc_ValueError, "tensor invalid");
...@@ -492,7 +494,10 @@ PyObject* TensorWrapper::_dev_tensor(){ ...@@ -492,7 +494,10 @@ PyObject* TensorWrapper::_dev_tensor(){
if (m_tensor->m_trace_info.recording && !skip_tracing) { if (m_tensor->m_trace_info.recording && !skip_tracing) {
PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "data_read", py::cast(true).release().ptr()); PyObject_SetAttrString(m_tensor->m_trace_info.trace_mixin_info, "data_read", py::cast(true).release().ptr());
} }
auto dev_tensor = interpreter_for_py->get_dev_tensor(m_tensor->m_handle.get()); auto dev_tensor = [&](){
py::gil_scoped_release _;
return interpreter_for_py->get_dev_tensor(m_tensor->m_handle.get());
}();
return py::cast(dev_tensor).release().ptr(); return py::cast(dev_tensor).release().ptr();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册