提交 f70d644a 编写于 作者: M Megvii Engine Team

fix(trace): add error message for dumping func without return value

GitOrigin-RevId: 1a5885a7688378aa01fbfe530dd8bdfe31dd995d
上级 08f7a957
......@@ -9,7 +9,7 @@ import pickle
import re
import struct
import sys
from typing import Any
from typing import Any, Sequence
import cv2
import numpy as np
......@@ -241,6 +241,8 @@ class trace:
def _process_outputs(self, outputs):
if isinstance(outputs, RawTensor):
outputs = [outputs]
if not isinstance(outputs, Sequence):
outputs = [outputs]
if isinstance(outputs, collections.abc.Mapping):
output_names, outputs = zip(*sorted(outputs.items()))
else:
......@@ -248,6 +250,9 @@ class trace:
output_names = None
self._output_names = output_names
for i, output in enumerate(outputs):
assert isinstance(
output, RawTensor
), "Only support return tensors when capture_as_const is enabled"
name_tensor("output_{}".format(i), output)
if self._output_bindings is None:
self._output_bindings = ["output_{}".format(i) for i in range(len(outputs))]
......@@ -679,6 +684,10 @@ class trace:
raise ValueError(
"you must specify capture_as_const=True at __init__ to use dump"
)
if not hasattr(self, "_output_names"):
raise ValueError(
"the traced function without return values cannot be dumped, the traced function should return List[Tensor] or Dict[str, Tensor]"
)
if self._output_names and output_names:
raise TypeError(
"cannot specify output_names when output is already in dict format"
......
......@@ -782,3 +782,24 @@ def test_invalid_inp_error():
)
else:
assert False
def test_dump_without_output_error():
def forward(x):
return x * x
@trace(symbolic=True, capture_as_const=True)
def f(x):
y = forward(x)
data = tensor([1.0, 2.0, 3.0])
f(data)
try:
file = io.BytesIO()
f.dump(file, arg_names=["x"])
except Exception as e:
assert (
str(e)
== "the traced function without return values cannot be dumped, the traced function should return List[Tensor] or Dict[str, Tensor]"
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册