提交 7affca7e 编写于 作者: M Megvii Engine Team

fix(imperative): fix try_cast usage of TensorWrapper

GitOrigin-RevId: 5510bef70dafb54ed1dde49e32b87d5be834f07f
上级 986fc998
...@@ -239,6 +239,13 @@ class trace: ...@@ -239,6 +239,13 @@ class trace:
} }
def _process_outputs(self, outputs): def _process_outputs(self, outputs):
assert (
isinstance(outputs, RawTensor)
or (
isinstance(outputs, Sequence) and not (isinstance(outputs[0], Sequence))
)
or isinstance(outputs, collections.abc.Mapping)
), "Unsupport outputs type, should be Tensor, List[Tensor] or Dict[tensor_name, Tensor]"
if isinstance(outputs, RawTensor): if isinstance(outputs, RawTensor):
outputs = [outputs] outputs = [outputs]
if not isinstance(outputs, Sequence): if not isinstance(outputs, Sequence):
......
...@@ -51,10 +51,14 @@ void GradKeyWrapper::backward(GradKeyWrapper* self, py::list tensors, py::list g ...@@ -51,10 +51,14 @@ void GradKeyWrapper::backward(GradKeyWrapper* self, py::list tensors, py::list g
std::vector<ValueRef> args; std::vector<ValueRef> args;
mgb_assert(tensors.size() == grads.size()); mgb_assert(tensors.size() == grads.size());
for (auto&& tensor : tensors) { for (auto&& tensor : tensors) {
args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data()); auto* tw = TensorWrapper::try_cast(tensor.ptr());
mgb_assert(tw, "expect Tensor");
args.push_back(tw->m_tensor->data());
} }
for (auto&& grad : grads) { for (auto&& grad : grads) {
args.push_back(TensorWrapper::try_cast(grad.ptr())->m_tensor->data()); auto* tw = TensorWrapper::try_cast(grad.ptr());
mgb_assert(tw, "expect Tensor");
args.push_back(tw->m_tensor->data());
} }
imperative::apply(GradBackward(self->m_key), {args.data(), args.size()}); imperative::apply(GradBackward(self->m_key), {args.data(), args.size()});
} }
...@@ -63,7 +67,9 @@ pybind11::function GradKeyWrapper::get_backward_closure( ...@@ -63,7 +67,9 @@ pybind11::function GradKeyWrapper::get_backward_closure(
GradKeyWrapper* self, py::list tensors) { GradKeyWrapper* self, py::list tensors) {
std::vector<ValueRef> args; std::vector<ValueRef> args;
for (auto&& tensor : tensors) { for (auto&& tensor : tensors) {
args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data()); auto* tw = TensorWrapper::try_cast(tensor.ptr());
mgb_assert(tw, "expect Tensor");
args.push_back(tw->m_tensor->data());
} }
auto closure_value = imperative::apply(GetBackwardColsure(self->m_key), args)[0]; auto closure_value = imperative::apply(GetBackwardColsure(self->m_key), args)[0];
auto closure = closure_value.as_ref<FunctionValue>(); auto closure = closure_value.as_ref<FunctionValue>();
......
...@@ -24,7 +24,9 @@ private: ...@@ -24,7 +24,9 @@ private:
ValueRefList outputs(output_tws.size()); ValueRefList outputs(output_tws.size());
auto it = outputs.begin(); auto it = outputs.begin();
for (auto&& output_tw : output_tws) { for (auto&& output_tw : output_tws) {
*(it++) = TensorWrapper::try_cast(output_tw.ptr())->m_tensor->data(); auto* tw = TensorWrapper::try_cast(output_tw.ptr());
mgb_assert(tw, "expect Tensor");
*(it++) = tw->m_tensor->data();
} }
return outputs; return outputs;
} }
......
...@@ -418,7 +418,8 @@ py::object get_res_by_refhdl( ...@@ -418,7 +418,8 @@ py::object get_res_by_refhdl(
} }
mgb::DType _get_dtype(py::handle tensor) { mgb::DType _get_dtype(py::handle tensor) {
auto tw = TensorWrapper::try_cast(tensor.ptr()); auto* tw = TensorWrapper::try_cast(tensor.ptr());
mgb_assert(tw, "expect Tensor");
return tw->m_tensor->dtype(); return tw->m_tensor->dtype();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册