diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/error.py b/python/paddle/fluid/dygraph/dygraph_to_static/error.py index 273961e27efba2b7dfe4c2cb942829fa89c2d8ff..008070fcead5df5f305fcd13e43718ac6ec53ea3 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/error.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/error.py @@ -122,7 +122,7 @@ class TraceBackFrameRange(OriginInfo): msg = ' ' * BLANK_COUNT_BEFORE_FILE_STR + 'File "{}", line {}, in {}\n'.format( self.location.filepath, self.location.lineno, self.function_name) # add empty line after range code - return msg + '\n'.join(self.source_code) + '\n' + return msg + '\n'.join(self.source_code) class SuggestionDict(object): @@ -183,24 +183,39 @@ class ErrorData(object): return '\n'.join(message_lines) # Step2: Optimizes stack information with source code information of dygraph from user. - whether_source_range = True - for filepath, lineno, funcname, code in self.origin_traceback[::-1]: - loc = Location(filepath, lineno) - dygraph_func_info = self.origin_info_map.get(loc.line_location, + user_code_traceback_index = [] + for i, (filepath, lineno, funcname, + code) in enumerate(self.origin_traceback): + dygraph_func_info = self.origin_info_map.get((filepath, lineno), None) if dygraph_func_info: - if whether_source_range: - traceback_frame = TraceBackFrameRange( - dygraph_func_info.location, - dygraph_func_info.function_name) - whether_source_range = False - else: - traceback_frame = TraceBackFrame( - dygraph_func_info.location, - dygraph_func_info.function_name, - dygraph_func_info.source_code) - # Two elements already exist in message_lines: "In transformed code:" and "", so insert in index 2 - message_lines.insert(2, traceback_frame.formated_message()) + user_code_traceback_index.append(i) + + # Add user code traceback + for i in user_code_traceback_index: + filepath, lineno, funcname, code = self.origin_traceback[i] + dygraph_func_info = self.origin_info_map.get((filepath, lineno), + None) + if i == user_code_traceback_index[-1]: + traceback_frame = TraceBackFrameRange( + dygraph_func_info.location, dygraph_func_info.function_name) + else: + traceback_frame = TraceBackFrame( + dygraph_func_info.location, dygraph_func_info.function_name, + dygraph_func_info.source_code) + + message_lines.append(traceback_frame.formated_message()) + message_lines.append("") + + # Add paddle traceback after user code traceback + paddle_traceback_start_idnex = user_code_traceback_index[ + -1] + 1 if user_code_traceback_index else 0 + for filepath, lineno, funcname, code in self.origin_traceback[ + paddle_traceback_start_idnex:]: + traceback_frame = TraceBackFrame( + Location(filepath, lineno), funcname, code) + message_lines.append(traceback_frame.formated_message()) + message_lines.append("") # Step3: Adds error message like "TypeError: dtype must be int32, but received float32". # NOTE: `format_exception` is a list, its length is 1 in most cases, but sometimes its length @@ -258,8 +273,9 @@ class ErrorData(object): bottom_error_message = error_value_lines[empty_line_idx + 1:] revise_suggestion = self._create_revise_suggestion(bottom_error_message) - filepath = '' - error_from_user_code = [] + user_filepath = '' + error_traceback = [] + user_code_traceback_index = [] pattern = 'File "(?P.+)", line (?P.+), in (?P.+)' for i in range(0, len(error_value_lines_strip), 2): if error_value_lines_strip[i].startswith("File "): @@ -268,22 +284,35 @@ class ErrorData(object): code = error_value_lines_strip[i + 1] if i + 1 < len( error_value_lines_strip) else '' if i == 0: - filepath = tmp_filepath - if tmp_filepath == filepath: - error_from_user_code.append( - (tmp_filepath, int(lineno_str), function_name, code)) + user_filepath = tmp_filepath + if tmp_filepath == user_filepath: + user_code_traceback_index.append(len(error_traceback)) + + error_traceback.append( + (tmp_filepath, int(lineno_str), function_name, code)) error_frame = [] - whether_source_range = True - for filepath, lineno, funcname, code in error_from_user_code[::-1]: - loc = Location(filepath, lineno) - if whether_source_range: - traceback_frame = TraceBackFrameRange(loc, funcname) - whether_source_range = False + # Add user code traceback + for i in user_code_traceback_index: + filepath, lineno, funcname, code = error_traceback[i] + if i == user_code_traceback_index[-1]: + traceback_frame = TraceBackFrameRange( + Location(filepath, lineno), funcname) else: - traceback_frame = TraceBackFrame(loc, funcname, code) - - error_frame.insert(0, traceback_frame.formated_message()) + traceback_frame = TraceBackFrame( + Location(filepath, lineno), funcname, code) + error_frame.append(traceback_frame.formated_message()) + error_frame.append("") + + # Add paddle traceback after user code traceback + paddle_traceback_start_idnex = user_code_traceback_index[ + -1] + 1 if user_code_traceback_index else 0 + for filepath, lineno, funcname, code in error_traceback[ + paddle_traceback_start_idnex:]: + traceback_frame = TraceBackFrame( + Location(filepath, lineno), funcname, code) + error_frame.append(traceback_frame.formated_message()) + error_frame.append("") error_frame.extend(bottom_error_message) error_frame.extend(revise_suggestion)