未验证 提交 a6cc567f 编写于 作者: 0 0x45f 提交者: GitHub

[Dy2Stat]Modify dy2stat error message in runtime and format error message (#35365)

上级 ef7bc367
......@@ -17,6 +17,7 @@ import six
import sys
import traceback
import linecache
import re
from paddle.fluid.dygraph.dygraph_to_static.origin_info import Location, OriginInfo, global_origin_info_map
......@@ -106,22 +107,34 @@ class TraceBackFrameRange(OriginInfo):
begin_lineno = max(1, self.location.lineno - int(SOURCE_CODE_RANGE / 2))
for i in range(begin_lineno, begin_lineno + SOURCE_CODE_RANGE):
line = linecache.getline(self.location.filepath, i)
line_lstrip = line.strip()
line = linecache.getline(self.location.filepath, i).rstrip('\n')
line_lstrip = line.lstrip()
self.source_code.append(line_lstrip)
blank_count.append(len(line) - len(line_lstrip))
if not line_lstrip: # empty line from source code
blank_count.append(-1)
else:
blank_count.append(len(line) - len(line_lstrip))
if i == self.location.lineno:
hint_msg = '~' * len(self.source_code[-1]) + ' <--- HERE'
self.source_code.append(hint_msg)
blank_count.append(blank_count[-1])
linecache.clearcache()
min_black_count = min(blank_count)
# remove top and bottom empty line in source code
while len(self.source_code) > 0 and not self.source_code[0]:
self.source_code.pop(0)
blank_count.pop(0)
while len(self.source_code) > 0 and not self.source_code[-1]:
self.source_code.pop(-1)
blank_count.pop(-1)
min_black_count = min([i for i in blank_count if i >= 0])
for i in range(len(self.source_code)):
self.source_code[i] = ' ' * (blank_count[i] - min_black_count +
BLANK_COUNT_BEFORE_FILE_STR * 2
) + self.source_code[i]
# if source_code[i] is empty line between two code line, dont add blank
if self.source_code[i]:
self.source_code[i] = ' ' * (blank_count[i] - min_black_count +
BLANK_COUNT_BEFORE_FILE_STR * 2
) + self.source_code[i]
def formated_message(self):
msg = ' ' * BLANK_COUNT_BEFORE_FILE_STR + 'File "{}", line {}, in {}\n'.format(
......@@ -212,6 +225,7 @@ class ErrorData(object):
1. Need a more robust way because the code of start_trace may change.
2. Set the switch to determine whether to simplify error_value
"""
assert self.in_runtime is True
error_value_lines = str(self.error_value).split("\n")
......@@ -219,9 +233,43 @@ class ErrorData(object):
start_trace = "outputs = static_func(*inputs)"
start_idx = error_value_lines_strip.index(start_trace)
error_value_lines = error_value_lines[start_idx + 1:]
error_value_lines_strip = error_value_lines_strip[start_idx + 1:]
# use empty line to locate the bottom_error_message
empty_line_idx = error_value_lines_strip.index('')
bottom_error_message = error_value_lines[empty_line_idx + 1:]
filepath = ''
error_from_user_code = []
pattern = 'File "(?P<filepath>.+)", line (?P<lineno>.+), in (?P<function_name>.+)'
for i in range(0, len(error_value_lines_strip), 2):
if error_value_lines_strip[i].startswith("File "):
re_result = re.search(pattern, error_value_lines_strip[i])
tmp_filepath, lineno_str, function_name = re_result.groups()
code = error_value_lines_strip[i + 1] if i + 1 < len(
error_value_lines_strip) else ''
if i == 0:
filepath = tmp_filepath
if tmp_filepath == filepath:
error_from_user_code.append(
(tmp_filepath, int(lineno_str), function_name, code))
error_frame = []
whether_source_range = True
for filepath, lineno, funcname, code in error_from_user_code[::-1]:
loc = Location(filepath, lineno)
if whether_source_range:
traceback_frame = TraceBackFrameRange(loc, funcname)
whether_source_range = False
else:
traceback_frame = TraceBackFrame(loc, funcname, code)
error_frame.insert(0, traceback_frame.formated_message())
error_value_str = '\n'.join(error_value_lines)
error_frame.extend(bottom_error_message)
error_value_str = '\n'.join(error_frame)
self.error_value = self.error_type(error_value_str)
def raise_new_exception(self):
......
......@@ -98,6 +98,16 @@ class LayerErrorInCompiletime2(fluid.dygraph.Layer):
return
@paddle.jit.to_static
def func_error_in_runtime_with_empty_line(x):
x = fluid.dygraph.to_variable(x)
two = fluid.layers.fill_constant(shape=[1], value=2, dtype="int32")
x = fluid.layers.reshape(x, shape=[1, two])
return x
class TestFlags(unittest.TestCase):
def setUp(self):
self.reset_flags_to_default()
......@@ -293,7 +303,26 @@ class TestErrorStaticLayerCallInRuntime(TestErrorStaticLayerCallInCompiletime):
self.expected_message = \
[
'File "{}", line 54, in func_error_in_runtime'.format(self.filepath),
'x = fluid.layers.reshape(x, shape=[1, two])'
'x = fluid.dygraph.to_variable(x)',
'two = fluid.layers.fill_constant(shape=[1], value=2, dtype="int32")',
'x = fluid.layers.reshape(x, shape=[1, two])',
'<--- HERE',
'return x'
]
class TestErrorStaticLayerCallInRuntime2(TestErrorStaticLayerCallInRuntime):
def set_func(self):
self.func = func_error_in_runtime_with_empty_line
def set_message(self):
self.expected_message = \
[
'File "{}", line 106, in func_error_in_runtime_with_empty_line'.format(self.filepath),
'two = fluid.layers.fill_constant(shape=[1], value=2, dtype="int32")',
'x = fluid.layers.reshape(x, shape=[1, two])',
'<--- HERE',
'return x'
]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册