提交 c8e04ce4 编写于 作者: M Megvii Engine Team

fix(imperative/interpreter): use info->h_value for host compute

Rather than use host_only Tensor and modify Put's behavior.

GitOrigin-RevId: f890d66acb731fe1208a37f6335128d9ab169aaa
上级 a5a60679
......@@ -34,8 +34,13 @@ Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
info->desc.layout = value.layout();
info->desc.comp_node = value.comp_node();
info->desc.value = value.proxy_to_default_cpu();
info->h_value = value;
m_valid_handle.insert(info);
m_buffer.enqueue(Put{info, value, no_cache});
if (m_async_level == 0) {
sync();
info->desc.comp_node.sync();
}
return info;
}
......@@ -90,14 +95,19 @@ void ChannelImpl::dispatch_default_cpu(
{
MGB_LOCK_GUARD(m_mutex);
for (auto&& info : input_infos) {
mgb_assert(info->ptr, "invalid tensor ptr!");
auto input_cn = info->desc.comp_node;
if (!output_cn.valid()) {
output_cn = info->ptr->comp_node();
output_cn = input_cn;
} else {
mgb_assert(output_cn == input_cn, "cannot decide output comp node");
}
if (info->ptr && info->ptr->try_get_value()) {
input_tensornds.emplace_back(info->ptr->get_value().proxy_to_default_cpu());
} else {
mgb_assert(output_cn == info->ptr->comp_node(), "cannot decide output comp node");
mgb_assert(!info->h_value.empty(), "inp->h_value is empty!");
input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu());
}
mgb_assert(info->ptr->try_get_value(), "no valid host value");
input_tensornds.emplace_back(info->ptr->get_value().proxy_to_default_cpu());
}
}
......@@ -116,18 +126,12 @@ void ChannelImpl::dispatch_default_cpu(
SmallVector<TensorInfo*> output_infos;
output_infos.reserve(output_descs.size());
for (auto&& tensornd : output_tensornds) {
// tensornd -> host_tensornd
HostTensorND host_tensornd = HostTensorND::make_proxy(tensornd)
.proxy_to_comp_node(output_cn);
// tensornd -> desc
LogicalTensorDesc desc = {tensornd.layout(), output_cn, tensornd};
// tensornd -> tensor
auto info = alloc();
info->desc = desc;
m_valid_handle.insert(info);
// use `put` for consistency
auto info = reinterpret_cast<TensorInfo*>(put(host_tensornd, false));
mgb_assert(info->desc.layout.ndim != 0);
output_infos.push_back(info);
info->ptr = Tensor::make(host_tensornd, true); // host_only=true
info->value_fetched = true;
outputs->push_back(info);
}
......@@ -159,6 +163,11 @@ void ChannelImpl::dispatch_kernel(
for (auto&& desc : output_descs) {
auto info = alloc();
info->desc = desc;
// make sure desc's value is consistent with h_value
if (!info->desc.value.empty()) {
info->h_value = HostTensorND::make_proxy(desc.value)
.proxy_to_comp_node(desc.comp_node);
}
m_valid_handle.insert(info);
cmd.outputs.push_back(info);
outputs->push_back(info);
......@@ -220,7 +229,6 @@ SmallVector<Handle> ChannelImpl::apply_op(
break;
}
}
mgb_assert(outputs.size() > 0, "Invalid dispatch mode!");
return outputs;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册