提交 aa4e8476 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(interpreter): release gil when interpreter blocking

GitOrigin-RevId: c48e9efa5bb70fa39cfbe9ba1182603d1daf112f
上级 bd62a0a6
......@@ -12,7 +12,7 @@ import os
import queue
from .. import _exit
from ..core._imperative_rt.core2 import sync
from ..core._imperative_rt.core2 import full_sync
from ..logger import get_logger
from .group import group_barrier, init_process_group
from .helper import get_device_count_by_fork
......@@ -51,7 +51,7 @@ def _run_wrapped(
group_barrier()
ret = func(*args, **kwargs)
queue.put((dev, ret))
sync()
full_sync()
if is_multimachine:
group_barrier()
_exit(0)
......
......@@ -390,6 +390,7 @@ PyObject* TensorWrapper::shape() {
}
shape = *tshp;
} else {
py::gil_scoped_release _;
shape = m_tensor->shape();
}
......@@ -899,7 +900,6 @@ void init_tensor(py::module m) {
}
static constexpr auto sync_py_task_q = []{
py::gil_scoped_release _;
py_task_q.wait_all_task_finish();
};
......@@ -933,7 +933,7 @@ void init_tensor(py::module m) {
imperative::Profiler::load_options(std::move(options));
imperative::Profiler::start_profile();
interpreter_for_py->start_profile();
});
}, py::call_guard<py::gil_scoped_release>());
m.def("stop_profile",
[]() -> std::function<void(std::string, std::string)> {
interpreter_for_py->stop_profile();
......@@ -944,23 +944,23 @@ void init_tensor(py::module m) {
return [results=std::move(results), options=std::move(options)](std::string basename, std::string format){
imperative::Profiler::dump_profile(basename, format, results, options);
};
});
}, py::call_guard<py::gil_scoped_release>());
m.def("sync",
[]() {
interpreter_for_py->sync();
sync_py_task_q();
});
}, py::call_guard<py::gil_scoped_release>());
m.def("full_sync",
[]() {
interpreter_for_py->sync();
CompNode::sync_all();
sync_py_task_q();
});
}, py::call_guard<py::gil_scoped_release>());
m.def("close",
[]() {
interpreter_for_py->close();
sync_py_task_q();
});
}, py::call_guard<py::gil_scoped_release>());
py::handle grad_key_type = GradKeyWrapper::wrap_t::type()
.def<&GradKeyWrapper::attach>("attach")
......
......@@ -112,6 +112,7 @@ Interpreter& Interpreter::inst() {
}
Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
state.scopes.push("Put");
......@@ -126,13 +127,14 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
info->h_value = value;
m_buffer.enqueue(Put{info, value, no_cache});
if (m_async_level == 0) {
sync();
sync_impl();
info->desc.comp_node.sync();
}
return info;
}
Handle ChannelImpl::put(const DeviceTensorND& data) {
MGB_LOCK_GUARD(m_spin);
auto& state = get_channel_state();
mgb_assert(check_available(), "Channel already closed");
state.scopes.push("Put");
......@@ -148,9 +150,14 @@ Handle ChannelImpl::put(const DeviceTensorND& data) {
}
void ChannelImpl::del(Handle handle) {
MGB_LOCK_GUARD(m_spin);
if (!check_available()){
return;
}
del_impl(handle);
}
void ChannelImpl::del_impl(Handle handle) {
mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle);
auto* info = reinterpret_cast<TensorInfo*>(handle);
m_valid_handle.erase(handle);
......@@ -158,6 +165,7 @@ void ChannelImpl::del(Handle handle) {
}
void ChannelImpl::swap_in(Handle handle) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
if (state.options.enable_swap) {
......@@ -169,6 +177,7 @@ void ChannelImpl::swap_in(Handle handle) {
}
void ChannelImpl::swap_out(Handle handle) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
if (state.options.enable_swap) {
......@@ -180,6 +189,7 @@ void ChannelImpl::swap_out(Handle handle) {
}
void ChannelImpl::drop(Handle handle) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
if (state.options.enable_drop) {
......@@ -305,9 +315,9 @@ void ChannelImpl::dispatch_kernel(
RECORD_EVENT(OpDispatchEvent, cmd.id, cmd.op->trait()->name, op_info_getter, tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs));
m_buffer.enqueue(std::move(cmd));
if (!validated && options.async_level == 1) {
sync();
sync_impl();
} else if (options.async_level == 0) {
sync();
sync_impl();
// check device error
for (auto&& oup : *outputs) {
auto info = reinterpret_cast<TensorInfo*>(oup);
......@@ -320,6 +330,7 @@ void ChannelImpl::dispatch_kernel(
SmallVector<Handle> ChannelImpl::apply_op(
std::shared_ptr<OpDef> op,
const SmallVector<Handle>& inputs) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
for (auto i : inputs) {
......@@ -358,6 +369,7 @@ SmallVector<Handle> ChannelImpl::apply_op(
}
HostTensorND ChannelImpl::get_value(Handle handle) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
......@@ -368,6 +380,7 @@ HostTensorND ChannelImpl::get_value(Handle handle) {
}
TensorShape ChannelImpl::get_shape(Handle handle) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
......@@ -381,6 +394,7 @@ TensorShape ChannelImpl::get_shape(Handle handle) {
}
DType ChannelImpl::get_dtype(Handle handle) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
......@@ -392,6 +406,7 @@ DType ChannelImpl::get_dtype(Handle handle) {
}
CompNode ChannelImpl::get_device(Handle handle) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
......@@ -403,6 +418,7 @@ CompNode ChannelImpl::get_device(Handle handle) {
}
DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
......@@ -411,7 +427,12 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
}
void ChannelImpl::sync() {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
sync_impl();
}
void ChannelImpl::sync_impl() {
m_buffer.flush();
m_worker.wait_all_task_finish();
MGB_LOCK_GUARD(m_mutex);
......@@ -419,26 +440,29 @@ void ChannelImpl::sync() {
}
void ChannelImpl::close() {
MGB_LOCK_GUARD(m_spin);
if (!check_available()) {
return;
}
std::vector<Handle> valid_handles(m_valid_handle.begin(), m_valid_handle.end());
for (auto* handle: valid_handles) {
del(handle);
del_impl(handle);
}
mgb_assert(m_valid_handle.empty());
mgb_log_debug("%ld tensor exists before channel close", (long)valid_handles.size());
sync();
sync_impl();
m_closed = true;
}
size_t ChannelImpl::get_option(std::string name) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
return state.options.get_option(name);
}
void ChannelImpl::set_option(std::string name, size_t value) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
state.options.set_option(name, value);
......@@ -1096,6 +1120,7 @@ auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range)
}
void ChannelImpl::start_profile() {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
auto capture_tensors = collect_valid_tensors();
if (capture_tensors.size() > 0) {
......@@ -1104,6 +1129,7 @@ void ChannelImpl::start_profile() {
}
void ChannelImpl::stop_profile() {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
m_buffer.flush();
auto escape_tensors = collect_valid_tensors();
......@@ -1113,6 +1139,7 @@ void ChannelImpl::stop_profile() {
}
void ChannelImpl::push_scope(std::string name) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
state.scopes.push(name);
......@@ -1121,6 +1148,7 @@ void ChannelImpl::push_scope(std::string name) {
}
void ChannelImpl::pop_scope(std::string name) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
state.scopes.pop(name);
......
......@@ -84,6 +84,8 @@ private:
void detach_users(TensorInfo*);
TensorInfo* put_impl(const HostTensorND& value, bool no_cache);
void del_impl(Handle);
void sync_impl();
TensorPtr wait_tensor(TensorInfo* info, profiler::TensorProp prop);
void notify_tensor_unsafe(TensorInfo* info);
......@@ -130,6 +132,7 @@ private:
std::unordered_set<TensorInfo*> collect_valid_tensors();
std::mutex m_mutex;
Spinlock m_spin;
std::condition_variable m_cv;
MemPool<TensorInfo> m_pool;
std::unordered_set<Handle> m_valid_handle;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册