提交 462364eb 编写于 作者: M Megvii Engine Team

perf(interpreter): don't check host value if unnecessary

GitOrigin-RevId: 5306c71328d10cbb9335cdf2b5765f7dbf0023d9
上级 d23d1352
...@@ -935,13 +935,14 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { ...@@ -935,13 +935,14 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop); MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop);
bool require_host = prop == TensorProp::HostValue; bool require_host = prop == TensorProp::HostValue;
auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); }; auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); };
bool wait_host = !host_available(); bool wait_host = false;
if (require_host && wait_host) { if (require_host && !host_available()) {
// avoid dead lock // avoid dead lock
lock.unlock(); lock.unlock();
m_buffer.enqueue(GetValue{info}); m_buffer.enqueue(GetValue{info});
m_buffer.flush(); m_buffer.flush();
lock.lock(); lock.lock();
wait_host = true;
} }
m_cv.wait(lock, [&]() { m_cv.wait(lock, [&]() {
check_worker_exc_unsafe(); check_worker_exc_unsafe();
...@@ -949,7 +950,7 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { ...@@ -949,7 +950,7 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
}); });
MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop); MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop);
m_waitee = nullptr; m_waitee = nullptr;
if (require_host && wait_host) { if (wait_host) {
auto err = info->ptr->comp_node().check_async_error(); auto err = info->ptr->comp_node().check_async_error();
mgb_assert(!err, "%s", err->what()); mgb_assert(!err, "%s", err->what());
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册