提交 c8e04ce4 编写于 作者: M Megvii Engine Team

fix(imperative/interpreter): use info->h_value for host compute

Rather than use host_only Tensor and modify Put's behavior.

GitOrigin-RevId: f890d66acb731fe1208a37f6335128d9ab169aaa
上级 a5a60679
...@@ -34,8 +34,13 @@ Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) { ...@@ -34,8 +34,13 @@ Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
info->desc.layout = value.layout(); info->desc.layout = value.layout();
info->desc.comp_node = value.comp_node(); info->desc.comp_node = value.comp_node();
info->desc.value = value.proxy_to_default_cpu(); info->desc.value = value.proxy_to_default_cpu();
info->h_value = value;
m_valid_handle.insert(info); m_valid_handle.insert(info);
m_buffer.enqueue(Put{info, value, no_cache}); m_buffer.enqueue(Put{info, value, no_cache});
if (m_async_level == 0) {
sync();
info->desc.comp_node.sync();
}
return info; return info;
} }
...@@ -90,14 +95,19 @@ void ChannelImpl::dispatch_default_cpu( ...@@ -90,14 +95,19 @@ void ChannelImpl::dispatch_default_cpu(
{ {
MGB_LOCK_GUARD(m_mutex); MGB_LOCK_GUARD(m_mutex);
for (auto&& info : input_infos) { for (auto&& info : input_infos) {
mgb_assert(info->ptr, "invalid tensor ptr!"); auto input_cn = info->desc.comp_node;
if (!output_cn.valid()) { if (!output_cn.valid()) {
output_cn = info->ptr->comp_node(); output_cn = input_cn;
} else {
mgb_assert(output_cn == input_cn, "cannot decide output comp node");
}
if (info->ptr && info->ptr->try_get_value()) {
input_tensornds.emplace_back(info->ptr->get_value().proxy_to_default_cpu());
} else { } else {
mgb_assert(output_cn == info->ptr->comp_node(), "cannot decide output comp node"); mgb_assert(!info->h_value.empty(), "inp->h_value is empty!");
input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu());
} }
mgb_assert(info->ptr->try_get_value(), "no valid host value");
input_tensornds.emplace_back(info->ptr->get_value().proxy_to_default_cpu());
} }
} }
...@@ -116,18 +126,12 @@ void ChannelImpl::dispatch_default_cpu( ...@@ -116,18 +126,12 @@ void ChannelImpl::dispatch_default_cpu(
SmallVector<TensorInfo*> output_infos; SmallVector<TensorInfo*> output_infos;
output_infos.reserve(output_descs.size()); output_infos.reserve(output_descs.size());
for (auto&& tensornd : output_tensornds) { for (auto&& tensornd : output_tensornds) {
// tensornd -> host_tensornd
HostTensorND host_tensornd = HostTensorND::make_proxy(tensornd) HostTensorND host_tensornd = HostTensorND::make_proxy(tensornd)
.proxy_to_comp_node(output_cn); .proxy_to_comp_node(output_cn);
// tensornd -> desc // use `put` for consistency
LogicalTensorDesc desc = {tensornd.layout(), output_cn, tensornd}; auto info = reinterpret_cast<TensorInfo*>(put(host_tensornd, false));
// tensornd -> tensor mgb_assert(info->desc.layout.ndim != 0);
auto info = alloc();
info->desc = desc;
m_valid_handle.insert(info);
output_infos.push_back(info); output_infos.push_back(info);
info->ptr = Tensor::make(host_tensornd, true); // host_only=true
info->value_fetched = true;
outputs->push_back(info); outputs->push_back(info);
} }
...@@ -159,6 +163,11 @@ void ChannelImpl::dispatch_kernel( ...@@ -159,6 +163,11 @@ void ChannelImpl::dispatch_kernel(
for (auto&& desc : output_descs) { for (auto&& desc : output_descs) {
auto info = alloc(); auto info = alloc();
info->desc = desc; info->desc = desc;
// make sure desc's value is consistent with h_value
if (!info->desc.value.empty()) {
info->h_value = HostTensorND::make_proxy(desc.value)
.proxy_to_comp_node(desc.comp_node);
}
m_valid_handle.insert(info); m_valid_handle.insert(info);
cmd.outputs.push_back(info); cmd.outputs.push_back(info);
outputs->push_back(info); outputs->push_back(info);
...@@ -220,7 +229,6 @@ SmallVector<Handle> ChannelImpl::apply_op( ...@@ -220,7 +229,6 @@ SmallVector<Handle> ChannelImpl::apply_op(
break; break;
} }
} }
mgb_assert(outputs.size() > 0, "Invalid dispatch mode!");
return outputs; return outputs;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册