提交 85ea882c 编写于 作者: M Megvii Engine Team

fix(mgb/ops): immutable tensor support empty storage

GitOrigin-RevId: 2851498fce49d2c0801e1ef11cead37fbfddb974
上级 d8d5edb3
...@@ -100,7 +100,7 @@ void intl::HostIONodeBase::init_output_static_infer_desc() { ...@@ -100,7 +100,7 @@ void intl::HostIONodeBase::init_output_static_infer_desc() {
if (fill_in_static_infer(nullptr)) { if (fill_in_static_infer(nullptr)) {
auto infer_val = [this](DeviceTensorND& dest, const InpVal&) -> bool { auto infer_val = [this](DeviceTensorND& dest, const InpVal&) -> bool {
if (fill_in_static_infer(&dest) && !dest.empty()) { if (fill_in_static_infer(&dest) && dest.shape_valid()) {
return true; return true;
} }
return false; return false;
...@@ -423,8 +423,8 @@ void ImmutableTensor::Value::setup(CompNode cn, const HostTensorND& val) { ...@@ -423,8 +423,8 @@ void ImmutableTensor::Value::setup(CompNode cn, const HostTensorND& val) {
DeviceTensorND& ImmutableTensor::Value::static_infer() { DeviceTensorND& ImmutableTensor::Value::static_infer() {
MGB_LOCK_GUARD(m_mtx); MGB_LOCK_GUARD(m_mtx);
if (m_static_infer.empty()) { if (!m_static_infer.shape_valid()) {
mgb_assert(!m_dev.empty()); mgb_assert(m_dev.shape_valid());
m_static_infer.comp_node(CompNode::default_cpu()).copy_from(m_dev); m_static_infer.comp_node(CompNode::default_cpu()).copy_from(m_dev);
} }
return m_static_infer; return m_static_infer;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册