From e59b6e13a33da79d080cb45a671d14b419e9edd5 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 11 Apr 2022 16:20:17 +0800 Subject: [PATCH] fix(imperative/src): fix empty_tensor bug of convbwd&rng GitOrigin-RevId: 4c948f41f04649620ce7b34c5f3dac69d66705e2 --- imperative/src/impl/ops/convolution.cpp | 4 ++-- imperative/src/impl/ops/rng.cpp | 7 ++++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/imperative/src/impl/ops/convolution.cpp b/imperative/src/impl/ops/convolution.cpp index a4d22c88c..3832b7a1b 100644 --- a/imperative/src/impl/ops/convolution.cpp +++ b/imperative/src/impl/ops/convolution.cpp @@ -354,8 +354,8 @@ std::tuple, bool> infer_output_attrs_fallible( TensorLayout diff = inputs[1].layout; size_t filter_ndim = filter.ndim; size_t diff_ndim = diff.ndim; - if (filter_ndim == 0) { - desc.layout = filter; + if (diff_ndim == 0) { + desc.layout = diff; return {dests, false}; } diff --git a/imperative/src/impl/ops/rng.cpp b/imperative/src/impl/ops/rng.cpp index a59e44f83..561ac269c 100644 --- a/imperative/src/impl/ops/rng.cpp +++ b/imperative/src/impl/ops/rng.cpp @@ -548,6 +548,7 @@ Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { template std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { + bool success = inputs[0].layout.ndim != 0; LogicalTensorDesc dest; auto&& xxx_rng_def = def.cast_final_safe(); size_t nr_inp = inputs.size(); @@ -558,7 +559,11 @@ std::tuple, bool> infer_output_attrs_fallible( xxx_rng_def.dyn_typeinfo()->name, nr_inp); } dest.comp_node = inputs[0].comp_node; - dest.layout = _InferLayout::do_infer(inputs[0], xxx_rng_def); + if (success) { + dest.layout = _InferLayout::do_infer(inputs[0], xxx_rng_def); + } else { + dest.layout = TensorLayout(inputs[0].layout.dtype); + } return {{dest}, inputs[0].layout.ndim != 0}; } -- GitLab