From e02d80338b57cc0a5cff59683090b2334d0fe391 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 31 Mar 2023 11:15:19 +0800 Subject: [PATCH] fix(trace): add error message for dumping func without return value GitOrigin-RevId: 1a5885a7688378aa01fbfe530dd8bdfe31dd995d --- imperative/python/megengine/jit/tracing.py | 11 +++++++++- .../python/test/unit/jit/test_tracing.py | 21 +++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 28d8eeb77..2c2e3e12e 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -9,7 +9,7 @@ import pickle import re import struct import sys -from typing import Any +from typing import Any, Sequence import cv2 import numpy as np @@ -241,6 +241,8 @@ class trace: def _process_outputs(self, outputs): if isinstance(outputs, RawTensor): outputs = [outputs] + if not isinstance(outputs, Sequence): + outputs = [outputs] if isinstance(outputs, collections.abc.Mapping): output_names, outputs = zip(*sorted(outputs.items())) else: @@ -248,6 +250,9 @@ class trace: output_names = None self._output_names = output_names for i, output in enumerate(outputs): + assert isinstance( + output, RawTensor + ), "Only support return tensors when capture_as_const is enabled" name_tensor("output_{}".format(i), output) if self._output_bindings is None: self._output_bindings = ["output_{}".format(i) for i in range(len(outputs))] @@ -679,6 +684,10 @@ class trace: raise ValueError( "you must specify capture_as_const=True at __init__ to use dump" ) + if not hasattr(self, "_output_names"): + raise ValueError( + "the traced function without return values cannot be dumped, the traced function should return List[Tensor] or Dict[str, Tensor]" + ) if self._output_names and output_names: raise TypeError( "cannot specify output_names when output is already in dict format" diff --git a/imperative/python/test/unit/jit/test_tracing.py b/imperative/python/test/unit/jit/test_tracing.py index dc0d3f5f3..f0eb376db 100644 --- a/imperative/python/test/unit/jit/test_tracing.py +++ b/imperative/python/test/unit/jit/test_tracing.py @@ -782,3 +782,24 @@ def test_invalid_inp_error(): ) else: assert False + + +def test_dump_without_output_error(): + def forward(x): + return x * x + + @trace(symbolic=True, capture_as_const=True) + def f(x): + y = forward(x) + + data = tensor([1.0, 2.0, 3.0]) + + f(data) + try: + file = io.BytesIO() + f.dump(file, arg_names=["x"]) + except Exception as e: + assert ( + str(e) + == "the traced function without return values cannot be dumped, the traced function should return List[Tensor] or Dict[str, Tensor]" + ) -- GitLab