From 13ef444fa63106c31e43fe488b05c8303c356f6d Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Wed, 27 Jan 2021 15:15:58 +0800 Subject: [PATCH] [Dy2Stat] Fix error message when the message has more than one lines. (#30714) --- .../fluid/dygraph/dygraph_to_static/error.py | 9 ++-- .../unittests/dygraph_to_static/test_error.py | 42 +++++++++++++++++++ 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/error.py b/python/paddle/fluid/dygraph/dygraph_to_static/error.py index a994fbb107..913b7cec60 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/error.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/error.py @@ -143,9 +143,12 @@ class ErrorData(object): message_lines.append(traceback_frame.formated_message()) # Step3: Adds error message like "TypeError: dtype must be int32, but received float32". - error_message = " " * 4 + traceback.format_exception_only( - self.error_type, self.error_value)[0].strip("\n") - message_lines.append(error_message) + # NOTE: `format_exception` is a list, its length is 1 in most cases, but sometimes its length + # is gather than 1, for example, the error_type is IndentationError. + format_exception = traceback.format_exception_only(self.error_type, + self.error_value) + error_message = [" " * 4 + line for line in format_exception] + message_lines.extend(error_message) return '\n'.join(message_lines) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py index 3c43cbc518..c177b556b8 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py @@ -82,6 +82,22 @@ class LayerErrorInCompiletime(fluid.dygraph.Layer): return out +class LayerErrorInCompiletime2(fluid.dygraph.Layer): + def __init__(self): + super(LayerErrorInCompiletime2, self).__init__() + + @paddle.jit.to_static + def forward(self): + self.test_func() + + def test_func(self): + """ + NOTE: The next line has a tab. And this test to check the IndentationError when spaces and tabs are mixed. + A tab here. + """ + return + + class TestFlags(unittest.TestCase): def setUp(self): self.reset_flags_to_default() @@ -230,6 +246,32 @@ class TestErrorStaticLayerCallInCompiletime_2( ] +class TestErrorStaticLayerCallInCompiletime_3( + TestErrorStaticLayerCallInCompiletime): + def setUp(self): + self.reset_flags_to_default() + self.set_func_call() + self.filepath = inspect.getfile(unwrap(self.func_call)) + self.set_exception_type() + self.set_message() + + def set_exception_type(self): + self.exception_type = IndentationError + + def set_message(self): + self.expected_message = \ + ['File "{}", line 91, in forward'.format(self.filepath), + 'self.test_func()', + ] + + def set_func_call(self): + layer = LayerErrorInCompiletime2() + self.func_call = lambda: layer() + + def test_error(self): + self._test_raise_new_exception() + + class TestErrorStaticLayerCallInRuntime(TestErrorStaticLayerCallInCompiletime): def set_func(self): self.func = func_error_in_runtime -- GitLab