From b24f84c883e09f74820216ab63aa1c76ba294c30 Mon Sep 17 00:00:00 2001 From: 0x45f <23097963+0x45f@users.noreply.github.com> Date: Wed, 1 Sep 2021 13:17:46 +0800 Subject: [PATCH] [Dy2stat]modify dy2stat error message in compile time (#35320) * modify dy2stat error message in compile time * fix variable name --- .../fluid/dygraph/dygraph_to_static/error.py | 73 ++++++++++++++++--- .../unittests/dygraph_to_static/test_error.py | 21 +++++- 2 files changed, 79 insertions(+), 15 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/error.py b/python/paddle/fluid/dygraph/dygraph_to_static/error.py index 913b7cec602..66d3b58f4c2 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/error.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/error.py @@ -16,6 +16,7 @@ import os import six import sys import traceback +import linecache from paddle.fluid.dygraph.dygraph_to_static.origin_info import Location, OriginInfo, global_origin_info_map @@ -29,6 +30,9 @@ DEFAULT_SIMPLIFY_NEW_ERROR = 1 DISABLE_ERROR_ENV_NAME = "TRANSLATOR_DISABLE_NEW_ERROR" DEFAULT_DISABLE_NEW_ERROR = 0 +SOURCE_CODE_RANGE = 5 +BLANK_COUNT_BEFORE_FILE_STR = 4 + def attach_error_data(error, in_runtime=False): """ @@ -40,6 +44,7 @@ def attach_error_data(error, in_runtime=False): Returns: An error attached data about original source code information and traceback. """ + e_type, e_value, e_traceback = sys.exc_info() tb = traceback.extract_tb(e_traceback)[1:] @@ -82,12 +87,49 @@ class TraceBackFrame(OriginInfo): def formated_message(self): # self.source_code may be empty in some functions. # For example, decorator generated function - return ' File "{}", line {}, in {}\n\t{}'.format( + return ' ' * BLANK_COUNT_BEFORE_FILE_STR + 'File "{}", line {}, in {}\n\t{}'.format( self.location.filepath, self.location.lineno, self.function_name, self.source_code.lstrip() if isinstance(self.source_code, str) else self.source_code) +class TraceBackFrameRange(OriginInfo): + """ + Traceback frame information. + """ + + def __init__(self, location, function_name): + self.location = location + self.function_name = function_name + self.source_code = [] + blank_count = [] + begin_lineno = max(1, self.location.lineno - int(SOURCE_CODE_RANGE / 2)) + + for i in range(begin_lineno, begin_lineno + SOURCE_CODE_RANGE): + line = linecache.getline(self.location.filepath, i) + line_lstrip = line.strip() + self.source_code.append(line_lstrip) + blank_count.append(len(line) - len(line_lstrip)) + + if i == self.location.lineno: + hint_msg = '~' * len(self.source_code[-1]) + ' <--- HERE' + self.source_code.append(hint_msg) + blank_count.append(blank_count[-1]) + linecache.clearcache() + + min_black_count = min(blank_count) + for i in range(len(self.source_code)): + self.source_code[i] = ' ' * (blank_count[i] - min_black_count + + BLANK_COUNT_BEFORE_FILE_STR * 2 + ) + self.source_code[i] + + def formated_message(self): + msg = ' ' * BLANK_COUNT_BEFORE_FILE_STR + 'File "{}", line {}, in {}\n'.format( + self.location.filepath, self.location.lineno, self.function_name) + # add empty line after range code + return msg + '\n'.join(self.source_code) + '\n' + + class ErrorData(object): """ Error data attached to an exception which is raised in un-transformed code. @@ -128,26 +170,34 @@ class ErrorData(object): return '\n'.join(message_lines) # Step2: Optimizes stack information with source code information of dygraph from user. - for filepath, lineno, funcname, code in self.origin_traceback: + whether_source_range = True + for filepath, lineno, funcname, code in self.origin_traceback[::-1]: loc = Location(filepath, lineno) - dygraph_func_info = self.origin_info_map.get(loc.line_location, None) if dygraph_func_info: - # TODO(liym27): more information to prompt users that this is the original information. - # Replaces trace stack information about transformed static code with original dygraph code. - traceback_frame = self.origin_info_map[loc.line_location] - else: - traceback_frame = TraceBackFrame(loc, funcname, code) - - message_lines.append(traceback_frame.formated_message()) + if whether_source_range: + traceback_frame = TraceBackFrameRange( + dygraph_func_info.location, + dygraph_func_info.function_name) + whether_source_range = False + else: + traceback_frame = TraceBackFrame( + dygraph_func_info.location, + dygraph_func_info.function_name, + dygraph_func_info.source_code) + # Two elements already exist in message_lines: "In transformed code:" and "", so insert in index 2 + message_lines.insert(2, traceback_frame.formated_message()) # Step3: Adds error message like "TypeError: dtype must be int32, but received float32". # NOTE: `format_exception` is a list, its length is 1 in most cases, but sometimes its length # is gather than 1, for example, the error_type is IndentationError. format_exception = traceback.format_exception_only(self.error_type, self.error_value) - error_message = [" " * 4 + line for line in format_exception] + error_message = [ + " " * BLANK_COUNT_BEFORE_FILE_STR + line + for line in format_exception + ] message_lines.extend(error_message) return '\n'.join(message_lines) @@ -175,7 +225,6 @@ class ErrorData(object): self.error_value = self.error_type(error_value_str) def raise_new_exception(self): - # Raises the origin error if disable dygraph2static error module, if int(os.getenv(DISABLE_ERROR_ENV_NAME, DEFAULT_DISABLE_NEW_ERROR)): raise diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py index c177b556b86..aafb0287099 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py @@ -218,7 +218,10 @@ class TestErrorStaticLayerCallInCompiletime(TestErrorBase): ['File "{}", line 35, in func_error_in_compile_time'.format(self.filepath), 'inner_func()', 'File "{}", line 28, in inner_func'.format(self.filepath), + 'def inner_func():', 'fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")', + '<--- HERE', + 'return', ] def set_func_call(self): @@ -242,7 +245,11 @@ class TestErrorStaticLayerCallInCompiletime_2( self.expected_message = \ [ 'File "{}", line 46, in func_error_in_compile_time_2'.format(self.filepath), - 'x = fluid.layers.reshape(x, shape=[1, 2])' + 'def func_error_in_compile_time_2(x):', + 'x = fluid.dygraph.to_variable(x)', + 'x = fluid.layers.reshape(x, shape=[1, 2])', + '<--- HERE', + 'return x' ] @@ -261,7 +268,10 @@ class TestErrorStaticLayerCallInCompiletime_3( def set_message(self): self.expected_message = \ ['File "{}", line 91, in forward'.format(self.filepath), + '@paddle.jit.to_static', + 'def forward(self):', 'self.test_func()', + '<--- HERE' ] def set_func_call(self): @@ -318,7 +328,12 @@ class TestJitSaveInCompiletime(TestErrorBase): def set_message(self): self.expected_message = \ ['File "{}", line 80, in forward'.format(self.filepath), - 'fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")', + 'def forward(self, x):', + 'y = self._linear(x)', + 'z = fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")', + '<--- HERE', + 'out = fluid.layers.mean(y[z])', + 'return out' ] def set_func_call(self): @@ -329,7 +344,7 @@ class TestJitSaveInCompiletime(TestErrorBase): self._test_raise_new_exception() -# Situation 4: NotImplementedError +# # Situation 4: NotImplementedError class TestErrorInOther(unittest.TestCase): def test(self): paddle.disable_static() -- GitLab