未验证 提交 a6cc567f 编写于 作者: 0 0x45f 提交者: GitHub

[Dy2Stat]Modify dy2stat error message in runtime and format error message (#35365)

上级 ef7bc367
...@@ -17,6 +17,7 @@ import six ...@@ -17,6 +17,7 @@ import six
import sys import sys
import traceback import traceback
import linecache import linecache
import re
from paddle.fluid.dygraph.dygraph_to_static.origin_info import Location, OriginInfo, global_origin_info_map from paddle.fluid.dygraph.dygraph_to_static.origin_info import Location, OriginInfo, global_origin_info_map
...@@ -106,22 +107,34 @@ class TraceBackFrameRange(OriginInfo): ...@@ -106,22 +107,34 @@ class TraceBackFrameRange(OriginInfo):
begin_lineno = max(1, self.location.lineno - int(SOURCE_CODE_RANGE / 2)) begin_lineno = max(1, self.location.lineno - int(SOURCE_CODE_RANGE / 2))
for i in range(begin_lineno, begin_lineno + SOURCE_CODE_RANGE): for i in range(begin_lineno, begin_lineno + SOURCE_CODE_RANGE):
line = linecache.getline(self.location.filepath, i) line = linecache.getline(self.location.filepath, i).rstrip('\n')
line_lstrip = line.strip() line_lstrip = line.lstrip()
self.source_code.append(line_lstrip) self.source_code.append(line_lstrip)
blank_count.append(len(line) - len(line_lstrip)) if not line_lstrip: # empty line from source code
blank_count.append(-1)
else:
blank_count.append(len(line) - len(line_lstrip))
if i == self.location.lineno: if i == self.location.lineno:
hint_msg = '~' * len(self.source_code[-1]) + ' <--- HERE' hint_msg = '~' * len(self.source_code[-1]) + ' <--- HERE'
self.source_code.append(hint_msg) self.source_code.append(hint_msg)
blank_count.append(blank_count[-1]) blank_count.append(blank_count[-1])
linecache.clearcache() linecache.clearcache()
# remove top and bottom empty line in source code
min_black_count = min(blank_count) while len(self.source_code) > 0 and not self.source_code[0]:
self.source_code.pop(0)
blank_count.pop(0)
while len(self.source_code) > 0 and not self.source_code[-1]:
self.source_code.pop(-1)
blank_count.pop(-1)
min_black_count = min([i for i in blank_count if i >= 0])
for i in range(len(self.source_code)): for i in range(len(self.source_code)):
self.source_code[i] = ' ' * (blank_count[i] - min_black_count + # if source_code[i] is empty line between two code line, dont add blank
BLANK_COUNT_BEFORE_FILE_STR * 2 if self.source_code[i]:
) + self.source_code[i] self.source_code[i] = ' ' * (blank_count[i] - min_black_count +
BLANK_COUNT_BEFORE_FILE_STR * 2
) + self.source_code[i]
def formated_message(self): def formated_message(self):
msg = ' ' * BLANK_COUNT_BEFORE_FILE_STR + 'File "{}", line {}, in {}\n'.format( msg = ' ' * BLANK_COUNT_BEFORE_FILE_STR + 'File "{}", line {}, in {}\n'.format(
...@@ -212,6 +225,7 @@ class ErrorData(object): ...@@ -212,6 +225,7 @@ class ErrorData(object):
1. Need a more robust way because the code of start_trace may change. 1. Need a more robust way because the code of start_trace may change.
2. Set the switch to determine whether to simplify error_value 2. Set the switch to determine whether to simplify error_value
""" """
assert self.in_runtime is True assert self.in_runtime is True
error_value_lines = str(self.error_value).split("\n") error_value_lines = str(self.error_value).split("\n")
...@@ -219,9 +233,43 @@ class ErrorData(object): ...@@ -219,9 +233,43 @@ class ErrorData(object):
start_trace = "outputs = static_func(*inputs)" start_trace = "outputs = static_func(*inputs)"
start_idx = error_value_lines_strip.index(start_trace) start_idx = error_value_lines_strip.index(start_trace)
error_value_lines = error_value_lines[start_idx + 1:] error_value_lines = error_value_lines[start_idx + 1:]
error_value_lines_strip = error_value_lines_strip[start_idx + 1:]
# use empty line to locate the bottom_error_message
empty_line_idx = error_value_lines_strip.index('')
bottom_error_message = error_value_lines[empty_line_idx + 1:]
filepath = ''
error_from_user_code = []
pattern = 'File "(?P<filepath>.+)", line (?P<lineno>.+), in (?P<function_name>.+)'
for i in range(0, len(error_value_lines_strip), 2):
if error_value_lines_strip[i].startswith("File "):
re_result = re.search(pattern, error_value_lines_strip[i])
tmp_filepath, lineno_str, function_name = re_result.groups()
code = error_value_lines_strip[i + 1] if i + 1 < len(
error_value_lines_strip) else ''
if i == 0:
filepath = tmp_filepath
if tmp_filepath == filepath:
error_from_user_code.append(
(tmp_filepath, int(lineno_str), function_name, code))
error_frame = []
whether_source_range = True
for filepath, lineno, funcname, code in error_from_user_code[::-1]:
loc = Location(filepath, lineno)
if whether_source_range:
traceback_frame = TraceBackFrameRange(loc, funcname)
whether_source_range = False
else:
traceback_frame = TraceBackFrame(loc, funcname, code)
error_frame.insert(0, traceback_frame.formated_message())
error_value_str = '\n'.join(error_value_lines) error_frame.extend(bottom_error_message)
error_value_str = '\n'.join(error_frame)
self.error_value = self.error_type(error_value_str) self.error_value = self.error_type(error_value_str)
def raise_new_exception(self): def raise_new_exception(self):
......
...@@ -98,6 +98,16 @@ class LayerErrorInCompiletime2(fluid.dygraph.Layer): ...@@ -98,6 +98,16 @@ class LayerErrorInCompiletime2(fluid.dygraph.Layer):
return return
@paddle.jit.to_static
def func_error_in_runtime_with_empty_line(x):
x = fluid.dygraph.to_variable(x)
two = fluid.layers.fill_constant(shape=[1], value=2, dtype="int32")
x = fluid.layers.reshape(x, shape=[1, two])
return x
class TestFlags(unittest.TestCase): class TestFlags(unittest.TestCase):
def setUp(self): def setUp(self):
self.reset_flags_to_default() self.reset_flags_to_default()
...@@ -293,7 +303,26 @@ class TestErrorStaticLayerCallInRuntime(TestErrorStaticLayerCallInCompiletime): ...@@ -293,7 +303,26 @@ class TestErrorStaticLayerCallInRuntime(TestErrorStaticLayerCallInCompiletime):
self.expected_message = \ self.expected_message = \
[ [
'File "{}", line 54, in func_error_in_runtime'.format(self.filepath), 'File "{}", line 54, in func_error_in_runtime'.format(self.filepath),
'x = fluid.layers.reshape(x, shape=[1, two])' 'x = fluid.dygraph.to_variable(x)',
'two = fluid.layers.fill_constant(shape=[1], value=2, dtype="int32")',
'x = fluid.layers.reshape(x, shape=[1, two])',
'<--- HERE',
'return x'
]
class TestErrorStaticLayerCallInRuntime2(TestErrorStaticLayerCallInRuntime):
def set_func(self):
self.func = func_error_in_runtime_with_empty_line
def set_message(self):
self.expected_message = \
[
'File "{}", line 106, in func_error_in_runtime_with_empty_line'.format(self.filepath),
'two = fluid.layers.fill_constant(shape=[1], value=2, dtype="int32")',
'x = fluid.layers.reshape(x, shape=[1, two])',
'<--- HERE',
'return x'
] ]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册