提交 f70d644a 编写于 作者: M Megvii Engine Team

fix(trace): add error message for dumping func without return value

GitOrigin-RevId: 1a5885a7688378aa01fbfe530dd8bdfe31dd995d
上级 08f7a957
...@@ -9,7 +9,7 @@ import pickle ...@@ -9,7 +9,7 @@ import pickle
import re import re
import struct import struct
import sys import sys
from typing import Any from typing import Any, Sequence
import cv2 import cv2
import numpy as np import numpy as np
...@@ -241,6 +241,8 @@ class trace: ...@@ -241,6 +241,8 @@ class trace:
def _process_outputs(self, outputs): def _process_outputs(self, outputs):
if isinstance(outputs, RawTensor): if isinstance(outputs, RawTensor):
outputs = [outputs] outputs = [outputs]
if not isinstance(outputs, Sequence):
outputs = [outputs]
if isinstance(outputs, collections.abc.Mapping): if isinstance(outputs, collections.abc.Mapping):
output_names, outputs = zip(*sorted(outputs.items())) output_names, outputs = zip(*sorted(outputs.items()))
else: else:
...@@ -248,6 +250,9 @@ class trace: ...@@ -248,6 +250,9 @@ class trace:
output_names = None output_names = None
self._output_names = output_names self._output_names = output_names
for i, output in enumerate(outputs): for i, output in enumerate(outputs):
assert isinstance(
output, RawTensor
), "Only support return tensors when capture_as_const is enabled"
name_tensor("output_{}".format(i), output) name_tensor("output_{}".format(i), output)
if self._output_bindings is None: if self._output_bindings is None:
self._output_bindings = ["output_{}".format(i) for i in range(len(outputs))] self._output_bindings = ["output_{}".format(i) for i in range(len(outputs))]
...@@ -679,6 +684,10 @@ class trace: ...@@ -679,6 +684,10 @@ class trace:
raise ValueError( raise ValueError(
"you must specify capture_as_const=True at __init__ to use dump" "you must specify capture_as_const=True at __init__ to use dump"
) )
if not hasattr(self, "_output_names"):
raise ValueError(
"the traced function without return values cannot be dumped, the traced function should return List[Tensor] or Dict[str, Tensor]"
)
if self._output_names and output_names: if self._output_names and output_names:
raise TypeError( raise TypeError(
"cannot specify output_names when output is already in dict format" "cannot specify output_names when output is already in dict format"
......
...@@ -782,3 +782,24 @@ def test_invalid_inp_error(): ...@@ -782,3 +782,24 @@ def test_invalid_inp_error():
) )
else: else:
assert False assert False
def test_dump_without_output_error():
def forward(x):
return x * x
@trace(symbolic=True, capture_as_const=True)
def f(x):
y = forward(x)
data = tensor([1.0, 2.0, 3.0])
f(data)
try:
file = io.BytesIO()
f.dump(file, arg_names=["x"])
except Exception as e:
assert (
str(e)
== "the traced function without return values cannot be dumped, the traced function should return List[Tensor] or Dict[str, Tensor]"
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册