未验证 提交 b24f84c8 编写于 作者: 0 0x45f 提交者: GitHub

[Dy2stat]modify dy2stat error message in compile time (#35320)

* modify dy2stat error message in compile time

* fix variable name
上级 b53887fd
......@@ -16,6 +16,7 @@ import os
import six
import sys
import traceback
import linecache
from paddle.fluid.dygraph.dygraph_to_static.origin_info import Location, OriginInfo, global_origin_info_map
......@@ -29,6 +30,9 @@ DEFAULT_SIMPLIFY_NEW_ERROR = 1
DISABLE_ERROR_ENV_NAME = "TRANSLATOR_DISABLE_NEW_ERROR"
DEFAULT_DISABLE_NEW_ERROR = 0
SOURCE_CODE_RANGE = 5
BLANK_COUNT_BEFORE_FILE_STR = 4
def attach_error_data(error, in_runtime=False):
"""
......@@ -40,6 +44,7 @@ def attach_error_data(error, in_runtime=False):
Returns:
An error attached data about original source code information and traceback.
"""
e_type, e_value, e_traceback = sys.exc_info()
tb = traceback.extract_tb(e_traceback)[1:]
......@@ -82,12 +87,49 @@ class TraceBackFrame(OriginInfo):
def formated_message(self):
# self.source_code may be empty in some functions.
# For example, decorator generated function
return ' File "{}", line {}, in {}\n\t{}'.format(
return ' ' * BLANK_COUNT_BEFORE_FILE_STR + 'File "{}", line {}, in {}\n\t{}'.format(
self.location.filepath, self.location.lineno, self.function_name,
self.source_code.lstrip()
if isinstance(self.source_code, str) else self.source_code)
class TraceBackFrameRange(OriginInfo):
"""
Traceback frame information.
"""
def __init__(self, location, function_name):
self.location = location
self.function_name = function_name
self.source_code = []
blank_count = []
begin_lineno = max(1, self.location.lineno - int(SOURCE_CODE_RANGE / 2))
for i in range(begin_lineno, begin_lineno + SOURCE_CODE_RANGE):
line = linecache.getline(self.location.filepath, i)
line_lstrip = line.strip()
self.source_code.append(line_lstrip)
blank_count.append(len(line) - len(line_lstrip))
if i == self.location.lineno:
hint_msg = '~' * len(self.source_code[-1]) + ' <--- HERE'
self.source_code.append(hint_msg)
blank_count.append(blank_count[-1])
linecache.clearcache()
min_black_count = min(blank_count)
for i in range(len(self.source_code)):
self.source_code[i] = ' ' * (blank_count[i] - min_black_count +
BLANK_COUNT_BEFORE_FILE_STR * 2
) + self.source_code[i]
def formated_message(self):
msg = ' ' * BLANK_COUNT_BEFORE_FILE_STR + 'File "{}", line {}, in {}\n'.format(
self.location.filepath, self.location.lineno, self.function_name)
# add empty line after range code
return msg + '\n'.join(self.source_code) + '\n'
class ErrorData(object):
"""
Error data attached to an exception which is raised in un-transformed code.
......@@ -128,26 +170,34 @@ class ErrorData(object):
return '\n'.join(message_lines)
# Step2: Optimizes stack information with source code information of dygraph from user.
for filepath, lineno, funcname, code in self.origin_traceback:
whether_source_range = True
for filepath, lineno, funcname, code in self.origin_traceback[::-1]:
loc = Location(filepath, lineno)
dygraph_func_info = self.origin_info_map.get(loc.line_location,
None)
if dygraph_func_info:
# TODO(liym27): more information to prompt users that this is the original information.
# Replaces trace stack information about transformed static code with original dygraph code.
traceback_frame = self.origin_info_map[loc.line_location]
else:
traceback_frame = TraceBackFrame(loc, funcname, code)
message_lines.append(traceback_frame.formated_message())
if whether_source_range:
traceback_frame = TraceBackFrameRange(
dygraph_func_info.location,
dygraph_func_info.function_name)
whether_source_range = False
else:
traceback_frame = TraceBackFrame(
dygraph_func_info.location,
dygraph_func_info.function_name,
dygraph_func_info.source_code)
# Two elements already exist in message_lines: "In transformed code:" and "", so insert in index 2
message_lines.insert(2, traceback_frame.formated_message())
# Step3: Adds error message like "TypeError: dtype must be int32, but received float32".
# NOTE: `format_exception` is a list, its length is 1 in most cases, but sometimes its length
# is gather than 1, for example, the error_type is IndentationError.
format_exception = traceback.format_exception_only(self.error_type,
self.error_value)
error_message = [" " * 4 + line for line in format_exception]
error_message = [
" " * BLANK_COUNT_BEFORE_FILE_STR + line
for line in format_exception
]
message_lines.extend(error_message)
return '\n'.join(message_lines)
......@@ -175,7 +225,6 @@ class ErrorData(object):
self.error_value = self.error_type(error_value_str)
def raise_new_exception(self):
# Raises the origin error if disable dygraph2static error module,
if int(os.getenv(DISABLE_ERROR_ENV_NAME, DEFAULT_DISABLE_NEW_ERROR)):
raise
......
......@@ -218,7 +218,10 @@ class TestErrorStaticLayerCallInCompiletime(TestErrorBase):
['File "{}", line 35, in func_error_in_compile_time'.format(self.filepath),
'inner_func()',
'File "{}", line 28, in inner_func'.format(self.filepath),
'def inner_func():',
'fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")',
'<--- HERE',
'return',
]
def set_func_call(self):
......@@ -242,7 +245,11 @@ class TestErrorStaticLayerCallInCompiletime_2(
self.expected_message = \
[
'File "{}", line 46, in func_error_in_compile_time_2'.format(self.filepath),
'x = fluid.layers.reshape(x, shape=[1, 2])'
'def func_error_in_compile_time_2(x):',
'x = fluid.dygraph.to_variable(x)',
'x = fluid.layers.reshape(x, shape=[1, 2])',
'<--- HERE',
'return x'
]
......@@ -261,7 +268,10 @@ class TestErrorStaticLayerCallInCompiletime_3(
def set_message(self):
self.expected_message = \
['File "{}", line 91, in forward'.format(self.filepath),
'@paddle.jit.to_static',
'def forward(self):',
'self.test_func()',
'<--- HERE'
]
def set_func_call(self):
......@@ -318,7 +328,12 @@ class TestJitSaveInCompiletime(TestErrorBase):
def set_message(self):
self.expected_message = \
['File "{}", line 80, in forward'.format(self.filepath),
'fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")',
'def forward(self, x):',
'y = self._linear(x)',
'z = fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")',
'<--- HERE',
'out = fluid.layers.mean(y[z])',
'return out'
]
def set_func_call(self):
......@@ -329,7 +344,7 @@ class TestJitSaveInCompiletime(TestErrorBase):
self._test_raise_new_exception()
# Situation 4: NotImplementedError
# # Situation 4: NotImplementedError
class TestErrorInOther(unittest.TestCase):
def test(self):
paddle.disable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册