From e6a8b0256e77b741d16c344b721b36ba9d3819de Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 23 Aug 2021 21:41:39 +0800 Subject: [PATCH] fix(mge): ignore errors caused by earlier async errors GitOrigin-RevId: ce2028d38acba75cf7e31ed5eac7de38f3204b45 --- .../python/megengine/functional/utils.py | 8 ++++++ .../python/test/unit/core/test_interpreter.py | 8 ++++++ .../src/impl/interpreter/interpreter_impl.cpp | 28 +++++++++++++++---- .../include/megbrain/imperative/interpreter.h | 11 ++++++++ 4 files changed, 49 insertions(+), 6 deletions(-) diff --git a/imperative/python/megengine/functional/utils.py b/imperative/python/megengine/functional/utils.py index 35b14f9e7..fbab36459 100644 --- a/imperative/python/megengine/functional/utils.py +++ b/imperative/python/megengine/functional/utils.py @@ -12,6 +12,7 @@ from ..core.ops.builtin import AssertEqual from ..tensor import Tensor from ..utils.deprecation import deprecated_func from .elemwise import abs, maximum, minimum +from .tensor import ones, zeros __all__ = ["topk_accuracy"] @@ -59,6 +60,13 @@ def _assert_equal( return result +def _simulate_error(): + x1 = zeros(100) + x2 = ones(100) + (ret,) = apply(AssertEqual(maxerr=0, verbose=False), x1, x2, x2) + return ret + + topk_accuracy = deprecated_func( "1.3", "megengine.functional.metric", "topk_accuracy", True ) diff --git a/imperative/python/test/unit/core/test_interpreter.py b/imperative/python/test/unit/core/test_interpreter.py index 07db2a2a6..a78a898c5 100644 --- a/imperative/python/test/unit/core/test_interpreter.py +++ b/imperative/python/test/unit/core/test_interpreter.py @@ -90,3 +90,11 @@ with megengine.core.option("enable_host_compute", 0): y.numpy() """ subprocess.check_call([sys.executable, "-c", prog]) + + +def test_regression_2870(): + x = F.zeros(1000) + y = F.utils._simulate_error() + with pytest.raises(RuntimeError): + y.numpy() + (x + x).numpy() diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index e2507497f..85256d67f 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -373,7 +373,7 @@ SmallVector ChannelImpl::apply_op_impl( MGB_LOCK_GUARD(m_mutex); for (auto i : inputs) { auto info = reinterpret_cast(i); - mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!"); + mgb_assert(!info->invalid, "an input tensor is unusable due to previous error"); input_infos.push_back(info); input_descs.push_back(info->desc); } @@ -403,7 +403,7 @@ HostTensorND ChannelImpl::get_value(Handle handle) { "invalid handle: %p", handle); auto info = reinterpret_cast(handle); // donnot use info->value_fetched, it's unsafe - mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!"); + mgb_assert(!info->invalid, "tensor is unusable due to previous error"); return wait_tensor(info, TensorProp::HostValue)->get_value(); } @@ -776,7 +776,7 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd) { RECORD_EVENT(OpExecuteFinishEvent, apply_id); // End profiling operator } - + void ChannelImpl::flush_apply_stack() { m_applying = true; auto& state = get_worker_state(); @@ -1002,7 +1002,7 @@ std::tuple, SmallVector, SmallVectorid, TensorCommandFinishEvent::Put); sample_on_device(cmd.dest->desc.comp_node, false); } else if constexpr (std::is_same_v) { + for (auto& i : cmd.inputs) { + if (i->invalid) { + MGB_LOCK_GUARD(m_mutex); + for (auto& i : cmd.outputs) { + i->invalid = true; + } + return; + } + } m_apply_stack.push({cmd, 0, nullptr}); flush_apply_stack(); for (size_t i = 0; i < cmd.outputs.size(); ++i) { @@ -1085,21 +1094,23 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { RECORD_EVENT(TensorCommandFinishEvent, tensor_id, TensorCommandFinishEvent::Del); sample_on_device(device, false); } else if constexpr (std::is_same_v) { + if (cmd.dest->invalid) return; imperative_log_profile_begin("GetValue"); if (!cmd.dest->ptr && cmd.dest->evict_type != EvictType::NONE) { regenerate(cmd.dest); } - mgb_assert(cmd.dest->ptr, "Invalid tensor ptr!"); cmd.dest->ptr->fetch_value(); MGB_LOCK_GUARD(m_mutex); notify_tensor_unsafe(cmd.dest); imperative_log_profile_end("GetValue"); } else if constexpr (std::is_same_v) { + if (cmd.dest->invalid) return; RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandEvent::SwapIn); produce_tensor(cmd.dest, Tensor::make(cmd.dest->h_value)); RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandFinishEvent::SwapIn); sample_on_device(cmd.dest->desc.comp_node, false); } else if constexpr (std::is_same_v) { + if (cmd.dest->invalid) return; RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandEvent::SwapOut); cmd.dest->h_value = cmd.dest->ptr->get_value(); if (cmd.dest->evict_type == EvictType::NONE) { @@ -1110,6 +1121,7 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandFinishEvent::SwapOut); sample_on_device(cmd.dest->desc.comp_node, false); } else if constexpr (std::is_same_v) { + if (cmd.dest->invalid) return; RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandEvent::Drop); do_drop(cmd.dest, true); RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandFinishEvent::Drop); @@ -1186,7 +1198,11 @@ void ChannelImpl::check_worker_exc_unsafe() { m_waitee = nullptr; std::exception_ptr exc; std::swap(exc, m_worker_exc); - std::rethrow_exception(exc); + try { + std::rethrow_exception(exc); + } catch (...) { + throw AsyncError(); + } } } diff --git a/imperative/src/include/megbrain/imperative/interpreter.h b/imperative/src/include/megbrain/imperative/interpreter.h index 92de64ed3..81ca9b94f 100644 --- a/imperative/src/include/megbrain/imperative/interpreter.h +++ b/imperative/src/include/megbrain/imperative/interpreter.h @@ -16,6 +16,17 @@ namespace mgb::imperative::interpreter { +struct AsyncError : std::nested_exception, std::exception { + const char* what() const noexcept { + try { + rethrow_nested(); + } catch (const std::exception& e) { + return e.what(); + } catch (...) {} + return "unkown async error"; + } +}; + struct Interpreter { using Handle = void*; -- GitLab