diff --git a/imperative/src/impl/interpreter_impl.cpp b/imperative/src/impl/interpreter_impl.cpp index f6642848687987a43f5fde64f9169820e0979220..35c82aed6bb80c7784fd71f8fe93ba7c5f01cea6 100644 --- a/imperative/src/impl/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter_impl.cpp @@ -10,6 +10,7 @@ */ #include "./interpreter_impl.h" +#include "megbrain/common.h" using namespace mgb; @@ -58,11 +59,14 @@ SmallVector ChannelImpl::apply_op( input_infos.reserve(inputs.size()); SmallVector input_descs; input_descs.reserve(inputs.size()); + std::unique_lock lock(m_mutex); for (auto i : inputs) { auto info = reinterpret_cast(i); + mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!"); input_infos.push_back(info); input_descs.push_back(info->desc); } + lock.unlock(); auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); ApplyOp cmd{std::move(op)}; @@ -101,6 +105,7 @@ HostTensorND ChannelImpl::get_value(void* handle) { std::unique_lock lock(m_mutex); mgb_assert(!m_waitee); if (!info->value_fetched) { + mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!"); m_waitee = info; m_worker.add_task(GetValue{info}); m_cv.wait(lock, [&]() { @@ -222,6 +227,7 @@ void ChannelImpl::process_one_task(Command& cmd) { SmallVector tensor_inputs; tensor_inputs.reserve(cmd.inputs.size()); for (auto i : cmd.inputs) { + mgb_assert(i->ptr, "Invalid input tensor ptr!"); tensor_inputs.push_back(i->ptr); } auto tensor_outputs = OpDef::apply_on_physical_tensor(*cmd.op, tensor_inputs); @@ -232,6 +238,7 @@ void ChannelImpl::process_one_task(Command& cmd) { } else if constexpr (std::is_same_v) { free(cmd.dest); } else if constexpr (std::is_same_v) { + mgb_assert(cmd.dest->ptr, "Invalid tensor ptr!"); cmd.dest->ptr->fetch_value(); MGB_LOCK_GUARD(m_mutex); cmd.dest->value_fetched = true; @@ -243,6 +250,13 @@ void ChannelImpl::process_one_task(Command& cmd) { } } catch (...) { MGB_LOCK_GUARD(m_mutex); + if constexpr (std::is_same_v) { + for (auto oup : cmd.outputs) { + oup->invalid = true; + } + } else if constexpr (std::is_same_v) { + cmd.dest->invalid = true; + } m_worker_exc = std::current_exception(); m_cv.notify_all(); } diff --git a/imperative/src/impl/interpreter_impl.h b/imperative/src/impl/interpreter_impl.h index 652a31ea27352280e255cede01ee1bcd9d749bbc..508e7c46374b293fb8d05a5a217015321fc66ffc 100644 --- a/imperative/src/impl/interpreter_impl.h +++ b/imperative/src/impl/interpreter_impl.h @@ -28,6 +28,7 @@ struct TensorInfo { TensorPtr ptr; LogicalTensorDesc desc; bool value_fetched = false; + bool invalid = false; }; struct Put {