diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index ecf63e65479c8367374cbd4e1c241148ffc379ae..45fda4dbfe218f74da72236ce5e378a681c424cc 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -935,13 +935,14 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop); bool require_host = prop == TensorProp::HostValue; auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); }; - bool wait_host = !host_available(); - if (require_host && wait_host) { + bool wait_host = false; + if (require_host && !host_available()) { // avoid dead lock lock.unlock(); m_buffer.enqueue(GetValue{info}); m_buffer.flush(); lock.lock(); + wait_host = true; } m_cv.wait(lock, [&]() { check_worker_exc_unsafe(); @@ -949,7 +950,7 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { }); MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop); m_waitee = nullptr; - if (require_host && wait_host) { + if (wait_host) { auto err = info->ptr->comp_node().check_async_error(); mgb_assert(!err, "%s", err->what()); }