提交 51fa530d 编写于 作者: M Megvii Engine Team

fix(mge/interpreter): add check for invalid tensor ptr

GitOrigin-RevId: e8edcd92a45150439d164d9ef7fbef44805d8d46
上级 634de590
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
*/ */
#include "./interpreter_impl.h" #include "./interpreter_impl.h"
#include "megbrain/common.h"
using namespace mgb; using namespace mgb;
...@@ -58,11 +59,14 @@ SmallVector<void*> ChannelImpl::apply_op( ...@@ -58,11 +59,14 @@ SmallVector<void*> ChannelImpl::apply_op(
input_infos.reserve(inputs.size()); input_infos.reserve(inputs.size());
SmallVector<LogicalTensorDesc> input_descs; SmallVector<LogicalTensorDesc> input_descs;
input_descs.reserve(inputs.size()); input_descs.reserve(inputs.size());
std::unique_lock<decltype(m_mutex)> lock(m_mutex);
for (auto i : inputs) { for (auto i : inputs) {
auto info = reinterpret_cast<TensorInfo*>(i); auto info = reinterpret_cast<TensorInfo*>(i);
mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!");
input_infos.push_back(info); input_infos.push_back(info);
input_descs.push_back(info->desc); input_descs.push_back(info->desc);
} }
lock.unlock();
auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs);
ApplyOp cmd{std::move(op)}; ApplyOp cmd{std::move(op)};
...@@ -101,6 +105,7 @@ HostTensorND ChannelImpl::get_value(void* handle) { ...@@ -101,6 +105,7 @@ HostTensorND ChannelImpl::get_value(void* handle) {
std::unique_lock<decltype(m_mutex)> lock(m_mutex); std::unique_lock<decltype(m_mutex)> lock(m_mutex);
mgb_assert(!m_waitee); mgb_assert(!m_waitee);
if (!info->value_fetched) { if (!info->value_fetched) {
mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!");
m_waitee = info; m_waitee = info;
m_worker.add_task(GetValue{info}); m_worker.add_task(GetValue{info});
m_cv.wait(lock, [&]() { m_cv.wait(lock, [&]() {
...@@ -222,6 +227,7 @@ void ChannelImpl::process_one_task(Command& cmd) { ...@@ -222,6 +227,7 @@ void ChannelImpl::process_one_task(Command& cmd) {
SmallVector<TensorPtr> tensor_inputs; SmallVector<TensorPtr> tensor_inputs;
tensor_inputs.reserve(cmd.inputs.size()); tensor_inputs.reserve(cmd.inputs.size());
for (auto i : cmd.inputs) { for (auto i : cmd.inputs) {
mgb_assert(i->ptr, "Invalid input tensor ptr!");
tensor_inputs.push_back(i->ptr); tensor_inputs.push_back(i->ptr);
} }
auto tensor_outputs = OpDef::apply_on_physical_tensor(*cmd.op, tensor_inputs); auto tensor_outputs = OpDef::apply_on_physical_tensor(*cmd.op, tensor_inputs);
...@@ -232,6 +238,7 @@ void ChannelImpl::process_one_task(Command& cmd) { ...@@ -232,6 +238,7 @@ void ChannelImpl::process_one_task(Command& cmd) {
} else if constexpr (std::is_same_v<T, Del>) { } else if constexpr (std::is_same_v<T, Del>) {
free(cmd.dest); free(cmd.dest);
} else if constexpr (std::is_same_v<T, GetValue>) { } else if constexpr (std::is_same_v<T, GetValue>) {
mgb_assert(cmd.dest->ptr, "Invalid tensor ptr!");
cmd.dest->ptr->fetch_value(); cmd.dest->ptr->fetch_value();
MGB_LOCK_GUARD(m_mutex); MGB_LOCK_GUARD(m_mutex);
cmd.dest->value_fetched = true; cmd.dest->value_fetched = true;
...@@ -243,6 +250,13 @@ void ChannelImpl::process_one_task(Command& cmd) { ...@@ -243,6 +250,13 @@ void ChannelImpl::process_one_task(Command& cmd) {
} }
} catch (...) { } catch (...) {
MGB_LOCK_GUARD(m_mutex); MGB_LOCK_GUARD(m_mutex);
if constexpr (std::is_same_v<T, ApplyOp>) {
for (auto oup : cmd.outputs) {
oup->invalid = true;
}
} else if constexpr (std::is_same_v<T, Put>) {
cmd.dest->invalid = true;
}
m_worker_exc = std::current_exception(); m_worker_exc = std::current_exception();
m_cv.notify_all(); m_cv.notify_all();
} }
......
...@@ -28,6 +28,7 @@ struct TensorInfo { ...@@ -28,6 +28,7 @@ struct TensorInfo {
TensorPtr ptr; TensorPtr ptr;
LogicalTensorDesc desc; LogicalTensorDesc desc;
bool value_fetched = false; bool value_fetched = false;
bool invalid = false;
}; };
struct Put { struct Put {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册