From 7affca7e60e677d6e96f26227b3063ee472f907b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 14 Apr 2023 12:38:18 +0800 Subject: [PATCH] fix(imperative): fix try_cast usage of TensorWrapper GitOrigin-RevId: 5510bef70dafb54ed1dde49e32b87d5be834f07f --- imperative/python/megengine/jit/tracing.py | 7 +++++++ imperative/python/src/grad.cpp | 12 +++++++++--- imperative/python/src/module_trace.h | 4 +++- imperative/python/src/tensor_utils.cpp | 3 ++- 4 files changed, 21 insertions(+), 5 deletions(-) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 2c2e3e12e..f3a88eef2 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -239,6 +239,13 @@ class trace: } def _process_outputs(self, outputs): + assert ( + isinstance(outputs, RawTensor) + or ( + isinstance(outputs, Sequence) and not (isinstance(outputs[0], Sequence)) + ) + or isinstance(outputs, collections.abc.Mapping) + ), "Unsupport outputs type, should be Tensor, List[Tensor] or Dict[tensor_name, Tensor]" if isinstance(outputs, RawTensor): outputs = [outputs] if not isinstance(outputs, Sequence): diff --git a/imperative/python/src/grad.cpp b/imperative/python/src/grad.cpp index a7c9a3e8d..1f6e92049 100644 --- a/imperative/python/src/grad.cpp +++ b/imperative/python/src/grad.cpp @@ -51,10 +51,14 @@ void GradKeyWrapper::backward(GradKeyWrapper* self, py::list tensors, py::list g std::vector args; mgb_assert(tensors.size() == grads.size()); for (auto&& tensor : tensors) { - args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data()); + auto* tw = TensorWrapper::try_cast(tensor.ptr()); + mgb_assert(tw, "expect Tensor"); + args.push_back(tw->m_tensor->data()); } for (auto&& grad : grads) { - args.push_back(TensorWrapper::try_cast(grad.ptr())->m_tensor->data()); + auto* tw = TensorWrapper::try_cast(grad.ptr()); + mgb_assert(tw, "expect Tensor"); + args.push_back(tw->m_tensor->data()); } imperative::apply(GradBackward(self->m_key), {args.data(), args.size()}); } @@ -63,7 +67,9 @@ pybind11::function GradKeyWrapper::get_backward_closure( GradKeyWrapper* self, py::list tensors) { std::vector args; for (auto&& tensor : tensors) { - args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data()); + auto* tw = TensorWrapper::try_cast(tensor.ptr()); + mgb_assert(tw, "expect Tensor"); + args.push_back(tw->m_tensor->data()); } auto closure_value = imperative::apply(GetBackwardColsure(self->m_key), args)[0]; auto closure = closure_value.as_ref(); diff --git a/imperative/python/src/module_trace.h b/imperative/python/src/module_trace.h index c9f195a8d..50cd013da 100644 --- a/imperative/python/src/module_trace.h +++ b/imperative/python/src/module_trace.h @@ -24,7 +24,9 @@ private: ValueRefList outputs(output_tws.size()); auto it = outputs.begin(); for (auto&& output_tw : output_tws) { - *(it++) = TensorWrapper::try_cast(output_tw.ptr())->m_tensor->data(); + auto* tw = TensorWrapper::try_cast(output_tw.ptr()); + mgb_assert(tw, "expect Tensor"); + *(it++) = tw->m_tensor->data(); } return outputs; } diff --git a/imperative/python/src/tensor_utils.cpp b/imperative/python/src/tensor_utils.cpp index 638b538ec..28f7e4ee6 100644 --- a/imperative/python/src/tensor_utils.cpp +++ b/imperative/python/src/tensor_utils.cpp @@ -418,7 +418,8 @@ py::object get_res_by_refhdl( } mgb::DType _get_dtype(py::handle tensor) { - auto tw = TensorWrapper::try_cast(tensor.ptr()); + auto* tw = TensorWrapper::try_cast(tensor.ptr()); + mgb_assert(tw, "expect Tensor"); return tw->m_tensor->dtype(); } -- GitLab