提交 d3247bee 编写于 作者: M Megvii Engine Team

fix(dtr): always write shape when tensor produced

GitOrigin-RevId: d2b23b5c25bb509456b3d77a56f20ec985efa458
上级 0a266d7a
......@@ -312,7 +312,7 @@ void ChannelImpl::dispatch_default_cpu(
HostTensorND::make_proxy(tensornd).proxy_to_comp_node(output_cn);
// use `put` for consistency
auto info = reinterpret_cast<TensorInfo*>(put_impl(host_tensornd, false));
mgb_assert(info->desc.layout.ndim != 0);
mgb_assert(info->shape_valid());
output_infos.push_back(info);
outputs->push_back(reinterpret_cast<Handle>(info));
}
......@@ -406,7 +406,7 @@ SmallVector<Handle> ChannelImpl::apply_op(
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
auto* input = reinterpret_cast<TensorInfo*>(inputs[0]);
if (op->same_type<GetVarShape>() && input->desc.layout.ndim) {
if (op->same_type<GetVarShape>() && input->shape_valid()) {
size_t ndim = input->desc.layout.ndim;
auto& gvs = op->cast_final_safe<GetVarShape>();
if (gvs.axis == MEGDNN_MAX_NDIM) {
......@@ -477,11 +477,11 @@ TensorShape ChannelImpl::get_shape(Handle handle) {
m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
handle);
auto info = reinterpret_cast<TensorInfo*>(handle);
if (info->desc.layout.ndim != 0) {
if (info->shape_valid()) {
return info->desc.layout;
}
TensorShape ret = wait_tensor(info, TensorProp::Shape)->layout();
mgb_assert(ret.ndim != 0);
mgb_assert(ret.ndim > 0);
return ret;
}
......@@ -694,12 +694,7 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(),
ptr->raw_ptr_not_for_readwrite());
// update tensor desc for static infer
if (dest->desc.layout.ndim) {
mgb_assert(
dest->desc.layout.eq_shape(ptr->layout()),
"shape infer error, %s vs %s", dest->desc.layout.to_string().c_str(),
ptr->layout().to_string().c_str());
}
dest->update_layout(ptr->layout());
// in order to avoid performance impact,
// memory forwarding is disabled when DTR is enabled
if (state.options.enable_dtr_auto_drop || state.options.disable_memory_forwarding) {
......
......@@ -48,6 +48,7 @@ struct TensorInfo {
// Lock interpreter when visiting `ptr`.
TensorPtr ptr;
LogicalTensorDesc desc;
Spinlock lock;
double compute_time;
size_t memory;
......@@ -158,6 +159,26 @@ struct TensorInfo {
// UINT_MAX as a magic default value
size_t cand_index = UINT_MAX;
bool shape_valid() {
MGB_LOCK_GUARD(lock);
return desc.layout.ndim;
}
void update_layout(const TensorLayout& layout) {
MGB_LOCK_GUARD(lock);
mgb_assert(desc.layout.dtype == layout.dtype, "dtype mismatch");
mgb_assert(desc.layout.format == layout.format, "format mismatch");
if (desc.layout.ndim) {
mgb_assert(
desc.layout.eq_shape(layout), "shape infer error, %s vs %s",
desc.layout.to_string().c_str(), layout.to_string().c_str());
// ignore strides
} else {
static_cast<TensorShape&>(desc.layout) = layout;
desc.layout.init_contiguous_stride();
}
}
};
} // namespace interpreter::intl
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册