提交 51fa530d 编写于 作者: M Megvii Engine Team

fix(mge/interpreter): add check for invalid tensor ptr

GitOrigin-RevId: e8edcd92a45150439d164d9ef7fbef44805d8d46
上级 634de590
......@@ -10,6 +10,7 @@
*/
#include "./interpreter_impl.h"
#include "megbrain/common.h"
using namespace mgb;
......@@ -58,11 +59,14 @@ SmallVector<void*> ChannelImpl::apply_op(
input_infos.reserve(inputs.size());
SmallVector<LogicalTensorDesc> input_descs;
input_descs.reserve(inputs.size());
std::unique_lock<decltype(m_mutex)> lock(m_mutex);
for (auto i : inputs) {
auto info = reinterpret_cast<TensorInfo*>(i);
mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!");
input_infos.push_back(info);
input_descs.push_back(info->desc);
}
lock.unlock();
auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs);
ApplyOp cmd{std::move(op)};
......@@ -101,6 +105,7 @@ HostTensorND ChannelImpl::get_value(void* handle) {
std::unique_lock<decltype(m_mutex)> lock(m_mutex);
mgb_assert(!m_waitee);
if (!info->value_fetched) {
mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!");
m_waitee = info;
m_worker.add_task(GetValue{info});
m_cv.wait(lock, [&]() {
......@@ -222,6 +227,7 @@ void ChannelImpl::process_one_task(Command& cmd) {
SmallVector<TensorPtr> tensor_inputs;
tensor_inputs.reserve(cmd.inputs.size());
for (auto i : cmd.inputs) {
mgb_assert(i->ptr, "Invalid input tensor ptr!");
tensor_inputs.push_back(i->ptr);
}
auto tensor_outputs = OpDef::apply_on_physical_tensor(*cmd.op, tensor_inputs);
......@@ -232,6 +238,7 @@ void ChannelImpl::process_one_task(Command& cmd) {
} else if constexpr (std::is_same_v<T, Del>) {
free(cmd.dest);
} else if constexpr (std::is_same_v<T, GetValue>) {
mgb_assert(cmd.dest->ptr, "Invalid tensor ptr!");
cmd.dest->ptr->fetch_value();
MGB_LOCK_GUARD(m_mutex);
cmd.dest->value_fetched = true;
......@@ -243,6 +250,13 @@ void ChannelImpl::process_one_task(Command& cmd) {
}
} catch (...) {
MGB_LOCK_GUARD(m_mutex);
if constexpr (std::is_same_v<T, ApplyOp>) {
for (auto oup : cmd.outputs) {
oup->invalid = true;
}
} else if constexpr (std::is_same_v<T, Put>) {
cmd.dest->invalid = true;
}
m_worker_exc = std::current_exception();
m_cv.notify_all();
}
......
......@@ -28,6 +28,7 @@ struct TensorInfo {
TensorPtr ptr;
LogicalTensorDesc desc;
bool value_fetched = false;
bool invalid = false;
};
struct Put {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册