提交 d04b4bc0 编写于 作者: M Megvii Engine Team

fix(interp): thread safety for drop and swapout

GitOrigin-RevId: 7684f160bf1ca239c92c977c7238cac2b51ab4a2
上级 3eb529d6
......@@ -233,18 +233,17 @@ HostTensorND ChannelImpl::get_value(Handle handle) {
mgb_assert(!m_waitee);
// donnot use info->value_fetched, it's unsafe
mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!");
std::unique_lock<decltype(m_mutex)> lock(m_mutex);
TensorPtr tensor_ptr = info->ptr;
auto value_fetched = [&]() {
return tensor_ptr && tensor_ptr->value_fetched();
};
if (!value_fetched()) {
std::unique_lock<decltype(m_mutex)> lock(m_mutex);
m_waitee = info;
regenerate(info);
m_buffer.enqueue(GetValue{info});
m_cv.wait(lock, [&]() {
check_worker_exc_unsafe();
// get tensor ptr in lock to ensure safety
tensor_ptr = info->ptr;
return value_fetched();
});
......@@ -359,6 +358,11 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
}
}
void ChannelImpl::release_tensor(TensorInfo* dest) {
MGB_LOCK_GUARD(m_mutex);
dest->ptr.reset();
}
void ChannelImpl::regenerate(TensorInfo* dest) {
if (dest->evict_type == DROP) {
recompute(dest->producer);
......@@ -481,9 +485,9 @@ void ChannelImpl::process_one_task(Command& cmd) {
produce_tensor(cmd.dest, Tensor::make(cmd.dest->h_value));
} else if constexpr (std::is_same_v<T, SwapOut>) {
cmd.dest->h_value = cmd.dest->ptr->get_value();
cmd.dest->ptr.reset();
release_tensor(cmd.dest);
} else if constexpr (std::is_same_v<T, Drop>) {
cmd.dest->ptr.reset();
release_tensor(cmd.dest);
} else if constexpr (std::is_same_v<T, Move>) {
produce_tensor(cmd.dest, cmd.src->ptr);
free(cmd.src);
......
......@@ -249,6 +249,8 @@ private:
void produce_tensor(TensorInfo* dest, TensorPtr ptr);
void release_tensor(TensorInfo* dest);
void regenerate(TensorInfo* dest);
void recompute(TensorInfo::ComputePath* path);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册