diff --git a/imperative/src/impl/interpreter_impl.cpp b/imperative/src/impl/interpreter_impl.cpp index 8ab5e0aa99ac64e281234dffcb2ff1736d0b5234..7be91f180acaf841580333f2a2a7792892930dc3 100644 --- a/imperative/src/impl/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter_impl.cpp @@ -233,18 +233,17 @@ HostTensorND ChannelImpl::get_value(Handle handle) { mgb_assert(!m_waitee); // donnot use info->value_fetched, it's unsafe mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!"); + std::unique_lock lock(m_mutex); TensorPtr tensor_ptr = info->ptr; auto value_fetched = [&]() { return tensor_ptr && tensor_ptr->value_fetched(); }; if (!value_fetched()) { - std::unique_lock lock(m_mutex); m_waitee = info; regenerate(info); m_buffer.enqueue(GetValue{info}); m_cv.wait(lock, [&]() { check_worker_exc_unsafe(); - // get tensor ptr in lock to ensure safety tensor_ptr = info->ptr; return value_fetched(); }); @@ -359,6 +358,11 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { } } +void ChannelImpl::release_tensor(TensorInfo* dest) { + MGB_LOCK_GUARD(m_mutex); + dest->ptr.reset(); +} + void ChannelImpl::regenerate(TensorInfo* dest) { if (dest->evict_type == DROP) { recompute(dest->producer); @@ -481,9 +485,9 @@ void ChannelImpl::process_one_task(Command& cmd) { produce_tensor(cmd.dest, Tensor::make(cmd.dest->h_value)); } else if constexpr (std::is_same_v) { cmd.dest->h_value = cmd.dest->ptr->get_value(); - cmd.dest->ptr.reset(); + release_tensor(cmd.dest); } else if constexpr (std::is_same_v) { - cmd.dest->ptr.reset(); + release_tensor(cmd.dest); } else if constexpr (std::is_same_v) { produce_tensor(cmd.dest, cmd.src->ptr); free(cmd.src); diff --git a/imperative/src/impl/interpreter_impl.h b/imperative/src/impl/interpreter_impl.h index c9c070c8c07f42980dbd95e9a19174c67bd7a0f0..979201dd090f88ef754b65bfa22c63a5ee1db532 100644 --- a/imperative/src/impl/interpreter_impl.h +++ b/imperative/src/impl/interpreter_impl.h @@ -249,6 +249,8 @@ private: void produce_tensor(TensorInfo* dest, TensorPtr ptr); + void release_tensor(TensorInfo* dest); + void regenerate(TensorInfo* dest); void recompute(TensorInfo::ComputePath* path);