From 51fa530d2ab052fa6d7877685513d7983c999435 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 10 Nov 2020 15:16:16 +0800 Subject: [PATCH] fix(mge/interpreter): add check for invalid tensor ptr GitOrigin-RevId: e8edcd92a45150439d164d9ef7fbef44805d8d46 --- imperative/src/impl/interpreter_impl.cpp | 14 ++++++++++++++ imperative/src/impl/interpreter_impl.h | 1 + 2 files changed, 15 insertions(+) diff --git a/imperative/src/impl/interpreter_impl.cpp b/imperative/src/impl/interpreter_impl.cpp index f66428486..35c82aed6 100644 --- a/imperative/src/impl/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter_impl.cpp @@ -10,6 +10,7 @@ */ #include "./interpreter_impl.h" +#include "megbrain/common.h" using namespace mgb; @@ -58,11 +59,14 @@ SmallVector ChannelImpl::apply_op( input_infos.reserve(inputs.size()); SmallVector input_descs; input_descs.reserve(inputs.size()); + std::unique_lock lock(m_mutex); for (auto i : inputs) { auto info = reinterpret_cast(i); + mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!"); input_infos.push_back(info); input_descs.push_back(info->desc); } + lock.unlock(); auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); ApplyOp cmd{std::move(op)}; @@ -101,6 +105,7 @@ HostTensorND ChannelImpl::get_value(void* handle) { std::unique_lock lock(m_mutex); mgb_assert(!m_waitee); if (!info->value_fetched) { + mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!"); m_waitee = info; m_worker.add_task(GetValue{info}); m_cv.wait(lock, [&]() { @@ -222,6 +227,7 @@ void ChannelImpl::process_one_task(Command& cmd) { SmallVector tensor_inputs; tensor_inputs.reserve(cmd.inputs.size()); for (auto i : cmd.inputs) { + mgb_assert(i->ptr, "Invalid input tensor ptr!"); tensor_inputs.push_back(i->ptr); } auto tensor_outputs = OpDef::apply_on_physical_tensor(*cmd.op, tensor_inputs); @@ -232,6 +238,7 @@ void ChannelImpl::process_one_task(Command& cmd) { } else if constexpr (std::is_same_v) { free(cmd.dest); } else if constexpr (std::is_same_v) { + mgb_assert(cmd.dest->ptr, "Invalid tensor ptr!"); cmd.dest->ptr->fetch_value(); MGB_LOCK_GUARD(m_mutex); cmd.dest->value_fetched = true; @@ -243,6 +250,13 @@ void ChannelImpl::process_one_task(Command& cmd) { } } catch (...) { MGB_LOCK_GUARD(m_mutex); + if constexpr (std::is_same_v) { + for (auto oup : cmd.outputs) { + oup->invalid = true; + } + } else if constexpr (std::is_same_v) { + cmd.dest->invalid = true; + } m_worker_exc = std::current_exception(); m_cv.notify_all(); } diff --git a/imperative/src/impl/interpreter_impl.h b/imperative/src/impl/interpreter_impl.h index 652a31ea2..508e7c463 100644 --- a/imperative/src/impl/interpreter_impl.h +++ b/imperative/src/impl/interpreter_impl.h @@ -28,6 +28,7 @@ struct TensorInfo { TensorPtr ptr; LogicalTensorDesc desc; bool value_fetched = false; + bool invalid = false; }; struct Put { -- GitLab