未验证 提交 168ea223 编写于 作者: L liym27 提交者: GitHub

[Dy2Stat-ErrorMessage]Optimize error value to improve readability when error...

[Dy2Stat-ErrorMessage]Optimize error value to improve readability when error raised in runtime.  (#25970)

* don't remove op_callstack

* [Dy2Stat-ErrorMessage] Optimize error value to improve readability if error raised in run-time. 
  1. update op_callstack with original information;
  2. simplify error value to improve readability if error raised in run-time.

* Fix error in Python3. 
上级 3755564a
...@@ -20,13 +20,13 @@ from paddle.fluid.dygraph.dygraph_to_static.origin_info import Location, OriginI ...@@ -20,13 +20,13 @@ from paddle.fluid.dygraph.dygraph_to_static.origin_info import Location, OriginI
ERROR_DATA = "Error data about original source code information and traceback." ERROR_DATA = "Error data about original source code information and traceback."
def attach_error_data(error): def attach_error_data(error, in_runtime=False):
""" """
Attachs error data about original source code information and traceback to an error. Attachs error data about original source code information and traceback to an error.
Args: Args:
error(Exception): An native error. error(Exception): An native error.
in_runtime(bool): `error` is raised in runtime if in_runtime is True, otherwise in compile time
Returns: Returns:
An error attached data about original source code information and traceback. An error attached data about original source code information and traceback.
""" """
...@@ -34,6 +34,8 @@ def attach_error_data(error): ...@@ -34,6 +34,8 @@ def attach_error_data(error):
tb = traceback.extract_tb(e_traceback)[1:] tb = traceback.extract_tb(e_traceback)[1:]
error_data = ErrorData(e_type, e_value, tb, global_origin_info_map) error_data = ErrorData(e_type, e_value, tb, global_origin_info_map)
error_data.in_runtime = in_runtime
setattr(error, ERROR_DATA, error_data) setattr(error, ERROR_DATA, error_data)
return error return error
...@@ -53,8 +55,6 @@ class TraceBackFrame(OriginInfo): ...@@ -53,8 +55,6 @@ class TraceBackFrame(OriginInfo):
class ErrorData(object): class ErrorData(object):
""" """
Error data attached to an exception which is raised in un-transformed code. Error data attached to an exception which is raised in un-transformed code.
TODO(liym27): Consider the case that op_callstack when error raised from c++ code
""" """
def __init__(self, error_type, error_value, origin_traceback, def __init__(self, error_type, error_value, origin_traceback,
...@@ -63,6 +63,7 @@ class ErrorData(object): ...@@ -63,6 +63,7 @@ class ErrorData(object):
self.error_value = error_value self.error_value = error_value
self.origin_traceback = origin_traceback self.origin_traceback = origin_traceback
self.origin_info_map = origin_info_map self.origin_info_map = origin_info_map
self.in_runtime = False
def create_exception(self): def create_exception(self):
message = self.create_message() message = self.create_message()
...@@ -81,6 +82,12 @@ class ErrorData(object): ...@@ -81,6 +82,12 @@ class ErrorData(object):
message_lines.append(header_message) message_lines.append(header_message)
message_lines.append("") message_lines.append("")
# Simplify error value to improve readability if error is raised in runtime
if self.in_runtime:
self._simplify_error_value()
message_lines.append(str(self.error_value))
return '\n'.join(message_lines)
# Step2: Optimizes stack information with source code information of dygraph from user. # Step2: Optimizes stack information with source code information of dygraph from user.
for filepath, lineno, funcname, code in self.origin_traceback: for filepath, lineno, funcname, code in self.origin_traceback:
loc = Location(filepath, lineno) loc = Location(filepath, lineno)
...@@ -102,3 +109,25 @@ class ErrorData(object): ...@@ -102,3 +109,25 @@ class ErrorData(object):
message_lines.append(error_message) message_lines.append(error_message)
return '\n'.join(message_lines) return '\n'.join(message_lines)
def _simplify_error_value(self):
"""
Simplifies error value to improve readability if error is raised in runtime.
NOTE(liym27): The op callstack information about transformed static code has been replaced with original dygraph code.
TODO(liym27):
1. Need a more robust way because the code of start_trace may change.
2. Set the switch to determine whether to simplify error_value
"""
assert self.in_runtime is True
error_value_lines = str(self.error_value).split("\n")
error_value_lines_strip = [mes.lstrip(" ") for mes in error_value_lines]
start_trace = "outputs = static_func(*inputs)"
start_idx = error_value_lines_strip.index(start_trace)
error_value_lines = error_value_lines[start_idx + 1:]
error_value_str = '\n'.join(error_value_lines)
self.error_value = self.error_type(error_value_str)
...@@ -19,6 +19,9 @@ import inspect ...@@ -19,6 +19,9 @@ import inspect
import gast import gast
from paddle.fluid import core
from paddle.fluid.framework import Program
# NOTE(liym27): Please use `getattr(ast_node, ORIGI_INFO)` instead of . operation to get the original information of ast node. # NOTE(liym27): Please use `getattr(ast_node, ORIGI_INFO)` instead of . operation to get the original information of ast node.
ORIGI_INFO = "Original information of source code for ast node." ORIGI_INFO = "Original information of source code for ast node."
ORIGI_INFO_MAP = "Original information map of source code." ORIGI_INFO_MAP = "Original information map of source code."
...@@ -70,6 +73,10 @@ class OriginInfo(object): ...@@ -70,6 +73,10 @@ class OriginInfo(object):
self.location.filepath, self.location.lineno, self.function_name, self.location.filepath, self.location.lineno, self.function_name,
self.source_code.lstrip()) self.source_code.lstrip())
def as_frame(self):
return (self.location.filepath, self.location.lineno,
self.function_name, self.source_code.lstrip())
class OriginInfoAttacher(gast.NodeTransformer): class OriginInfoAttacher(gast.NodeTransformer):
""" """
...@@ -249,3 +256,63 @@ def ast_walk(transformed_node, static_node): ...@@ -249,3 +256,63 @@ def ast_walk(transformed_node, static_node):
if isinstance(d_item, gast.AST): if isinstance(d_item, gast.AST):
transformed_node_list.append(d_item) transformed_node_list.append(d_item)
static_node_list.append(s_item) static_node_list.append(s_item)
def update_op_callstack_with_origin_info(program):
"""
Replaces op callstack information about transformed static code with original dygraph code.
"""
assert isinstance(program, Program)
def get_new_op_callstack(callstack):
"""
An example of callstack:
File "path1/to/file.py", line 10, in func_1
y = fluid.layers.fill_constant(x, shape=[1], dtype="int32")
File "path2/to/file.py", line 740, in fill_constant
stop_gradient=True)
File "path3/to/file.py", line 43, in append_op
return self.main_program.current_block().append_op(*args, **kwargs)
File "path4/to/file.py", line 2811, in append_op
attrs=kwargs.get("attrs", None))
File "path5/to/file.py", line 1919, in __init__
for frame in traceback.extract_stack():
"""
assert len(callstack) % 2 == 0
for i in range(0, len(callstack), 2):
file_line = callstack[i].lstrip(" ").split(",")
filepath = file_line[0][6:-1]
lineno = int(file_line[1][6:])
funcname = file_line[2][4:]
code = callstack[i + 1].lstrip(" ")
loc = Location(filepath, lineno)
dygraph_func_info = global_origin_info_map.get(loc.line_location)
if dygraph_func_info:
filepath, lineno, funcname, code = \
dygraph_func_info.as_frame()
callstack[i] = ' File "{}", line {}, in {}'.format(
filepath, lineno, funcname)
callstack[i + 1] = ' {}'.format(code)
return callstack
op_maker = core.op_proto_and_checker_maker
callstack_var_name = op_maker.kOpCreationCallstackAttrName()
for block in program.blocks:
for i, op in enumerate(block.ops):
if op.has_attr(callstack_var_name):
callstack = op.attr(callstack_var_name)
callstack = get_new_op_callstack(callstack)
op._set_attr(callstack_var_name, callstack)
return program
...@@ -130,8 +130,6 @@ class PartialProgramLayer(layers.Layer): ...@@ -130,8 +130,6 @@ class PartialProgramLayer(layers.Layer):
self._check_params_all_inited(main_program) self._check_params_all_inited(main_program)
# 2. Prune the parameters not used anywhere in the program. # 2. Prune the parameters not used anywhere in the program.
self._prune_unused_params(main_program) self._prune_unused_params(main_program)
# 3. Remove op's python call stack with redundant low-level error messages.
main_program = self._remove_op_call_stack(main_program)
return main_program return main_program
......
...@@ -37,6 +37,7 @@ from paddle.fluid.dygraph.base import param_guard ...@@ -37,6 +37,7 @@ from paddle.fluid.dygraph.base import param_guard
from paddle.fluid.data_feeder import check_type from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_program_from from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_program_from
from paddle.fluid.dygraph.dygraph_to_static.origin_info import attach_origin_info, create_and_update_origin_info_map from paddle.fluid.dygraph.dygraph_to_static.origin_info import attach_origin_info, create_and_update_origin_info_map
from paddle.fluid.dygraph.dygraph_to_static.origin_info import update_op_callstack_with_origin_info
from paddle.fluid.dygraph.dygraph_to_static.error import attach_error_data, ERROR_DATA from paddle.fluid.dygraph.dygraph_to_static.error import attach_error_data, ERROR_DATA
__all__ = ['ProgramTranslator', 'convert_to_static'] __all__ = ['ProgramTranslator', 'convert_to_static']
...@@ -304,6 +305,8 @@ class ConcreteProgram(object): ...@@ -304,6 +305,8 @@ class ConcreteProgram(object):
(tuple, list)) and outputs is not None: (tuple, list)) and outputs is not None:
outputs = [outputs] outputs = [outputs]
main_program = update_op_callstack_with_origin_info(main_program)
return ConcreteProgram( return ConcreteProgram(
inputs=inputs, inputs=inputs,
outputs=outputs, outputs=outputs,
...@@ -516,7 +519,7 @@ class ProgramTranslator(object): ...@@ -516,7 +519,7 @@ class ProgramTranslator(object):
# 2. If e raised in runtime, e should be attached to ERROR_DATA here. # 2. If e raised in runtime, e should be attached to ERROR_DATA here.
if not hasattr(e, ERROR_DATA): if not hasattr(e, ERROR_DATA):
# runtime error # runtime error
attach_error_data(e) attach_error_data(e, in_runtime=True)
raise raise
def get_func(self, dygraph_func): def get_func(self, dygraph_func):
......
...@@ -176,6 +176,16 @@ def _declarative_(dygraph_func): ...@@ -176,6 +176,16 @@ def _declarative_(dygraph_func):
error_data = getattr(e, ERROR_DATA, None) error_data = getattr(e, ERROR_DATA, None)
if error_data: if error_data:
new_exception = error_data.create_exception() new_exception = error_data.create_exception()
if six.PY3:
# NOTE(liym27):
# 1. Why `raise new_exception from None`?
# In Python 3, by default, an new exception is raised with trace information of the caught exception.
# This only raises new_exception and hides unwanted implementation details from tracebacks of the
# caught exception.
# 2. Use exec to bypass syntax error checking in Python 2.
six.exec_("raise new_exception from None")
else:
raise new_exception raise new_exception
else: else:
raise raise
......
...@@ -51,13 +51,9 @@ def func_error_in_compile_time_2(x): ...@@ -51,13 +51,9 @@ def func_error_in_compile_time_2(x):
@declarative @declarative
def func_error_in_runtime(x, iter_num=3): def func_error_in_runtime(x, iter_num=3):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
a = [] two = fluid.layers.fill_constant(shape=[1], value=2, dtype="int32")
iter_num = fluid.layers.fill_constant( x = fluid.layers.reshape(x, shape=[1, two])
shape=[1], value=iter_num, dtype="int32") return x
for i in range(iter_num):
a.append(b)
a = fluid.layers.concat(a, axis=0)
return a
class TestErrorInCompileTime(unittest.TestCase): class TestErrorInCompileTime(unittest.TestCase):
...@@ -118,7 +114,6 @@ class TestErrorInCompileTime2(TestErrorInCompileTime): ...@@ -118,7 +114,6 @@ class TestErrorInCompileTime2(TestErrorInCompileTime):
] ]
# TODO(liym27): Consider the case that op_callstack when error raised from c++ code
class TestErrorInRuntime(TestErrorInCompileTime): class TestErrorInRuntime(TestErrorInCompileTime):
def set_func(self): def set_func(self):
self.func = func_error_in_runtime self.func = func_error_in_runtime
...@@ -126,10 +121,26 @@ class TestErrorInRuntime(TestErrorInCompileTime): ...@@ -126,10 +121,26 @@ class TestErrorInRuntime(TestErrorInCompileTime):
def set_exception_type(self): def set_exception_type(self):
self.exception_type = EnforceNotMet self.exception_type = EnforceNotMet
def test(self): def set_message(self):
with fluid.dygraph.guard(): self.expected_message = \
with self.assertRaises(self.exception_type) as cm: [
self.func(self.input) 'File "{}", line 55, in func_error_in_runtime'.format(self.filepath),
'x = fluid.layers.reshape(x, shape=[1, two])'
]
def _test_create_message(self, error_data):
self.filepath = inspect.getfile(unwrap(self.func))
self.set_message()
with self.assertRaises(ValueError):
error_data.create_message()
error_data.in_runtime = False
error_message = error_data.create_message()
self.assertIn('In user code:', error_message)
for m in self.expected_message:
self.assertIn(m, error_message)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册