From 6a2348f4f7c81cad5c33267309ba4d2b804b9781 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 28 Jan 2022 14:00:18 +0800 Subject: [PATCH] fix(trace): assume result is not scalar when shape is valid GitOrigin-RevId: beee2d0f28620cc3410d5c4172e0413e012114fd --- .../python/megengine/core/tensor/indexing.py | 14 ++++++++----- imperative/python/src/tensor.cpp | 10 +++------ .../src/impl/transformations/scalar.cpp | 21 ++++++++++++------- 3 files changed, 26 insertions(+), 19 deletions(-) diff --git a/imperative/python/megengine/core/tensor/indexing.py b/imperative/python/megengine/core/tensor/indexing.py index f55136392..4dc1da3bd 100644 --- a/imperative/python/megengine/core/tensor/indexing.py +++ b/imperative/python/megengine/core/tensor/indexing.py @@ -119,12 +119,16 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): else 1 ) else: - if ndim_indexed > inp.ndim: - raise IndexError( - "too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format( - inp.ndim, len(tuple_val) + try: + if ndim_indexed > inp.ndim: + raise IndexError( + "too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format( + inp.ndim, len(tuple_val) + ) ) - ) + except ValueError: + # ignore + pass tuple_val = remove_ellipsis(inp, tuple_val) use_subtensor = True diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 8cede64a8..d81de0415 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -272,16 +272,12 @@ PyObject* TensorWrapper::device() { PyObject* TensorWrapper::numpy() { auto hv = m_tensor->numpy(); - // if (!hv) { - // PyErr_SetString(PyExc_ValueError, "tensor invalid"); - // return nullptr; - // } - auto arr = py::reinterpret_steal( - npy::ndarray_from_tensor(hv->as_nd(true), npy::ShareType::TRY_SHARE)); - if (!arr) { + if (!hv) { PyErr_SetString(PyExc_ValueError, "tensor invalid"); return nullptr; } + auto arr = py::reinterpret_steal( + npy::ndarray_from_tensor(hv->as_nd(true), npy::ShareType::TRY_SHARE)); if (hv->shape().is_scalar()) { mgb_assert(PyArray_Check(arr.ptr())); return PyArray_Squeeze(reinterpret_cast(arr.ptr())); diff --git a/imperative/src/impl/transformations/scalar.cpp b/imperative/src/impl/transformations/scalar.cpp index 891abdd7c..8daa5827e 100644 --- a/imperative/src/impl/transformations/scalar.cpp +++ b/imperative/src/impl/transformations/scalar.cpp @@ -51,6 +51,7 @@ bool is_scalar_shape(ValueRef shape) { if (shape.is()) { return false; } + // may have performance issue auto shape_of_shape = shape.shape(); if (!shape_of_shape) { // assume not scalar @@ -211,14 +212,21 @@ std::vector subtensor_rule( const Subtensor& subtensor, Span inputs) { mgb_assert(inputs.size() >= 1); auto input = inputs[0]; - size_t ndim = input.is() ? 0 : input.shape()->ndim; - for (auto&& [axis, begin, end, step, idx] : subtensor.items) { - if (idx) { - ndim--; + bool is_scalar; + mgb_assert(!input.is(), "subtensor shouldn't have scalar input"); + if (auto shape = input.shape()) { + size_t ndim = input.shape()->ndim; + for (auto&& [axis, begin, end, step, idx] : subtensor.items) { + if (idx) { + ndim--; + } } + is_scalar = ndim == 0; + } else { + is_scalar = false; } auto output = imperative::apply(subtensor, unwrap_inputs(inputs))[0]; - if (!ndim) { + if (is_scalar) { return {ScalarValue::make(output)}; } else { return {output}; @@ -261,8 +269,7 @@ std::vector fastpath_copy_rule( std::vector reshape_rule(const Reshape& reshape, Span inputs) { mgb_assert(inputs.size() == 2); - bool is_scalar = - (!inputs[1].is()) && *inputs[1].shape() == ValueShape{0}; + bool is_scalar = is_scalar_shape(inputs[1]); auto unwrapped_input = inputs[0].is() ? inputs[0].cast().value() : inputs[0]; -- GitLab