提交 7affca7e 编写于 作者: M Megvii Engine Team

fix(imperative): fix try_cast usage of TensorWrapper

GitOrigin-RevId: 5510bef70dafb54ed1dde49e32b87d5be834f07f
上级 986fc998
......@@ -239,6 +239,13 @@ class trace:
}
def _process_outputs(self, outputs):
assert (
isinstance(outputs, RawTensor)
or (
isinstance(outputs, Sequence) and not (isinstance(outputs[0], Sequence))
)
or isinstance(outputs, collections.abc.Mapping)
), "Unsupport outputs type, should be Tensor, List[Tensor] or Dict[tensor_name, Tensor]"
if isinstance(outputs, RawTensor):
outputs = [outputs]
if not isinstance(outputs, Sequence):
......
......@@ -51,10 +51,14 @@ void GradKeyWrapper::backward(GradKeyWrapper* self, py::list tensors, py::list g
std::vector<ValueRef> args;
mgb_assert(tensors.size() == grads.size());
for (auto&& tensor : tensors) {
args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data());
auto* tw = TensorWrapper::try_cast(tensor.ptr());
mgb_assert(tw, "expect Tensor");
args.push_back(tw->m_tensor->data());
}
for (auto&& grad : grads) {
args.push_back(TensorWrapper::try_cast(grad.ptr())->m_tensor->data());
auto* tw = TensorWrapper::try_cast(grad.ptr());
mgb_assert(tw, "expect Tensor");
args.push_back(tw->m_tensor->data());
}
imperative::apply(GradBackward(self->m_key), {args.data(), args.size()});
}
......@@ -63,7 +67,9 @@ pybind11::function GradKeyWrapper::get_backward_closure(
GradKeyWrapper* self, py::list tensors) {
std::vector<ValueRef> args;
for (auto&& tensor : tensors) {
args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data());
auto* tw = TensorWrapper::try_cast(tensor.ptr());
mgb_assert(tw, "expect Tensor");
args.push_back(tw->m_tensor->data());
}
auto closure_value = imperative::apply(GetBackwardColsure(self->m_key), args)[0];
auto closure = closure_value.as_ref<FunctionValue>();
......
......@@ -24,7 +24,9 @@ private:
ValueRefList outputs(output_tws.size());
auto it = outputs.begin();
for (auto&& output_tw : output_tws) {
*(it++) = TensorWrapper::try_cast(output_tw.ptr())->m_tensor->data();
auto* tw = TensorWrapper::try_cast(output_tw.ptr());
mgb_assert(tw, "expect Tensor");
*(it++) = tw->m_tensor->data();
}
return outputs;
}
......
......@@ -418,7 +418,8 @@ py::object get_res_by_refhdl(
}
mgb::DType _get_dtype(py::handle tensor) {
auto tw = TensorWrapper::try_cast(tensor.ptr());
auto* tw = TensorWrapper::try_cast(tensor.ptr());
mgb_assert(tw, "expect Tensor");
return tw->m_tensor->dtype();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册