未验证 提交 12bf9d71 编写于 作者: L liym27 提交者: GitHub

[Dy2Stat-ErrorMessage]Enhance original error and create new exception (#25798)

* [Dy2Stat-ErrorMessage]Enhance original error and create new exception. test=develop

* Delete redundant code and change func name to create_and_update_origin_info_map. 

* optimize loop_transformer. 

* fix bug in print_transformer.

* Modify code according to the comments.
上级 0a47387b
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import traceback
from paddle.fluid.dygraph.dygraph_to_static.origin_info import Location, OriginInfo, global_origin_info_map
ERROR_DATA = "Error data about original source code information and traceback."
def attach_error_data(error):
"""
Attachs error data about original source code information and traceback to an error.
Args:
error(Exception): An native error.
Returns:
An error attached data about original source code information and traceback.
"""
e_type, e_value, e_traceback = sys.exc_info()
tb = traceback.extract_tb(e_traceback)[1:]
error_data = ErrorData(e_type, e_value, tb, global_origin_info_map)
setattr(error, ERROR_DATA, error_data)
return error
class TraceBackFrame(OriginInfo):
"""
Traceback frame information.
"""
def __init__(self, location, function_name, source_code):
self.location = location
self.function_name = function_name
self.source_code = source_code
class ErrorData(object):
"""
Error data attached to an exception which is raised in un-transformed code.
TODO(liym27): Consider the case that op_callstack when error raised from c++ code
"""
def __init__(self, error_type, error_value, origin_traceback,
origin_info_map):
self.error_type = error_type
self.error_value = error_value
self.origin_traceback = origin_traceback
self.origin_info_map = origin_info_map
def create_exception(self):
message = self.create_message()
new_exception = self.error_type(message)
setattr(new_exception, ERROR_DATA, self)
return new_exception
def create_message(self):
"""
Creates a custom error message which includes trace stack with source code information of dygraph from user.
"""
message_lines = []
# Step1: Adds header message to prompt users that the following is the original information.
header_message = "In user code:"
message_lines.append(header_message)
message_lines.append("")
# Step2: Optimizes stack information with source code information of dygraph from user.
for filepath, lineno, funcname, code in self.origin_traceback:
loc = Location(filepath, lineno)
dygraph_func_info = self.origin_info_map.get(loc.line_location,
None)
if dygraph_func_info:
# TODO(liym27): more information to prompt users that this is the original information.
# Replaces trace stack information about transformed static code with original dygraph code.
traceback_frame = self.origin_info_map[loc.line_location]
else:
traceback_frame = TraceBackFrame(loc, funcname, code)
message_lines.append(traceback_frame.formated_message())
# Step3: Adds error message like "TypeError: dtype must be int32, but received float32".
error_message = " " * 4 + traceback.format_exception_only(
self.error_type, self.error_value)[0].strip("\n")
message_lines.append(error_message)
return '\n'.join(message_lines)
...@@ -39,32 +39,21 @@ GENERATE_VARIABLE_PREFIX = 'generate_variable' ...@@ -39,32 +39,21 @@ GENERATE_VARIABLE_PREFIX = 'generate_variable'
def create_while_node(condition_name, body_name, loop_var_names): def create_while_node(condition_name, body_name, loop_var_names):
while_args = [] # NOTE(liym27):
while_args.append( # It's better to parse the source code into an AST node than to customize an AST node
gast.Name( # including child nodes, because it is easy to mistake the ast node type when customizing the node.
id=condition_name, #
ctx=gast.Param(), # For example: loop_var_names = [a, b, foo.x], the type of `a` or `b` is gast.Name,
annotation=None, # but the type of `foo.x` gast.Attribute.
type_comment=None))
while_args.append( while_func_name = "fluid.dygraph.dygraph_to_static.convert_operators.convert_while_loop"
gast.Name( while_node_str = "[{}] = {}({}, {}, [{}])".format(
id=body_name, ctx=gast.Param(), annotation=None, type_comment=None)) ",".join(loop_var_names), while_func_name, condition_name, body_name,
assign_targets = [ ",".join(loop_var_names))
gast.Name(
id=var_name, ctx=gast.Param(), annotation=None, type_comment=None) while_node = gast.parse(while_node_str).body[0]
for var_name in loop_var_names
] return while_node
while_args.append(gast.List(elts=assign_targets, ctx=gast.Param()))
while_func_id = gast.parse(
'fluid.dygraph.dygraph_to_static.convert_operators.convert_while_loop'
).body[0].value
while_node = gast.Call(func=while_func_id, args=while_args, keywords=[])
assign_node = gast.Assign(
targets=[gast.Tuple(
elts=assign_targets, ctx=gast.Store())],
value=while_node)
return assign_node
class NameVisitor(gast.NodeVisitor): class NameVisitor(gast.NodeVisitor):
......
...@@ -21,6 +21,7 @@ import gast ...@@ -21,6 +21,7 @@ import gast
# NOTE(liym27): Please use `getattr(ast_node, ORIGI_INFO)` instead of . operation to get the original information of ast node. # NOTE(liym27): Please use `getattr(ast_node, ORIGI_INFO)` instead of . operation to get the original information of ast node.
ORIGI_INFO = "Original information of source code for ast node." ORIGI_INFO = "Original information of source code for ast node."
ORIGI_INFO_MAP = "Original information map of source code."
class Location(object): class Location(object):
...@@ -64,6 +65,11 @@ class OriginInfo(object): ...@@ -64,6 +65,11 @@ class OriginInfo(object):
return "{} \nsource_code: {} in function {}\n ".format( return "{} \nsource_code: {} in function {}\n ".format(
self.location, self.source_code, self.function_name) self.location, self.source_code, self.function_name)
def formated_message(self):
return ' File "{}", line {}, in {}\n\t{}'.format(
self.location.filepath, self.location.lineno, self.function_name,
self.source_code.lstrip())
class OriginInfoAttacher(gast.NodeTransformer): class OriginInfoAttacher(gast.NodeTransformer):
""" """
...@@ -119,7 +125,12 @@ class OriginInfoAttacher(gast.NodeTransformer): ...@@ -119,7 +125,12 @@ class OriginInfoAttacher(gast.NodeTransformer):
return self.col_offset + node.col_offset return self.col_offset + node.col_offset
def create_origin_info_map(transformed_node, static_func): global_origin_info_map = {}
def create_and_update_origin_info_map(transformed_node,
static_func,
is_global=True):
""" """
Creates a original information map between transformed static function and original dygraph function. Creates a original information map between transformed static function and original dygraph function.
...@@ -156,6 +167,10 @@ def create_origin_info_map(transformed_node, static_func): ...@@ -156,6 +167,10 @@ def create_origin_info_map(transformed_node, static_func):
origin_info_map[static_loc] = dygraph_info origin_info_map[static_loc] = dygraph_info
global_origin_info_map.update(origin_info_map)
if is_global:
return global_origin_info_map
return origin_info_map return origin_info_map
......
...@@ -47,8 +47,7 @@ class PrintTransformer(gast.NodeTransformer): ...@@ -47,8 +47,7 @@ class PrintTransformer(gast.NodeTransformer):
# NOTE: deal with print in PY3 # NOTE: deal with print in PY3
def visit_Call(self, node): def visit_Call(self, node):
if isinstance(node.func, gast.Name) and node.func.id == 'print': if isinstance(node.func, gast.Name) and node.func.id == 'print':
convert_print_node = self._create_print_node(node.args) node = self._create_print_node(node.args)
return gast.Expr(value=convert_print_node)
return node return node
# NOTE: deal with print in PY2 # NOTE: deal with print in PY2
......
...@@ -36,6 +36,8 @@ from paddle.fluid.wrapped_decorator import signature_safe_contextmanager ...@@ -36,6 +36,8 @@ from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
from paddle.fluid.dygraph.base import param_guard from paddle.fluid.dygraph.base import param_guard
from paddle.fluid.data_feeder import check_type from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_program_from from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_program_from
from paddle.fluid.dygraph.dygraph_to_static.origin_info import attach_origin_info, create_and_update_origin_info_map
from paddle.fluid.dygraph.dygraph_to_static.error import attach_error_data, ERROR_DATA
__all__ = ['ProgramTranslator', 'convert_to_static'] __all__ = ['ProgramTranslator', 'convert_to_static']
...@@ -88,15 +90,23 @@ class FunctionCache(object): ...@@ -88,15 +90,23 @@ class FunctionCache(object):
# with decorator directly and function.__wrapped__ holds the actual function. # with decorator directly and function.__wrapped__ holds the actual function.
func = getattr(func, '__wrapped__', func) func = getattr(func, '__wrapped__', func)
source_code = func_to_source_code(func) source_code = func_to_source_code(func)
# TODO(liym27):
# Consider this case: source_code in self._code_to_ast_caches,
# but actually they are methods in different classes.
# Maybe use (__class__, source_code) as key
if source_code in self._code_to_ast_caches: if source_code in self._code_to_ast_caches:
root_wrapper = self._code_to_ast_caches[source_code] root_wrapper = self._code_to_ast_caches[source_code]
else: else:
root = gast.parse(source_code) root = gast.parse(source_code)
root = attach_origin_info(root, func)
root_wrapper = self._dygraph_to_static.get_static_ast(root) root_wrapper = self._dygraph_to_static.get_static_ast(root)
self._code_to_ast_caches[source_code] = root_wrapper self._code_to_ast_caches[source_code] = root_wrapper
# Get static function from AST # Get static function from AST
static_func, file_name = ast_to_func(root_wrapper.node, func) static_func, file_name = ast_to_func(root_wrapper.node, func)
create_and_update_origin_info_map(root_wrapper.node, static_func)
return static_func return static_func
def exist(self, func): def exist(self, func):
...@@ -125,6 +135,7 @@ class FunctionSpec(object): ...@@ -125,6 +135,7 @@ class FunctionSpec(object):
self._args = args self._args = args
self._kwargs = kwargs self._kwargs = kwargs
# TODO(liym27): func has multi layer decorator
dyfunc = getattr(func, '__wrapped__', func) dyfunc = getattr(func, '__wrapped__', func)
self._dyfunc_code = inspect.getsource(dyfunc) self._dyfunc_code = inspect.getsource(dyfunc)
...@@ -282,7 +293,13 @@ class ConcreteProgram(object): ...@@ -282,7 +293,13 @@ class ConcreteProgram(object):
# 3. Builds program only once and returns the output Variables. # 3. Builds program only once and returns the output Variables.
with param_guard(func_spec.parameters(False)), param_guard( with param_guard(func_spec.parameters(False)), param_guard(
func_spec.buffers(False)): func_spec.buffers(False)):
try:
outputs = static_func(*inputs) outputs = static_func(*inputs)
except BaseException as e:
# NOTE: If e is raised in compile time, e should be attached to ERROR_DATA here.
attach_error_data(e)
raise
if not isinstance(outputs, if not isinstance(outputs,
(tuple, list)) and outputs is not None: (tuple, list)) and outputs is not None:
outputs = [outputs] outputs = [outputs]
...@@ -483,15 +500,25 @@ class ProgramTranslator(object): ...@@ -483,15 +500,25 @@ class ProgramTranslator(object):
return dygraph_func(*args, **kwargs) return dygraph_func(*args, **kwargs)
function_spec = FunctionSpec(dygraph_func, args, kwargs) function_spec = FunctionSpec(dygraph_func, args, kwargs)
_, partial_program_layer = self._program_cache[function_spec] concrete_program, partial_program_layer = self._program_cache[
function_spec]
if args and isinstance(args[0], layers.Layer): if args and isinstance(args[0], layers.Layer):
# Synchronize self.training attribute. # Synchronize self.training attribute.
partial_program_layer.training = args[0].training partial_program_layer.training = args[0].training
args = args[1:] args = args[1:]
try:
return partial_program_layer(args) return partial_program_layer(args)
except BaseException as e:
# NOTE:
# 1. If e is raised in compile time, e should have been attached to ERROR_DATA before;
# 2. If e raised in runtime, e should be attached to ERROR_DATA here.
if not hasattr(e, ERROR_DATA):
# runtime error
attach_error_data(e)
raise
def get_func(self, dygraph_func): def get_func(self, dygraph_func):
""" """
Returns a callable function which converts imperative dygraph APIs of Returns a callable function which converts imperative dygraph APIs of
......
...@@ -15,20 +15,23 @@ ...@@ -15,20 +15,23 @@
from __future__ import print_function from __future__ import print_function
import os import os
import six
import pickle import pickle
import warnings import warnings
import six
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.compiler import BuildStrategy, CompiledProgram, ExecutionStrategy from paddle.fluid.compiler import BuildStrategy, CompiledProgram, ExecutionStrategy
from paddle.fluid.data_feeder import check_type from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph.base import program_desc_tracing_guard, switch_to_static_graph from paddle.fluid.dygraph.base import program_desc_tracing_guard, switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator, FunctionSpec from paddle.fluid.dygraph.dygraph_to_static.error import ERROR_DATA
from paddle.fluid.dygraph.dygraph_to_static.program_translator import FunctionSpec, ProgramTranslator
from paddle.fluid.dygraph.io import EXTRA_VAR_INFO_FILENAME, VARIABLE_FILENAME, TranslatedLayer
from paddle.fluid.dygraph.layers import Layer from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.executor import Executor, scope_guard from paddle.fluid.executor import Executor, scope_guard
from paddle.fluid.framework import Program, Block, Variable, ParamBase, _dygraph_tracer, dygraph_only, _dygraph_guard, _current_expected_place, in_dygraph_mode from paddle.fluid.framework import Block, ParamBase, Program, Variable
from paddle.fluid.framework import _current_expected_place, _dygraph_guard, _dygraph_tracer
from paddle.fluid.framework import dygraph_only, in_dygraph_mode
from paddle.fluid.wrapped_decorator import wrap_decorator from paddle.fluid.wrapped_decorator import wrap_decorator
from paddle.fluid.dygraph.io import TranslatedLayer, VARIABLE_FILENAME, EXTRA_VAR_INFO_FILENAME
__all__ = ['TracedLayer', 'declarative', 'dygraph_to_static_func'] __all__ = ['TracedLayer', 'declarative', 'dygraph_to_static_func']
...@@ -167,7 +170,15 @@ def _declarative_(dygraph_func): ...@@ -167,7 +170,15 @@ def _declarative_(dygraph_func):
"The decorator 'declarative' doesn't work when setting ProgramTranslator.enable=False. " "The decorator 'declarative' doesn't work when setting ProgramTranslator.enable=False. "
"We will just return dygraph output.") "We will just return dygraph output.")
return dygraph_func(*args, **kwargs) return dygraph_func(*args, **kwargs)
try:
return program_translator.get_output(dygraph_func, *args, **kwargs) return program_translator.get_output(dygraph_func, *args, **kwargs)
except Exception as e:
error_data = getattr(e, ERROR_DATA, None)
if error_data:
new_exception = error_data.create_exception()
raise new_exception
else:
raise
return __impl__ return __impl__
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import inspect
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.core import EnforceNotMet
from paddle.fluid.dygraph.dygraph_to_static.error import ERROR_DATA, ErrorData
from paddle.fluid.dygraph.dygraph_to_static.origin_info import unwrap
from paddle.fluid.dygraph.jit import declarative
def inner_func():
fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")
return
@declarative
def func_error_in_compile_time(x):
x = fluid.dygraph.to_variable(x)
inner_func()
if fluid.layers.mean(x) < 0:
x_v = x - 1
else:
x_v = x + 1
return x_v
@declarative
def func_error_in_compile_time_2(x):
x = fluid.dygraph.to_variable(x)
x = fluid.layers.reshape(x, shape=[1, 2])
return x
@declarative
def func_error_in_runtime(x, iter_num=3):
x = fluid.dygraph.to_variable(x)
a = []
iter_num = fluid.layers.fill_constant(
shape=[1], value=iter_num, dtype="int32")
for i in range(iter_num):
a.append(b)
a = fluid.layers.concat(a, axis=0)
return a
class TestErrorInCompileTime(unittest.TestCase):
def setUp(self):
self.set_func()
self.set_input()
self.set_exception_type()
def set_func(self):
self.func = func_error_in_compile_time
def set_exception_type(self):
self.exception_type = TypeError
def set_input(self):
self.input = np.ones([3, 2])
def set_message(self):
self.expected_message = \
['File "{}", line 36, in func_error_in_compile_time'.format(self.filepath),
'inner_func()',
'File "{}", line 29, in inner_func'.format(self.filepath),
'fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")',
]
def _test_create_message(self, error_data):
self.filepath = inspect.getfile(unwrap(self.func))
self.set_message()
error_message = error_data.create_message()
self.assertIn('In user code:', error_message)
for m in self.expected_message:
self.assertIn(m, error_message)
def test(self):
with fluid.dygraph.guard():
with self.assertRaises(self.exception_type) as cm:
self.func(self.input)
exception = cm.exception
error_data = getattr(exception, ERROR_DATA)
self.assertIsInstance(error_data, ErrorData)
self._test_create_message(error_data)
class TestErrorInCompileTime2(TestErrorInCompileTime):
def set_func(self):
self.func = func_error_in_compile_time_2
def set_exception_type(self):
self.exception_type = EnforceNotMet
def set_message(self):
self.expected_message = \
[
'File "{}", line 47, in func_error_in_compile_time_2'.format(self.filepath),
'x = fluid.layers.reshape(x, shape=[1, 2])'
]
# TODO(liym27): Consider the case that op_callstack when error raised from c++ code
class TestErrorInRuntime(TestErrorInCompileTime):
def set_func(self):
self.func = func_error_in_runtime
def set_exception_type(self):
self.exception_type = EnforceNotMet
def test(self):
with fluid.dygraph.guard():
with self.assertRaises(self.exception_type) as cm:
self.func(self.input)
if __name__ == '__main__':
unittest.main()
...@@ -90,7 +90,8 @@ class TestOriginInfo(unittest.TestCase): ...@@ -90,7 +90,8 @@ class TestOriginInfo(unittest.TestCase):
# step3 # step3
self.static_func, _ = ast_to_func(transformed_ast, self.dygraph_func) self.static_func, _ = ast_to_func(transformed_ast, self.dygraph_func)
info_map = create_origin_info_map(dygraph_ast, self.static_func) info_map = create_and_update_origin_info_map(dygraph_ast,
self.static_func)
return info_map return info_map
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册