From 2d72de8a87bbe0411464fe709092cd2d123b1895 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 30 Mar 2022 12:43:42 +0800 Subject: [PATCH] fix(mge): fix infer output attrs fallible GitOrigin-RevId: ea18c7f753aa1b48fd7ff66f7035c70b4745164d --- imperative/python/megengine/functional/nn.py | 2 +- imperative/src/impl/interpreter/interpreter_impl.cpp | 1 - imperative/src/impl/ops/reduce.cpp | 6 ++++++ imperative/src/impl/ops/tensor_manip.cpp | 3 +++ 4 files changed, 10 insertions(+), 2 deletions(-) diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 4fbff5b40..afe30ccff 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1509,7 +1509,7 @@ def sync_batch_norm( """ _eps_mode = eps_mode.lower() assert _eps_mode in {"max", "additive"}, "unknown eps_mode: {}".format(eps_mode) - if _eps_mode == "additive" and not (is_distributed() or training): + if _eps_mode == "additive" and not (is_distributed() and training): return batch_norm( inp, running_mean, diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index e3d5d0595..a2c3b54c3 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -717,7 +717,6 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { if (state.options.enable_dtr_auto_drop || state.options.disable_memory_forwarding) { ptr->to_contiguous_inplace(); } - dest->desc.layout = ptr->layout(); dest->desc.comp_node = ptr->comp_node(); dest->memory = ptr->blob()->size(); dest->ptr = std::move(ptr); diff --git a/imperative/src/impl/ops/reduce.cpp b/imperative/src/impl/ops/reduce.cpp index 5000fea6c..ffd2400e0 100644 --- a/imperative/src/impl/ops/reduce.cpp +++ b/imperative/src/impl/ops/reduce.cpp @@ -205,6 +205,12 @@ std::tuple, bool> infer_output_attrs_fallible( size_t size = inputs.size(); SmallVector dests(size); + for (size_t i = 0; i < size; i++) { + if (inputs[i].layout.ndim == 0) { + return {{{TensorLayout(inputs[0].layout.dtype), inputs[0].comp_node}}, + false}; + } + } if (size > 1) { auto [output_descs, validated] = proxy_graph_detail::infer_output_attrs_fallible(def, inputs); diff --git a/imperative/src/impl/ops/tensor_manip.cpp b/imperative/src/impl/ops/tensor_manip.cpp index c24b06b75..7f06da31b 100644 --- a/imperative/src/impl/ops/tensor_manip.cpp +++ b/imperative/src/impl/ops/tensor_manip.cpp @@ -115,6 +115,9 @@ std::tuple, bool> infer_output_attrs_fallible( TensorShapeArray src(inputs.size()); for (size_t i = 0; i < inputs.size(); ++i) { src[i] = inputs[i].layout; + if (!src[i].ndim) { + return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false}; + } } megdnn::Elemwise::deduce_shape(src, shp); } -- GitLab