提交 aa4e8476 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(interpreter): release gil when interpreter blocking

GitOrigin-RevId: c48e9efa5bb70fa39cfbe9ba1182603d1daf112f
上级 bd62a0a6
...@@ -12,7 +12,7 @@ import os ...@@ -12,7 +12,7 @@ import os
import queue import queue
from .. import _exit from .. import _exit
from ..core._imperative_rt.core2 import sync from ..core._imperative_rt.core2 import full_sync
from ..logger import get_logger from ..logger import get_logger
from .group import group_barrier, init_process_group from .group import group_barrier, init_process_group
from .helper import get_device_count_by_fork from .helper import get_device_count_by_fork
...@@ -51,7 +51,7 @@ def _run_wrapped( ...@@ -51,7 +51,7 @@ def _run_wrapped(
group_barrier() group_barrier()
ret = func(*args, **kwargs) ret = func(*args, **kwargs)
queue.put((dev, ret)) queue.put((dev, ret))
sync() full_sync()
if is_multimachine: if is_multimachine:
group_barrier() group_barrier()
_exit(0) _exit(0)
......
...@@ -390,6 +390,7 @@ PyObject* TensorWrapper::shape() { ...@@ -390,6 +390,7 @@ PyObject* TensorWrapper::shape() {
} }
shape = *tshp; shape = *tshp;
} else { } else {
py::gil_scoped_release _;
shape = m_tensor->shape(); shape = m_tensor->shape();
} }
...@@ -899,7 +900,6 @@ void init_tensor(py::module m) { ...@@ -899,7 +900,6 @@ void init_tensor(py::module m) {
} }
static constexpr auto sync_py_task_q = []{ static constexpr auto sync_py_task_q = []{
py::gil_scoped_release _;
py_task_q.wait_all_task_finish(); py_task_q.wait_all_task_finish();
}; };
...@@ -933,7 +933,7 @@ void init_tensor(py::module m) { ...@@ -933,7 +933,7 @@ void init_tensor(py::module m) {
imperative::Profiler::load_options(std::move(options)); imperative::Profiler::load_options(std::move(options));
imperative::Profiler::start_profile(); imperative::Profiler::start_profile();
interpreter_for_py->start_profile(); interpreter_for_py->start_profile();
}); }, py::call_guard<py::gil_scoped_release>());
m.def("stop_profile", m.def("stop_profile",
[]() -> std::function<void(std::string, std::string)> { []() -> std::function<void(std::string, std::string)> {
interpreter_for_py->stop_profile(); interpreter_for_py->stop_profile();
...@@ -944,23 +944,23 @@ void init_tensor(py::module m) { ...@@ -944,23 +944,23 @@ void init_tensor(py::module m) {
return [results=std::move(results), options=std::move(options)](std::string basename, std::string format){ return [results=std::move(results), options=std::move(options)](std::string basename, std::string format){
imperative::Profiler::dump_profile(basename, format, results, options); imperative::Profiler::dump_profile(basename, format, results, options);
}; };
}); }, py::call_guard<py::gil_scoped_release>());
m.def("sync", m.def("sync",
[]() { []() {
interpreter_for_py->sync(); interpreter_for_py->sync();
sync_py_task_q(); sync_py_task_q();
}); }, py::call_guard<py::gil_scoped_release>());
m.def("full_sync", m.def("full_sync",
[]() { []() {
interpreter_for_py->sync(); interpreter_for_py->sync();
CompNode::sync_all(); CompNode::sync_all();
sync_py_task_q(); sync_py_task_q();
}); }, py::call_guard<py::gil_scoped_release>());
m.def("close", m.def("close",
[]() { []() {
interpreter_for_py->close(); interpreter_for_py->close();
sync_py_task_q(); sync_py_task_q();
}); }, py::call_guard<py::gil_scoped_release>());
py::handle grad_key_type = GradKeyWrapper::wrap_t::type() py::handle grad_key_type = GradKeyWrapper::wrap_t::type()
.def<&GradKeyWrapper::attach>("attach") .def<&GradKeyWrapper::attach>("attach")
......
...@@ -112,6 +112,7 @@ Interpreter& Interpreter::inst() { ...@@ -112,6 +112,7 @@ Interpreter& Interpreter::inst() {
} }
Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) { Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state(); auto& state = get_channel_state();
state.scopes.push("Put"); state.scopes.push("Put");
...@@ -126,13 +127,14 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) { ...@@ -126,13 +127,14 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
info->h_value = value; info->h_value = value;
m_buffer.enqueue(Put{info, value, no_cache}); m_buffer.enqueue(Put{info, value, no_cache});
if (m_async_level == 0) { if (m_async_level == 0) {
sync(); sync_impl();
info->desc.comp_node.sync(); info->desc.comp_node.sync();
} }
return info; return info;
} }
Handle ChannelImpl::put(const DeviceTensorND& data) { Handle ChannelImpl::put(const DeviceTensorND& data) {
MGB_LOCK_GUARD(m_spin);
auto& state = get_channel_state(); auto& state = get_channel_state();
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
state.scopes.push("Put"); state.scopes.push("Put");
...@@ -148,9 +150,14 @@ Handle ChannelImpl::put(const DeviceTensorND& data) { ...@@ -148,9 +150,14 @@ Handle ChannelImpl::put(const DeviceTensorND& data) {
} }
void ChannelImpl::del(Handle handle) { void ChannelImpl::del(Handle handle) {
MGB_LOCK_GUARD(m_spin);
if (!check_available()){ if (!check_available()){
return; return;
} }
del_impl(handle);
}
void ChannelImpl::del_impl(Handle handle) {
mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle); mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle);
auto* info = reinterpret_cast<TensorInfo*>(handle); auto* info = reinterpret_cast<TensorInfo*>(handle);
m_valid_handle.erase(handle); m_valid_handle.erase(handle);
...@@ -158,6 +165,7 @@ void ChannelImpl::del(Handle handle) { ...@@ -158,6 +165,7 @@ void ChannelImpl::del(Handle handle) {
} }
void ChannelImpl::swap_in(Handle handle) { void ChannelImpl::swap_in(Handle handle) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state(); auto& state = get_channel_state();
if (state.options.enable_swap) { if (state.options.enable_swap) {
...@@ -169,6 +177,7 @@ void ChannelImpl::swap_in(Handle handle) { ...@@ -169,6 +177,7 @@ void ChannelImpl::swap_in(Handle handle) {
} }
void ChannelImpl::swap_out(Handle handle) { void ChannelImpl::swap_out(Handle handle) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state(); auto& state = get_channel_state();
if (state.options.enable_swap) { if (state.options.enable_swap) {
...@@ -180,6 +189,7 @@ void ChannelImpl::swap_out(Handle handle) { ...@@ -180,6 +189,7 @@ void ChannelImpl::swap_out(Handle handle) {
} }
void ChannelImpl::drop(Handle handle) { void ChannelImpl::drop(Handle handle) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state(); auto& state = get_channel_state();
if (state.options.enable_drop) { if (state.options.enable_drop) {
...@@ -305,9 +315,9 @@ void ChannelImpl::dispatch_kernel( ...@@ -305,9 +315,9 @@ void ChannelImpl::dispatch_kernel(
RECORD_EVENT(OpDispatchEvent, cmd.id, cmd.op->trait()->name, op_info_getter, tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs)); RECORD_EVENT(OpDispatchEvent, cmd.id, cmd.op->trait()->name, op_info_getter, tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs));
m_buffer.enqueue(std::move(cmd)); m_buffer.enqueue(std::move(cmd));
if (!validated && options.async_level == 1) { if (!validated && options.async_level == 1) {
sync(); sync_impl();
} else if (options.async_level == 0) { } else if (options.async_level == 0) {
sync(); sync_impl();
// check device error // check device error
for (auto&& oup : *outputs) { for (auto&& oup : *outputs) {
auto info = reinterpret_cast<TensorInfo*>(oup); auto info = reinterpret_cast<TensorInfo*>(oup);
...@@ -320,6 +330,7 @@ void ChannelImpl::dispatch_kernel( ...@@ -320,6 +330,7 @@ void ChannelImpl::dispatch_kernel(
SmallVector<Handle> ChannelImpl::apply_op( SmallVector<Handle> ChannelImpl::apply_op(
std::shared_ptr<OpDef> op, std::shared_ptr<OpDef> op,
const SmallVector<Handle>& inputs) { const SmallVector<Handle>& inputs) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state(); auto& state = get_channel_state();
for (auto i : inputs) { for (auto i : inputs) {
...@@ -358,6 +369,7 @@ SmallVector<Handle> ChannelImpl::apply_op( ...@@ -358,6 +369,7 @@ SmallVector<Handle> ChannelImpl::apply_op(
} }
HostTensorND ChannelImpl::get_value(Handle handle) { HostTensorND ChannelImpl::get_value(Handle handle) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
...@@ -368,6 +380,7 @@ HostTensorND ChannelImpl::get_value(Handle handle) { ...@@ -368,6 +380,7 @@ HostTensorND ChannelImpl::get_value(Handle handle) {
} }
TensorShape ChannelImpl::get_shape(Handle handle) { TensorShape ChannelImpl::get_shape(Handle handle) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
...@@ -381,6 +394,7 @@ TensorShape ChannelImpl::get_shape(Handle handle) { ...@@ -381,6 +394,7 @@ TensorShape ChannelImpl::get_shape(Handle handle) {
} }
DType ChannelImpl::get_dtype(Handle handle) { DType ChannelImpl::get_dtype(Handle handle) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
...@@ -392,6 +406,7 @@ DType ChannelImpl::get_dtype(Handle handle) { ...@@ -392,6 +406,7 @@ DType ChannelImpl::get_dtype(Handle handle) {
} }
CompNode ChannelImpl::get_device(Handle handle) { CompNode ChannelImpl::get_device(Handle handle) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
...@@ -403,6 +418,7 @@ CompNode ChannelImpl::get_device(Handle handle) { ...@@ -403,6 +418,7 @@ CompNode ChannelImpl::get_device(Handle handle) {
} }
DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
...@@ -411,7 +427,12 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { ...@@ -411,7 +427,12 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
} }
void ChannelImpl::sync() { void ChannelImpl::sync() {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
sync_impl();
}
void ChannelImpl::sync_impl() {
m_buffer.flush(); m_buffer.flush();
m_worker.wait_all_task_finish(); m_worker.wait_all_task_finish();
MGB_LOCK_GUARD(m_mutex); MGB_LOCK_GUARD(m_mutex);
...@@ -419,26 +440,29 @@ void ChannelImpl::sync() { ...@@ -419,26 +440,29 @@ void ChannelImpl::sync() {
} }
void ChannelImpl::close() { void ChannelImpl::close() {
MGB_LOCK_GUARD(m_spin);
if (!check_available()) { if (!check_available()) {
return; return;
} }
std::vector<Handle> valid_handles(m_valid_handle.begin(), m_valid_handle.end()); std::vector<Handle> valid_handles(m_valid_handle.begin(), m_valid_handle.end());
for (auto* handle: valid_handles) { for (auto* handle: valid_handles) {
del(handle); del_impl(handle);
} }
mgb_assert(m_valid_handle.empty()); mgb_assert(m_valid_handle.empty());
mgb_log_debug("%ld tensor exists before channel close", (long)valid_handles.size()); mgb_log_debug("%ld tensor exists before channel close", (long)valid_handles.size());
sync(); sync_impl();
m_closed = true; m_closed = true;
} }
size_t ChannelImpl::get_option(std::string name) { size_t ChannelImpl::get_option(std::string name) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state(); auto& state = get_channel_state();
return state.options.get_option(name); return state.options.get_option(name);
} }
void ChannelImpl::set_option(std::string name, size_t value) { void ChannelImpl::set_option(std::string name, size_t value) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state(); auto& state = get_channel_state();
state.options.set_option(name, value); state.options.set_option(name, value);
...@@ -1096,6 +1120,7 @@ auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range) ...@@ -1096,6 +1120,7 @@ auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range)
} }
void ChannelImpl::start_profile() { void ChannelImpl::start_profile() {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto capture_tensors = collect_valid_tensors(); auto capture_tensors = collect_valid_tensors();
if (capture_tensors.size() > 0) { if (capture_tensors.size() > 0) {
...@@ -1104,6 +1129,7 @@ void ChannelImpl::start_profile() { ...@@ -1104,6 +1129,7 @@ void ChannelImpl::start_profile() {
} }
void ChannelImpl::stop_profile() { void ChannelImpl::stop_profile() {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
m_buffer.flush(); m_buffer.flush();
auto escape_tensors = collect_valid_tensors(); auto escape_tensors = collect_valid_tensors();
...@@ -1113,6 +1139,7 @@ void ChannelImpl::stop_profile() { ...@@ -1113,6 +1139,7 @@ void ChannelImpl::stop_profile() {
} }
void ChannelImpl::push_scope(std::string name) { void ChannelImpl::push_scope(std::string name) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state(); auto& state = get_channel_state();
state.scopes.push(name); state.scopes.push(name);
...@@ -1121,6 +1148,7 @@ void ChannelImpl::push_scope(std::string name) { ...@@ -1121,6 +1148,7 @@ void ChannelImpl::push_scope(std::string name) {
} }
void ChannelImpl::pop_scope(std::string name) { void ChannelImpl::pop_scope(std::string name) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state(); auto& state = get_channel_state();
state.scopes.pop(name); state.scopes.pop(name);
......
...@@ -84,6 +84,8 @@ private: ...@@ -84,6 +84,8 @@ private:
void detach_users(TensorInfo*); void detach_users(TensorInfo*);
TensorInfo* put_impl(const HostTensorND& value, bool no_cache); TensorInfo* put_impl(const HostTensorND& value, bool no_cache);
void del_impl(Handle);
void sync_impl();
TensorPtr wait_tensor(TensorInfo* info, profiler::TensorProp prop); TensorPtr wait_tensor(TensorInfo* info, profiler::TensorProp prop);
void notify_tensor_unsafe(TensorInfo* info); void notify_tensor_unsafe(TensorInfo* info);
...@@ -130,6 +132,7 @@ private: ...@@ -130,6 +132,7 @@ private:
std::unordered_set<TensorInfo*> collect_valid_tensors(); std::unordered_set<TensorInfo*> collect_valid_tensors();
std::mutex m_mutex; std::mutex m_mutex;
Spinlock m_spin;
std::condition_variable m_cv; std::condition_variable m_cv;
MemPool<TensorInfo> m_pool; MemPool<TensorInfo> m_pool;
std::unordered_set<Handle> m_valid_handle; std::unordered_set<Handle> m_valid_handle;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册