diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 28d8eeb7729b8b2c530c0b25504f0447f45fd697..2c2e3e12edf64f5acdc081aec350c7293f21fabb 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -9,7 +9,7 @@ import pickle import re import struct import sys -from typing import Any +from typing import Any, Sequence import cv2 import numpy as np @@ -241,6 +241,8 @@ class trace: def _process_outputs(self, outputs): if isinstance(outputs, RawTensor): outputs = [outputs] + if not isinstance(outputs, Sequence): + outputs = [outputs] if isinstance(outputs, collections.abc.Mapping): output_names, outputs = zip(*sorted(outputs.items())) else: @@ -248,6 +250,9 @@ class trace: output_names = None self._output_names = output_names for i, output in enumerate(outputs): + assert isinstance( + output, RawTensor + ), "Only support return tensors when capture_as_const is enabled" name_tensor("output_{}".format(i), output) if self._output_bindings is None: self._output_bindings = ["output_{}".format(i) for i in range(len(outputs))] @@ -679,6 +684,10 @@ class trace: raise ValueError( "you must specify capture_as_const=True at __init__ to use dump" ) + if not hasattr(self, "_output_names"): + raise ValueError( + "the traced function without return values cannot be dumped, the traced function should return List[Tensor] or Dict[str, Tensor]" + ) if self._output_names and output_names: raise TypeError( "cannot specify output_names when output is already in dict format" diff --git a/imperative/python/test/unit/jit/test_tracing.py b/imperative/python/test/unit/jit/test_tracing.py index dc0d3f5f312c7f933329f829c3bb0391da42e492..f0eb376db2eb4309a85d542759886ad1fc0951f8 100644 --- a/imperative/python/test/unit/jit/test_tracing.py +++ b/imperative/python/test/unit/jit/test_tracing.py @@ -782,3 +782,24 @@ def test_invalid_inp_error(): ) else: assert False + + +def test_dump_without_output_error(): + def forward(x): + return x * x + + @trace(symbolic=True, capture_as_const=True) + def f(x): + y = forward(x) + + data = tensor([1.0, 2.0, 3.0]) + + f(data) + try: + file = io.BytesIO() + f.dump(file, arg_names=["x"]) + except Exception as e: + assert ( + str(e) + == "the traced function without return values cannot be dumped, the traced function should return List[Tensor] or Dict[str, Tensor]" + )