提交 d3247bee 编写于 作者: M Megvii Engine Team

fix(dtr): always write shape when tensor produced

GitOrigin-RevId: d2b23b5c25bb509456b3d77a56f20ec985efa458
上级 0a266d7a
...@@ -312,7 +312,7 @@ void ChannelImpl::dispatch_default_cpu( ...@@ -312,7 +312,7 @@ void ChannelImpl::dispatch_default_cpu(
HostTensorND::make_proxy(tensornd).proxy_to_comp_node(output_cn); HostTensorND::make_proxy(tensornd).proxy_to_comp_node(output_cn);
// use `put` for consistency // use `put` for consistency
auto info = reinterpret_cast<TensorInfo*>(put_impl(host_tensornd, false)); auto info = reinterpret_cast<TensorInfo*>(put_impl(host_tensornd, false));
mgb_assert(info->desc.layout.ndim != 0); mgb_assert(info->shape_valid());
output_infos.push_back(info); output_infos.push_back(info);
outputs->push_back(reinterpret_cast<Handle>(info)); outputs->push_back(reinterpret_cast<Handle>(info));
} }
...@@ -406,7 +406,7 @@ SmallVector<Handle> ChannelImpl::apply_op( ...@@ -406,7 +406,7 @@ SmallVector<Handle> ChannelImpl::apply_op(
MGB_LOCK_GUARD(m_spin); MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto* input = reinterpret_cast<TensorInfo*>(inputs[0]); auto* input = reinterpret_cast<TensorInfo*>(inputs[0]);
if (op->same_type<GetVarShape>() && input->desc.layout.ndim) { if (op->same_type<GetVarShape>() && input->shape_valid()) {
size_t ndim = input->desc.layout.ndim; size_t ndim = input->desc.layout.ndim;
auto& gvs = op->cast_final_safe<GetVarShape>(); auto& gvs = op->cast_final_safe<GetVarShape>();
if (gvs.axis == MEGDNN_MAX_NDIM) { if (gvs.axis == MEGDNN_MAX_NDIM) {
...@@ -477,11 +477,11 @@ TensorShape ChannelImpl::get_shape(Handle handle) { ...@@ -477,11 +477,11 @@ TensorShape ChannelImpl::get_shape(Handle handle) {
m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
handle); handle);
auto info = reinterpret_cast<TensorInfo*>(handle); auto info = reinterpret_cast<TensorInfo*>(handle);
if (info->desc.layout.ndim != 0) { if (info->shape_valid()) {
return info->desc.layout; return info->desc.layout;
} }
TensorShape ret = wait_tensor(info, TensorProp::Shape)->layout(); TensorShape ret = wait_tensor(info, TensorProp::Shape)->layout();
mgb_assert(ret.ndim != 0); mgb_assert(ret.ndim > 0);
return ret; return ret;
} }
...@@ -694,12 +694,7 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { ...@@ -694,12 +694,7 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(), TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(),
ptr->raw_ptr_not_for_readwrite()); ptr->raw_ptr_not_for_readwrite());
// update tensor desc for static infer // update tensor desc for static infer
if (dest->desc.layout.ndim) { dest->update_layout(ptr->layout());
mgb_assert(
dest->desc.layout.eq_shape(ptr->layout()),
"shape infer error, %s vs %s", dest->desc.layout.to_string().c_str(),
ptr->layout().to_string().c_str());
}
// in order to avoid performance impact, // in order to avoid performance impact,
// memory forwarding is disabled when DTR is enabled // memory forwarding is disabled when DTR is enabled
if (state.options.enable_dtr_auto_drop || state.options.disable_memory_forwarding) { if (state.options.enable_dtr_auto_drop || state.options.disable_memory_forwarding) {
......
...@@ -48,6 +48,7 @@ struct TensorInfo { ...@@ -48,6 +48,7 @@ struct TensorInfo {
// Lock interpreter when visiting `ptr`. // Lock interpreter when visiting `ptr`.
TensorPtr ptr; TensorPtr ptr;
LogicalTensorDesc desc; LogicalTensorDesc desc;
Spinlock lock;
double compute_time; double compute_time;
size_t memory; size_t memory;
...@@ -158,6 +159,26 @@ struct TensorInfo { ...@@ -158,6 +159,26 @@ struct TensorInfo {
// UINT_MAX as a magic default value // UINT_MAX as a magic default value
size_t cand_index = UINT_MAX; size_t cand_index = UINT_MAX;
bool shape_valid() {
MGB_LOCK_GUARD(lock);
return desc.layout.ndim;
}
void update_layout(const TensorLayout& layout) {
MGB_LOCK_GUARD(lock);
mgb_assert(desc.layout.dtype == layout.dtype, "dtype mismatch");
mgb_assert(desc.layout.format == layout.format, "format mismatch");
if (desc.layout.ndim) {
mgb_assert(
desc.layout.eq_shape(layout), "shape infer error, %s vs %s",
desc.layout.to_string().c_str(), layout.to_string().c_str());
// ignore strides
} else {
static_cast<TensorShape&>(desc.layout) = layout;
desc.layout.init_contiguous_stride();
}
}
}; };
} // namespace interpreter::intl } // namespace interpreter::intl
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册