提交 b8c7557b 编写于 作者: M Megvii Engine Team

fix(mm): fix mm error when use sync

GitOrigin-RevId: 63387bda049e51b0a7c65b9a9a0ebd746a713446
上级 73ad06ba
......@@ -162,6 +162,12 @@ def _check_device_initialized(device_type: str, rank: int):
raise RuntimeError(errmsg)
def _check_interpreter_status():
from ..core._imperative_rt.core2 import get_option
_ = get_option("async_level")
get_device_count_by_fork = deprecated_func(
"1.5", "megengine.device", "get_device_count", False
)
......
......@@ -9,7 +9,7 @@ from ..core._imperative_rt.core2 import full_sync
from ..device import get_device_count
from ..logger import get_logger
from .group import _set_machine_ranks, group_barrier, init_process_group
from .helper import _check_device_initialized
from .helper import _check_device_initialized, _check_interpreter_status
from .server import Client, Server
WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = (
......@@ -33,6 +33,7 @@ def _run_wrapped(
machine_ranks: list,
):
r"""Init distributed process group and run wrapped function."""
_check_interpreter_status()
_check_device_initialized(device_type, dev)
init_process_group(
master_ip=master_ip,
......
......@@ -115,7 +115,18 @@ void ChannelImpl::WorkQueue::on_async_queue_worker_thread_start() {
#define m_worker_state
std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() {
return std::make_unique<ChannelImpl>();
auto ret = std::make_unique<ChannelImpl>();
#if !(defined(_WIN32) || defined(_WIN64))
auto disable_channels = [](void) -> void {
for (ChannelImpl* channel : ChannelImpl::m_all_active_channels) {
if (channel->worker_started()) {
channel->update_status_to_forked();
}
}
};
pthread_atfork(nullptr, nullptr, static_cast<void (*)(void)>(disable_channels));
#endif
return ret;
}
Interpreter& Interpreter::inst() {
......@@ -125,7 +136,7 @@ Interpreter& Interpreter::inst() {
Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
std::optional<StackManager::Guard> guard;
if (Profiler::is_profiling()) {
auto& state = get_channel_state();
......@@ -158,7 +169,8 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
Put{info, value, no_cache},
});
}
if (m_async_level == 0) {
if (get_channel_state().options.async_level == 0) {
sync_impl();
info->desc.comp_node.sync();
auto err = info->desc.comp_node.check_async_error();
......@@ -169,7 +181,7 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
Handle ChannelImpl::put(const DeviceTensorND& data, const HostTensorND& hvalue) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
return reinterpret_cast<Handle>(put_impl(data, hvalue));
}
TensorInfo* ChannelImpl::put_impl(
......@@ -221,7 +233,7 @@ void ChannelImpl::del_impl(Handle handle) {
void ChannelImpl::drop(Handle handle) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
auto& state = get_channel_state();
if (state.options.enable_drop) {
mgb_assert(
......@@ -404,7 +416,7 @@ void ChannelImpl::dispatch_kernel(
SmallVector<Handle> ChannelImpl::apply_op(
std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
auto* input = reinterpret_cast<TensorInfo*>(inputs[0]);
if (op->same_type<GetVarShape>() && input->shape_valid()) {
size_t ndim = input->desc.layout.ndim;
......@@ -460,7 +472,7 @@ SmallVector<Handle> ChannelImpl::apply_op_impl(
HostTensorND ChannelImpl::get_value(Handle handle) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
mgb_assert(
m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
handle);
......@@ -472,7 +484,7 @@ HostTensorND ChannelImpl::get_value(Handle handle) {
TensorShape ChannelImpl::get_shape(Handle handle) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
mgb_assert(
m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
handle);
......@@ -487,7 +499,7 @@ TensorShape ChannelImpl::get_shape(Handle handle) {
DType ChannelImpl::get_dtype(Handle handle) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
mgb_assert(
m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
handle);
......@@ -500,7 +512,7 @@ DType ChannelImpl::get_dtype(Handle handle) {
CompNode ChannelImpl::get_device(Handle handle) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
mgb_assert(
m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
handle);
......@@ -513,7 +525,7 @@ CompNode ChannelImpl::get_device(Handle handle) {
DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
mgb_assert(
m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
handle);
......@@ -523,7 +535,7 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
void ChannelImpl::sync() {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
sync_impl();
}
......@@ -545,19 +557,19 @@ void ChannelImpl::close() {
mgb_assert(m_valid_handle.empty());
mgb_log_debug("%ld tensor exists before channel close", (long)valid_handles.size());
sync_impl();
m_closed = true;
m_status = ChannelRunningStatus::CLOSED;
}
size_t ChannelImpl::get_option(std::string name) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
auto& state = get_channel_state();
return state.options.get_option(name);
}
void ChannelImpl::set_option(std::string name, size_t value) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
auto& state = get_channel_state();
state.options.set_option(name, value);
// FIXME
......@@ -583,7 +595,7 @@ void ChannelImpl::set_option(std::string name, size_t value) {
void ChannelImpl::clear_candidates() {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
m_dtr.candidates.clear();
}
......@@ -681,10 +693,18 @@ void ChannelImpl::real_free(TensorInfo* ptr) {
m_pool.free(ptr);
}
ChannelImpl::ChannelImpl() : m_worker(this) {}
std::unordered_set<ChannelImpl*> ChannelImpl::m_all_active_channels{};
MGB_MUTEX ChannelImpl::m_all_active_channels_mutex{};
ChannelImpl::ChannelImpl() : m_worker(this) {
MGB_LOCK_GUARD(m_all_active_channels_mutex);
m_all_active_channels.emplace(this);
}
ChannelImpl::~ChannelImpl() {
close();
MGB_LOCK_GUARD(m_all_active_channels_mutex);
m_all_active_channels.erase(this);
}
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
......@@ -992,7 +1012,7 @@ void ChannelImpl::detach_users(TensorInfo* dest) {
}
bool ChannelImpl::check_available() {
return !m_closed;
return m_status == ChannelRunningStatus::RUNING;
}
TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
......@@ -1352,7 +1372,7 @@ void ChannelImpl::check_worker_exc_unsafe() {
void ChannelImpl::start_profile() {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
auto capture_tensors = collect_valid_tensors();
if (capture_tensors.size() > 0) {
if (Profiler::is_profiling()) {
......@@ -1370,7 +1390,7 @@ void ChannelImpl::start_profile() {
void ChannelImpl::stop_profile() {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
auto escape_tensors = collect_valid_tensors();
if (escape_tensors.size() > 0) {
if (Profiler::is_profiling()) {
......@@ -1388,7 +1408,7 @@ void ChannelImpl::stop_profile() {
void ChannelImpl::push_scope(std::string name) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
auto& state = get_channel_state();
state.stack_manager.enter(name);
MGB_RECORD_EVENT(ScopeEvent, name);
......@@ -1406,7 +1426,7 @@ void ChannelImpl::push_scope(std::string name) {
void ChannelImpl::pop_scope(std::string name) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
auto& state = get_channel_state();
state.stack_manager.exit(name);
MGB_RECORD_EVENT(ScopeFinishEvent, name);
......@@ -1422,6 +1442,31 @@ void ChannelImpl::pop_scope(std::string name) {
}
}
bool ChannelImpl::worker_started() const {
return m_worker.worker_started();
}
void ChannelImpl::update_status_to_forked(void) {
MGB_LOCK_GUARD(m_spin);
m_status = ChannelRunningStatus::FORKED;
}
void ChannelImpl::assert_available() const {
if (m_status == ChannelRunningStatus::RUNING) {
return;
} else if (m_status == ChannelRunningStatus::CLOSED) {
mgb_assert(false, "Channel already closed");
} else if (m_status == ChannelRunningStatus::FORKED) {
mgb_assert(
false,
"your program is forked and megengine is be disabled in subprocess, if "
"you want to use megengine in subprocess, please DO NOT setup and use "
"megengine before fork");
} else {
mgb_assert(false, "impossible, Channel status is undefined");
}
}
void ChannelImpl::assert_in_channel() {
mgb_assert(
get_worker_tid() != std::this_thread::get_id(),
......
......@@ -27,7 +27,7 @@ struct InterpreterImpl : Interpreter {
std::unique_ptr<Channel> create_channel() override;
};
struct ChannelImpl : Interpreter::Channel {
struct ChannelImpl : Interpreter::Channel, NonCopyableObj, NonMoveableObj {
ChannelImpl();
~ChannelImpl() override;
......@@ -61,6 +61,13 @@ struct ChannelImpl : Interpreter::Channel {
void push_scope(std::string) override;
void pop_scope(std::string) override;
bool worker_started() const;
void update_status_to_forked(void);
void assert_available() const;
static std::unordered_set<ChannelImpl*> m_all_active_channels;
static MGB_MUTEX m_all_active_channels_mutex;
private:
struct WorkQueue;
struct State;
......@@ -130,7 +137,9 @@ private:
// TODO: use explicit struct
std::stack<std::tuple<ApplyOp, size_t, TensorInfo*, std::string>> m_apply_stack;
bool m_applying = false;
bool m_closed = false;
enum class ChannelRunningStatus { RUNING, CLOSED, FORKED };
ChannelRunningStatus m_status = ChannelRunningStatus::RUNING;
struct WorkQueue : AsyncQueueSC<Command, WorkQueue> {
// set max_spin=0 to prevent Queue fetch task in busy wait manner.
......@@ -159,12 +168,6 @@ private:
ChannelImpl* m_owner;
} m_worker;
//! config whether raise error exactly when invoking op.
//! level 2: both device and user side errors are async;
//! level 1: user side errors are sync;
//! level 0: both sync.
int m_async_level = 2;
struct State {
std::thread::id tid;
OptionManager options;
......
......@@ -60,6 +60,9 @@ struct Interpreter {
virtual std::unique_ptr<Channel> create_channel() = 0;
static Interpreter& inst();
protected:
Interpreter() = default;
};
} // namespace mgb::imperative::interpreter
......@@ -151,6 +151,17 @@ public:
NonCopyableObj() = default;
};
/*!
* \brief base class for non-moveable objects
*/
class NonMoveableObj {
NonMoveableObj(NonMoveableObj&&) = delete;
NonMoveableObj& operator=(NonMoveableObj&&) = delete;
public:
NonMoveableObj() = default;
};
template <typename T>
class ReverseAdaptor {
T& m_t;
......
......@@ -253,6 +253,8 @@ public:
}
}
inline bool worker_started() const { return m_synchronizer.worker_started(); }
protected:
~AsyncQueueSC() noexcept = default;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册