diff --git a/imperative/python/megengine/distributed/helper.py b/imperative/python/megengine/distributed/helper.py index 0a6871a04ebf590ab4f4cb40fde5a77f5d7d59b7..2530438b247bb7ffcd971338979458efd9172d1b 100644 --- a/imperative/python/megengine/distributed/helper.py +++ b/imperative/python/megengine/distributed/helper.py @@ -162,6 +162,12 @@ def _check_device_initialized(device_type: str, rank: int): raise RuntimeError(errmsg) +def _check_interpreter_status(): + from ..core._imperative_rt.core2 import get_option + + _ = get_option("async_level") + + get_device_count_by_fork = deprecated_func( "1.5", "megengine.device", "get_device_count", False ) diff --git a/imperative/python/megengine/distributed/launcher.py b/imperative/python/megengine/distributed/launcher.py index 279d0359f6d5cd6994fd4c344030efc73d1ffa2a..4a8625431afc8f7e8faa3dae131d721d5ca0660c 100644 --- a/imperative/python/megengine/distributed/launcher.py +++ b/imperative/python/megengine/distributed/launcher.py @@ -9,7 +9,7 @@ from ..core._imperative_rt.core2 import full_sync from ..device import get_device_count from ..logger import get_logger from .group import _set_machine_ranks, group_barrier, init_process_group -from .helper import _check_device_initialized +from .helper import _check_device_initialized, _check_interpreter_status from .server import Client, Server WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = ( @@ -33,6 +33,7 @@ def _run_wrapped( machine_ranks: list, ): r"""Init distributed process group and run wrapped function.""" + _check_interpreter_status() _check_device_initialized(device_type, dev) init_process_group( master_ip=master_ip, diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index e71187f63ece08ec016e5fc0d71bef8c9404c718..0f3df552967df34d72ed993f78b8a2f319db9152 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -115,7 +115,18 @@ void ChannelImpl::WorkQueue::on_async_queue_worker_thread_start() { #define m_worker_state std::unique_ptr InterpreterImpl::create_channel() { - return std::make_unique(); + auto ret = std::make_unique(); +#if !(defined(_WIN32) || defined(_WIN64)) + auto disable_channels = [](void) -> void { + for (ChannelImpl* channel : ChannelImpl::m_all_active_channels) { + if (channel->worker_started()) { + channel->update_status_to_forked(); + } + } + }; + pthread_atfork(nullptr, nullptr, static_cast(disable_channels)); +#endif + return ret; } Interpreter& Interpreter::inst() { @@ -125,7 +136,7 @@ Interpreter& Interpreter::inst() { Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) { MGB_LOCK_GUARD(m_spin); - mgb_assert(check_available(), "Channel already closed"); + assert_available(); std::optional guard; if (Profiler::is_profiling()) { auto& state = get_channel_state(); @@ -158,7 +169,8 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) { Put{info, value, no_cache}, }); } - if (m_async_level == 0) { + + if (get_channel_state().options.async_level == 0) { sync_impl(); info->desc.comp_node.sync(); auto err = info->desc.comp_node.check_async_error(); @@ -169,7 +181,7 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) { Handle ChannelImpl::put(const DeviceTensorND& data, const HostTensorND& hvalue) { MGB_LOCK_GUARD(m_spin); - mgb_assert(check_available(), "Channel already closed"); + assert_available(); return reinterpret_cast(put_impl(data, hvalue)); } TensorInfo* ChannelImpl::put_impl( @@ -221,7 +233,7 @@ void ChannelImpl::del_impl(Handle handle) { void ChannelImpl::drop(Handle handle) { MGB_LOCK_GUARD(m_spin); - mgb_assert(check_available(), "Channel already closed"); + assert_available(); auto& state = get_channel_state(); if (state.options.enable_drop) { mgb_assert( @@ -404,7 +416,7 @@ void ChannelImpl::dispatch_kernel( SmallVector ChannelImpl::apply_op( std::shared_ptr op, const SmallVector& inputs) { MGB_LOCK_GUARD(m_spin); - mgb_assert(check_available(), "Channel already closed"); + assert_available(); auto* input = reinterpret_cast(inputs[0]); if (op->same_type() && input->shape_valid()) { size_t ndim = input->desc.layout.ndim; @@ -460,7 +472,7 @@ SmallVector ChannelImpl::apply_op_impl( HostTensorND ChannelImpl::get_value(Handle handle) { MGB_LOCK_GUARD(m_spin); - mgb_assert(check_available(), "Channel already closed"); + assert_available(); mgb_assert( m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); @@ -472,7 +484,7 @@ HostTensorND ChannelImpl::get_value(Handle handle) { TensorShape ChannelImpl::get_shape(Handle handle) { MGB_LOCK_GUARD(m_spin); - mgb_assert(check_available(), "Channel already closed"); + assert_available(); mgb_assert( m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); @@ -487,7 +499,7 @@ TensorShape ChannelImpl::get_shape(Handle handle) { DType ChannelImpl::get_dtype(Handle handle) { MGB_LOCK_GUARD(m_spin); - mgb_assert(check_available(), "Channel already closed"); + assert_available(); mgb_assert( m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); @@ -500,7 +512,7 @@ DType ChannelImpl::get_dtype(Handle handle) { CompNode ChannelImpl::get_device(Handle handle) { MGB_LOCK_GUARD(m_spin); - mgb_assert(check_available(), "Channel already closed"); + assert_available(); mgb_assert( m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); @@ -513,7 +525,7 @@ CompNode ChannelImpl::get_device(Handle handle) { DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { MGB_LOCK_GUARD(m_spin); - mgb_assert(check_available(), "Channel already closed"); + assert_available(); mgb_assert( m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); @@ -523,7 +535,7 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { void ChannelImpl::sync() { MGB_LOCK_GUARD(m_spin); - mgb_assert(check_available(), "Channel already closed"); + assert_available(); sync_impl(); } @@ -545,19 +557,19 @@ void ChannelImpl::close() { mgb_assert(m_valid_handle.empty()); mgb_log_debug("%ld tensor exists before channel close", (long)valid_handles.size()); sync_impl(); - m_closed = true; + m_status = ChannelRunningStatus::CLOSED; } size_t ChannelImpl::get_option(std::string name) { MGB_LOCK_GUARD(m_spin); - mgb_assert(check_available(), "Channel already closed"); + assert_available(); auto& state = get_channel_state(); return state.options.get_option(name); } void ChannelImpl::set_option(std::string name, size_t value) { MGB_LOCK_GUARD(m_spin); - mgb_assert(check_available(), "Channel already closed"); + assert_available(); auto& state = get_channel_state(); state.options.set_option(name, value); // FIXME @@ -583,7 +595,7 @@ void ChannelImpl::set_option(std::string name, size_t value) { void ChannelImpl::clear_candidates() { MGB_LOCK_GUARD(m_spin); - mgb_assert(check_available(), "Channel already closed"); + assert_available(); m_dtr.candidates.clear(); } @@ -681,10 +693,18 @@ void ChannelImpl::real_free(TensorInfo* ptr) { m_pool.free(ptr); } -ChannelImpl::ChannelImpl() : m_worker(this) {} +std::unordered_set ChannelImpl::m_all_active_channels{}; +MGB_MUTEX ChannelImpl::m_all_active_channels_mutex{}; + +ChannelImpl::ChannelImpl() : m_worker(this) { + MGB_LOCK_GUARD(m_all_active_channels_mutex); + m_all_active_channels.emplace(this); +} ChannelImpl::~ChannelImpl() { close(); + MGB_LOCK_GUARD(m_all_active_channels_mutex); + m_all_active_channels.erase(this); } void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { @@ -992,7 +1012,7 @@ void ChannelImpl::detach_users(TensorInfo* dest) { } bool ChannelImpl::check_available() { - return !m_closed; + return m_status == ChannelRunningStatus::RUNING; } TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { @@ -1352,7 +1372,7 @@ void ChannelImpl::check_worker_exc_unsafe() { void ChannelImpl::start_profile() { MGB_LOCK_GUARD(m_spin); - mgb_assert(check_available(), "Channel already closed"); + assert_available(); auto capture_tensors = collect_valid_tensors(); if (capture_tensors.size() > 0) { if (Profiler::is_profiling()) { @@ -1370,7 +1390,7 @@ void ChannelImpl::start_profile() { void ChannelImpl::stop_profile() { MGB_LOCK_GUARD(m_spin); - mgb_assert(check_available(), "Channel already closed"); + assert_available(); auto escape_tensors = collect_valid_tensors(); if (escape_tensors.size() > 0) { if (Profiler::is_profiling()) { @@ -1388,7 +1408,7 @@ void ChannelImpl::stop_profile() { void ChannelImpl::push_scope(std::string name) { MGB_LOCK_GUARD(m_spin); - mgb_assert(check_available(), "Channel already closed"); + assert_available(); auto& state = get_channel_state(); state.stack_manager.enter(name); MGB_RECORD_EVENT(ScopeEvent, name); @@ -1406,7 +1426,7 @@ void ChannelImpl::push_scope(std::string name) { void ChannelImpl::pop_scope(std::string name) { MGB_LOCK_GUARD(m_spin); - mgb_assert(check_available(), "Channel already closed"); + assert_available(); auto& state = get_channel_state(); state.stack_manager.exit(name); MGB_RECORD_EVENT(ScopeFinishEvent, name); @@ -1422,6 +1442,31 @@ void ChannelImpl::pop_scope(std::string name) { } } +bool ChannelImpl::worker_started() const { + return m_worker.worker_started(); +} + +void ChannelImpl::update_status_to_forked(void) { + MGB_LOCK_GUARD(m_spin); + m_status = ChannelRunningStatus::FORKED; +} + +void ChannelImpl::assert_available() const { + if (m_status == ChannelRunningStatus::RUNING) { + return; + } else if (m_status == ChannelRunningStatus::CLOSED) { + mgb_assert(false, "Channel already closed"); + } else if (m_status == ChannelRunningStatus::FORKED) { + mgb_assert( + false, + "your program is forked and megengine is be disabled in subprocess, if " + "you want to use megengine in subprocess, please DO NOT setup and use " + "megengine before fork"); + } else { + mgb_assert(false, "impossible, Channel status is undefined"); + } +} + void ChannelImpl::assert_in_channel() { mgb_assert( get_worker_tid() != std::this_thread::get_id(), diff --git a/imperative/src/impl/interpreter/interpreter_impl.h b/imperative/src/impl/interpreter/interpreter_impl.h index e5ae45bd40495c6449eaacb45520e3da5137686e..d068a0a1afcab0fb0eab5f1ff47049e0d6c83540 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.h +++ b/imperative/src/impl/interpreter/interpreter_impl.h @@ -27,7 +27,7 @@ struct InterpreterImpl : Interpreter { std::unique_ptr create_channel() override; }; -struct ChannelImpl : Interpreter::Channel { +struct ChannelImpl : Interpreter::Channel, NonCopyableObj, NonMoveableObj { ChannelImpl(); ~ChannelImpl() override; @@ -61,6 +61,13 @@ struct ChannelImpl : Interpreter::Channel { void push_scope(std::string) override; void pop_scope(std::string) override; + bool worker_started() const; + void update_status_to_forked(void); + void assert_available() const; + + static std::unordered_set m_all_active_channels; + static MGB_MUTEX m_all_active_channels_mutex; + private: struct WorkQueue; struct State; @@ -130,7 +137,9 @@ private: // TODO: use explicit struct std::stack> m_apply_stack; bool m_applying = false; - bool m_closed = false; + + enum class ChannelRunningStatus { RUNING, CLOSED, FORKED }; + ChannelRunningStatus m_status = ChannelRunningStatus::RUNING; struct WorkQueue : AsyncQueueSC { // set max_spin=0 to prevent Queue fetch task in busy wait manner. @@ -159,12 +168,6 @@ private: ChannelImpl* m_owner; } m_worker; - //! config whether raise error exactly when invoking op. - //! level 2: both device and user side errors are async; - //! level 1: user side errors are sync; - //! level 0: both sync. - int m_async_level = 2; - struct State { std::thread::id tid; OptionManager options; diff --git a/imperative/src/include/megbrain/imperative/interpreter.h b/imperative/src/include/megbrain/imperative/interpreter.h index 924726014c13c779d2851df65d6067208f609c6e..61f73257941ec69bf72e2e83f49829e8e498bdcc 100644 --- a/imperative/src/include/megbrain/imperative/interpreter.h +++ b/imperative/src/include/megbrain/imperative/interpreter.h @@ -60,6 +60,9 @@ struct Interpreter { virtual std::unique_ptr create_channel() = 0; static Interpreter& inst(); + +protected: + Interpreter() = default; }; } // namespace mgb::imperative::interpreter diff --git a/src/core/include/megbrain/utils/metahelper.h b/src/core/include/megbrain/utils/metahelper.h index 4082edc1fdc106ad829a8fe49514e69d49321b81..c493f55ba519e86a082bfbb3bf95a62ec9b7bfc9 100644 --- a/src/core/include/megbrain/utils/metahelper.h +++ b/src/core/include/megbrain/utils/metahelper.h @@ -151,6 +151,17 @@ public: NonCopyableObj() = default; }; +/*! + * \brief base class for non-moveable objects + */ +class NonMoveableObj { + NonMoveableObj(NonMoveableObj&&) = delete; + NonMoveableObj& operator=(NonMoveableObj&&) = delete; + +public: + NonMoveableObj() = default; +}; + template class ReverseAdaptor { T& m_t; diff --git a/src/core/include/megbrain/utils/thread_impl_1.h b/src/core/include/megbrain/utils/thread_impl_1.h index 2b66bc7cbc7c09c59879a845e76acdb14147e5e6..dec54ef3bfb4e34d9005767b38d94f908045f024 100644 --- a/src/core/include/megbrain/utils/thread_impl_1.h +++ b/src/core/include/megbrain/utils/thread_impl_1.h @@ -253,6 +253,8 @@ public: } } + inline bool worker_started() const { return m_synchronizer.worker_started(); } + protected: ~AsyncQueueSC() noexcept = default;