From c8e04ce4df9ee0ee492e2db4253087b5ebadfeb2 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 11 Jan 2021 17:32:11 +0800 Subject: [PATCH] fix(imperative/interpreter): use info->h_value for host compute Rather than use host_only Tensor and modify Put's behavior. GitOrigin-RevId: f890d66acb731fe1208a37f6335128d9ab169aaa --- imperative/src/impl/interpreter_impl.cpp | 38 ++++++++++++++---------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/imperative/src/impl/interpreter_impl.cpp b/imperative/src/impl/interpreter_impl.cpp index 5a012813e..8248adf39 100644 --- a/imperative/src/impl/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter_impl.cpp @@ -34,8 +34,13 @@ Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) { info->desc.layout = value.layout(); info->desc.comp_node = value.comp_node(); info->desc.value = value.proxy_to_default_cpu(); + info->h_value = value; m_valid_handle.insert(info); m_buffer.enqueue(Put{info, value, no_cache}); + if (m_async_level == 0) { + sync(); + info->desc.comp_node.sync(); + } return info; } @@ -90,14 +95,19 @@ void ChannelImpl::dispatch_default_cpu( { MGB_LOCK_GUARD(m_mutex); for (auto&& info : input_infos) { - mgb_assert(info->ptr, "invalid tensor ptr!"); + auto input_cn = info->desc.comp_node; if (!output_cn.valid()) { - output_cn = info->ptr->comp_node(); + output_cn = input_cn; + } else { + mgb_assert(output_cn == input_cn, "cannot decide output comp node"); + } + + if (info->ptr && info->ptr->try_get_value()) { + input_tensornds.emplace_back(info->ptr->get_value().proxy_to_default_cpu()); } else { - mgb_assert(output_cn == info->ptr->comp_node(), "cannot decide output comp node"); + mgb_assert(!info->h_value.empty(), "inp->h_value is empty!"); + input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu()); } - mgb_assert(info->ptr->try_get_value(), "no valid host value"); - input_tensornds.emplace_back(info->ptr->get_value().proxy_to_default_cpu()); } } @@ -116,18 +126,12 @@ void ChannelImpl::dispatch_default_cpu( SmallVector output_infos; output_infos.reserve(output_descs.size()); for (auto&& tensornd : output_tensornds) { - // tensornd -> host_tensornd HostTensorND host_tensornd = HostTensorND::make_proxy(tensornd) .proxy_to_comp_node(output_cn); - // tensornd -> desc - LogicalTensorDesc desc = {tensornd.layout(), output_cn, tensornd}; - // tensornd -> tensor - auto info = alloc(); - info->desc = desc; - m_valid_handle.insert(info); + // use `put` for consistency + auto info = reinterpret_cast(put(host_tensornd, false)); + mgb_assert(info->desc.layout.ndim != 0); output_infos.push_back(info); - info->ptr = Tensor::make(host_tensornd, true); // host_only=true - info->value_fetched = true; outputs->push_back(info); } @@ -159,6 +163,11 @@ void ChannelImpl::dispatch_kernel( for (auto&& desc : output_descs) { auto info = alloc(); info->desc = desc; + // make sure desc's value is consistent with h_value + if (!info->desc.value.empty()) { + info->h_value = HostTensorND::make_proxy(desc.value) + .proxy_to_comp_node(desc.comp_node); + } m_valid_handle.insert(info); cmd.outputs.push_back(info); outputs->push_back(info); @@ -220,7 +229,6 @@ SmallVector ChannelImpl::apply_op( break; } } - mgb_assert(outputs.size() > 0, "Invalid dispatch mode!"); return outputs; } -- GitLab