diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 2c2e3e12edf64f5acdc081aec350c7293f21fabb..f3a88eef23f7a464ec2eeab72409c4136b019056 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -239,6 +239,13 @@ class trace: } def _process_outputs(self, outputs): + assert ( + isinstance(outputs, RawTensor) + or ( + isinstance(outputs, Sequence) and not (isinstance(outputs[0], Sequence)) + ) + or isinstance(outputs, collections.abc.Mapping) + ), "Unsupport outputs type, should be Tensor, List[Tensor] or Dict[tensor_name, Tensor]" if isinstance(outputs, RawTensor): outputs = [outputs] if not isinstance(outputs, Sequence): diff --git a/imperative/python/src/grad.cpp b/imperative/python/src/grad.cpp index a7c9a3e8df706277199cc3671b2370738678779c..1f6e920493d2b3e60e55972bc3ed3fb2be883181 100644 --- a/imperative/python/src/grad.cpp +++ b/imperative/python/src/grad.cpp @@ -51,10 +51,14 @@ void GradKeyWrapper::backward(GradKeyWrapper* self, py::list tensors, py::list g std::vector args; mgb_assert(tensors.size() == grads.size()); for (auto&& tensor : tensors) { - args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data()); + auto* tw = TensorWrapper::try_cast(tensor.ptr()); + mgb_assert(tw, "expect Tensor"); + args.push_back(tw->m_tensor->data()); } for (auto&& grad : grads) { - args.push_back(TensorWrapper::try_cast(grad.ptr())->m_tensor->data()); + auto* tw = TensorWrapper::try_cast(grad.ptr()); + mgb_assert(tw, "expect Tensor"); + args.push_back(tw->m_tensor->data()); } imperative::apply(GradBackward(self->m_key), {args.data(), args.size()}); } @@ -63,7 +67,9 @@ pybind11::function GradKeyWrapper::get_backward_closure( GradKeyWrapper* self, py::list tensors) { std::vector args; for (auto&& tensor : tensors) { - args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data()); + auto* tw = TensorWrapper::try_cast(tensor.ptr()); + mgb_assert(tw, "expect Tensor"); + args.push_back(tw->m_tensor->data()); } auto closure_value = imperative::apply(GetBackwardColsure(self->m_key), args)[0]; auto closure = closure_value.as_ref(); diff --git a/imperative/python/src/module_trace.h b/imperative/python/src/module_trace.h index c9f195a8d682562162474c2cf26e8fbe881ede89..50cd013daa5e67b72c4fbc76f45c3ca68740be51 100644 --- a/imperative/python/src/module_trace.h +++ b/imperative/python/src/module_trace.h @@ -24,7 +24,9 @@ private: ValueRefList outputs(output_tws.size()); auto it = outputs.begin(); for (auto&& output_tw : output_tws) { - *(it++) = TensorWrapper::try_cast(output_tw.ptr())->m_tensor->data(); + auto* tw = TensorWrapper::try_cast(output_tw.ptr()); + mgb_assert(tw, "expect Tensor"); + *(it++) = tw->m_tensor->data(); } return outputs; } diff --git a/imperative/python/src/tensor_utils.cpp b/imperative/python/src/tensor_utils.cpp index 638b538ecd73f38f9c3d767df9543693742f5d86..28f7e4ee6d77e52ad44103a0c3f508804aa9f5b9 100644 --- a/imperative/python/src/tensor_utils.cpp +++ b/imperative/python/src/tensor_utils.cpp @@ -418,7 +418,8 @@ py::object get_res_by_refhdl( } mgb::DType _get_dtype(py::handle tensor) { - auto tw = TensorWrapper::try_cast(tensor.ptr()); + auto* tw = TensorWrapper::try_cast(tensor.ptr()); + mgb_assert(tw, "expect Tensor"); return tw->m_tensor->dtype(); }