提交 e6a8b025 编写于 作者: M Megvii Engine Team

fix(mge): ignore errors caused by earlier async errors

GitOrigin-RevId: ce2028d38acba75cf7e31ed5eac7de38f3204b45
上级 0708bc78
......@@ -12,6 +12,7 @@ from ..core.ops.builtin import AssertEqual
from ..tensor import Tensor
from ..utils.deprecation import deprecated_func
from .elemwise import abs, maximum, minimum
from .tensor import ones, zeros
__all__ = ["topk_accuracy"]
......@@ -59,6 +60,13 @@ def _assert_equal(
return result
def _simulate_error():
x1 = zeros(100)
x2 = ones(100)
(ret,) = apply(AssertEqual(maxerr=0, verbose=False), x1, x2, x2)
return ret
topk_accuracy = deprecated_func(
"1.3", "megengine.functional.metric", "topk_accuracy", True
)
......
......@@ -90,3 +90,11 @@ with megengine.core.option("enable_host_compute", 0):
y.numpy()
"""
subprocess.check_call([sys.executable, "-c", prog])
def test_regression_2870():
x = F.zeros(1000)
y = F.utils._simulate_error()
with pytest.raises(RuntimeError):
y.numpy()
(x + x).numpy()
......@@ -373,7 +373,7 @@ SmallVector<Handle> ChannelImpl::apply_op_impl(
MGB_LOCK_GUARD(m_mutex);
for (auto i : inputs) {
auto info = reinterpret_cast<TensorInfo*>(i);
mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!");
mgb_assert(!info->invalid, "an input tensor is unusable due to previous error");
input_infos.push_back(info);
input_descs.push_back(info->desc);
}
......@@ -403,7 +403,7 @@ HostTensorND ChannelImpl::get_value(Handle handle) {
"invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle);
// donnot use info->value_fetched, it's unsafe
mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!");
mgb_assert(!info->invalid, "tensor is unusable due to previous error");
return wait_tensor(info, TensorProp::HostValue)->get_value();
}
......@@ -1021,6 +1021,15 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandFinishEvent::Put);
sample_on_device(cmd.dest->desc.comp_node, false);
} else if constexpr (std::is_same_v<T, ApplyOp>) {
for (auto& i : cmd.inputs) {
if (i->invalid) {
MGB_LOCK_GUARD(m_mutex);
for (auto& i : cmd.outputs) {
i->invalid = true;
}
return;
}
}
m_apply_stack.push({cmd, 0, nullptr});
flush_apply_stack();
for (size_t i = 0; i < cmd.outputs.size(); ++i) {
......@@ -1085,21 +1094,23 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
RECORD_EVENT(TensorCommandFinishEvent, tensor_id, TensorCommandFinishEvent::Del);
sample_on_device(device, false);
} else if constexpr (std::is_same_v<T, GetValue>) {
if (cmd.dest->invalid) return;
imperative_log_profile_begin("GetValue");
if (!cmd.dest->ptr && cmd.dest->evict_type != EvictType::NONE) {
regenerate(cmd.dest);
}
mgb_assert(cmd.dest->ptr, "Invalid tensor ptr!");
cmd.dest->ptr->fetch_value();
MGB_LOCK_GUARD(m_mutex);
notify_tensor_unsafe(cmd.dest);
imperative_log_profile_end("GetValue");
} else if constexpr (std::is_same_v<T, SwapIn>) {
if (cmd.dest->invalid) return;
RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandEvent::SwapIn);
produce_tensor(cmd.dest, Tensor::make(cmd.dest->h_value));
RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandFinishEvent::SwapIn);
sample_on_device(cmd.dest->desc.comp_node, false);
} else if constexpr (std::is_same_v<T, SwapOut>) {
if (cmd.dest->invalid) return;
RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandEvent::SwapOut);
cmd.dest->h_value = cmd.dest->ptr->get_value();
if (cmd.dest->evict_type == EvictType::NONE) {
......@@ -1110,6 +1121,7 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandFinishEvent::SwapOut);
sample_on_device(cmd.dest->desc.comp_node, false);
} else if constexpr (std::is_same_v<T, Drop>) {
if (cmd.dest->invalid) return;
RECORD_EVENT(TensorCommandEvent, cmd.dest->id, TensorCommandEvent::Drop);
do_drop(cmd.dest, true);
RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandFinishEvent::Drop);
......@@ -1186,7 +1198,11 @@ void ChannelImpl::check_worker_exc_unsafe() {
m_waitee = nullptr;
std::exception_ptr exc;
std::swap(exc, m_worker_exc);
try {
std::rethrow_exception(exc);
} catch (...) {
throw AsyncError();
}
}
}
......
......@@ -16,6 +16,17 @@
namespace mgb::imperative::interpreter {
struct AsyncError : std::nested_exception, std::exception {
const char* what() const noexcept {
try {
rethrow_nested();
} catch (const std::exception& e) {
return e.what();
} catch (...) {}
return "unkown async error";
}
};
struct Interpreter {
using Handle = void*;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册