未验证 提交 12bf9d71 编写于 作者: L liym27 提交者: GitHub

[Dy2Stat-ErrorMessage]Enhance original error and create new exception (#25798)

* [Dy2Stat-ErrorMessage]Enhance original error and create new exception. test=develop

* Delete redundant code and change func name to create_and_update_origin_info_map. 

* optimize loop_transformer. 

* fix bug in print_transformer.

* Modify code according to the comments.
上级 0a47387b
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import traceback
from paddle.fluid.dygraph.dygraph_to_static.origin_info import Location, OriginInfo, global_origin_info_map
ERROR_DATA = "Error data about original source code information and traceback."
def attach_error_data(error):
"""
Attachs error data about original source code information and traceback to an error.
Args:
error(Exception): An native error.
Returns:
An error attached data about original source code information and traceback.
"""
e_type, e_value, e_traceback = sys.exc_info()
tb = traceback.extract_tb(e_traceback)[1:]
error_data = ErrorData(e_type, e_value, tb, global_origin_info_map)
setattr(error, ERROR_DATA, error_data)
return error
class TraceBackFrame(OriginInfo):
"""
Traceback frame information.
"""
def __init__(self, location, function_name, source_code):
self.location = location
self.function_name = function_name
self.source_code = source_code
class ErrorData(object):
"""
Error data attached to an exception which is raised in un-transformed code.
TODO(liym27): Consider the case that op_callstack when error raised from c++ code
"""
def __init__(self, error_type, error_value, origin_traceback,
origin_info_map):
self.error_type = error_type
self.error_value = error_value
self.origin_traceback = origin_traceback
self.origin_info_map = origin_info_map
def create_exception(self):
message = self.create_message()
new_exception = self.error_type(message)
setattr(new_exception, ERROR_DATA, self)
return new_exception
def create_message(self):
"""
Creates a custom error message which includes trace stack with source code information of dygraph from user.
"""
message_lines = []
# Step1: Adds header message to prompt users that the following is the original information.
header_message = "In user code:"
message_lines.append(header_message)
message_lines.append("")
# Step2: Optimizes stack information with source code information of dygraph from user.
for filepath, lineno, funcname, code in self.origin_traceback:
loc = Location(filepath, lineno)
dygraph_func_info = self.origin_info_map.get(loc.line_location,
None)
if dygraph_func_info:
# TODO(liym27): more information to prompt users that this is the original information.
# Replaces trace stack information about transformed static code with original dygraph code.
traceback_frame = self.origin_info_map[loc.line_location]
else:
traceback_frame = TraceBackFrame(loc, funcname, code)
message_lines.append(traceback_frame.formated_message())
# Step3: Adds error message like "TypeError: dtype must be int32, but received float32".
error_message = " " * 4 + traceback.format_exception_only(
self.error_type, self.error_value)[0].strip("\n")
message_lines.append(error_message)
return '\n'.join(message_lines)
......@@ -39,32 +39,21 @@ GENERATE_VARIABLE_PREFIX = 'generate_variable'
def create_while_node(condition_name, body_name, loop_var_names):
while_args = []
while_args.append(
gast.Name(
id=condition_name,
ctx=gast.Param(),
annotation=None,
type_comment=None))
while_args.append(
gast.Name(
id=body_name, ctx=gast.Param(), annotation=None, type_comment=None))
assign_targets = [
gast.Name(
id=var_name, ctx=gast.Param(), annotation=None, type_comment=None)
for var_name in loop_var_names
]
while_args.append(gast.List(elts=assign_targets, ctx=gast.Param()))
while_func_id = gast.parse(
'fluid.dygraph.dygraph_to_static.convert_operators.convert_while_loop'
).body[0].value
while_node = gast.Call(func=while_func_id, args=while_args, keywords=[])
assign_node = gast.Assign(
targets=[gast.Tuple(
elts=assign_targets, ctx=gast.Store())],
value=while_node)
return assign_node
# NOTE(liym27):
# It's better to parse the source code into an AST node than to customize an AST node
# including child nodes, because it is easy to mistake the ast node type when customizing the node.
#
# For example: loop_var_names = [a, b, foo.x], the type of `a` or `b` is gast.Name,
# but the type of `foo.x` gast.Attribute.
while_func_name = "fluid.dygraph.dygraph_to_static.convert_operators.convert_while_loop"
while_node_str = "[{}] = {}({}, {}, [{}])".format(
",".join(loop_var_names), while_func_name, condition_name, body_name,
",".join(loop_var_names))
while_node = gast.parse(while_node_str).body[0]
return while_node
class NameVisitor(gast.NodeVisitor):
......
......@@ -21,6 +21,7 @@ import gast
# NOTE(liym27): Please use `getattr(ast_node, ORIGI_INFO)` instead of . operation to get the original information of ast node.
ORIGI_INFO = "Original information of source code for ast node."
ORIGI_INFO_MAP = "Original information map of source code."
class Location(object):
......@@ -64,6 +65,11 @@ class OriginInfo(object):
return "{} \nsource_code: {} in function {}\n ".format(
self.location, self.source_code, self.function_name)
def formated_message(self):
return ' File "{}", line {}, in {}\n\t{}'.format(
self.location.filepath, self.location.lineno, self.function_name,
self.source_code.lstrip())
class OriginInfoAttacher(gast.NodeTransformer):
"""
......@@ -119,7 +125,12 @@ class OriginInfoAttacher(gast.NodeTransformer):
return self.col_offset + node.col_offset
def create_origin_info_map(transformed_node, static_func):
global_origin_info_map = {}
def create_and_update_origin_info_map(transformed_node,
static_func,
is_global=True):
"""
Creates a original information map between transformed static function and original dygraph function.
......@@ -156,6 +167,10 @@ def create_origin_info_map(transformed_node, static_func):
origin_info_map[static_loc] = dygraph_info
global_origin_info_map.update(origin_info_map)
if is_global:
return global_origin_info_map
return origin_info_map
......
......@@ -47,8 +47,7 @@ class PrintTransformer(gast.NodeTransformer):
# NOTE: deal with print in PY3
def visit_Call(self, node):
if isinstance(node.func, gast.Name) and node.func.id == 'print':
convert_print_node = self._create_print_node(node.args)
return gast.Expr(value=convert_print_node)
node = self._create_print_node(node.args)
return node
# NOTE: deal with print in PY2
......
......@@ -36,6 +36,8 @@ from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
from paddle.fluid.dygraph.base import param_guard
from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_program_from
from paddle.fluid.dygraph.dygraph_to_static.origin_info import attach_origin_info, create_and_update_origin_info_map
from paddle.fluid.dygraph.dygraph_to_static.error import attach_error_data, ERROR_DATA
__all__ = ['ProgramTranslator', 'convert_to_static']
......@@ -88,15 +90,23 @@ class FunctionCache(object):
# with decorator directly and function.__wrapped__ holds the actual function.
func = getattr(func, '__wrapped__', func)
source_code = func_to_source_code(func)
# TODO(liym27):
# Consider this case: source_code in self._code_to_ast_caches,
# but actually they are methods in different classes.
# Maybe use (__class__, source_code) as key
if source_code in self._code_to_ast_caches:
root_wrapper = self._code_to_ast_caches[source_code]
else:
root = gast.parse(source_code)
root = attach_origin_info(root, func)
root_wrapper = self._dygraph_to_static.get_static_ast(root)
self._code_to_ast_caches[source_code] = root_wrapper
# Get static function from AST
static_func, file_name = ast_to_func(root_wrapper.node, func)
create_and_update_origin_info_map(root_wrapper.node, static_func)
return static_func
def exist(self, func):
......@@ -125,6 +135,7 @@ class FunctionSpec(object):
self._args = args
self._kwargs = kwargs
# TODO(liym27): func has multi layer decorator
dyfunc = getattr(func, '__wrapped__', func)
self._dyfunc_code = inspect.getsource(dyfunc)
......@@ -282,7 +293,13 @@ class ConcreteProgram(object):
# 3. Builds program only once and returns the output Variables.
with param_guard(func_spec.parameters(False)), param_guard(
func_spec.buffers(False)):
outputs = static_func(*inputs)
try:
outputs = static_func(*inputs)
except BaseException as e:
# NOTE: If e is raised in compile time, e should be attached to ERROR_DATA here.
attach_error_data(e)
raise
if not isinstance(outputs,
(tuple, list)) and outputs is not None:
outputs = [outputs]
......@@ -483,14 +500,24 @@ class ProgramTranslator(object):
return dygraph_func(*args, **kwargs)
function_spec = FunctionSpec(dygraph_func, args, kwargs)
_, partial_program_layer = self._program_cache[function_spec]
concrete_program, partial_program_layer = self._program_cache[
function_spec]
if args and isinstance(args[0], layers.Layer):
# Synchronize self.training attribute.
partial_program_layer.training = args[0].training
args = args[1:]
return partial_program_layer(args)
try:
return partial_program_layer(args)
except BaseException as e:
# NOTE:
# 1. If e is raised in compile time, e should have been attached to ERROR_DATA before;
# 2. If e raised in runtime, e should be attached to ERROR_DATA here.
if not hasattr(e, ERROR_DATA):
# runtime error
attach_error_data(e)
raise
def get_func(self, dygraph_func):
"""
......
......@@ -15,20 +15,23 @@
from __future__ import print_function
import os
import six
import pickle
import warnings
import six
from paddle.fluid import core
from paddle.fluid.compiler import BuildStrategy, CompiledProgram, ExecutionStrategy
from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph.base import program_desc_tracing_guard, switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator, FunctionSpec
from paddle.fluid.dygraph.dygraph_to_static.error import ERROR_DATA
from paddle.fluid.dygraph.dygraph_to_static.program_translator import FunctionSpec, ProgramTranslator
from paddle.fluid.dygraph.io import EXTRA_VAR_INFO_FILENAME, VARIABLE_FILENAME, TranslatedLayer
from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.executor import Executor, scope_guard
from paddle.fluid.framework import Program, Block, Variable, ParamBase, _dygraph_tracer, dygraph_only, _dygraph_guard, _current_expected_place, in_dygraph_mode
from paddle.fluid.framework import Block, ParamBase, Program, Variable
from paddle.fluid.framework import _current_expected_place, _dygraph_guard, _dygraph_tracer
from paddle.fluid.framework import dygraph_only, in_dygraph_mode
from paddle.fluid.wrapped_decorator import wrap_decorator
from paddle.fluid.dygraph.io import TranslatedLayer, VARIABLE_FILENAME, EXTRA_VAR_INFO_FILENAME
__all__ = ['TracedLayer', 'declarative', 'dygraph_to_static_func']
......@@ -167,7 +170,15 @@ def _declarative_(dygraph_func):
"The decorator 'declarative' doesn't work when setting ProgramTranslator.enable=False. "
"We will just return dygraph output.")
return dygraph_func(*args, **kwargs)
return program_translator.get_output(dygraph_func, *args, **kwargs)
try:
return program_translator.get_output(dygraph_func, *args, **kwargs)
except Exception as e:
error_data = getattr(e, ERROR_DATA, None)
if error_data:
new_exception = error_data.create_exception()
raise new_exception
else:
raise
return __impl__
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import inspect
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.core import EnforceNotMet
from paddle.fluid.dygraph.dygraph_to_static.error import ERROR_DATA, ErrorData
from paddle.fluid.dygraph.dygraph_to_static.origin_info import unwrap
from paddle.fluid.dygraph.jit import declarative
def inner_func():
fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")
return
@declarative
def func_error_in_compile_time(x):
x = fluid.dygraph.to_variable(x)
inner_func()
if fluid.layers.mean(x) < 0:
x_v = x - 1
else:
x_v = x + 1
return x_v
@declarative
def func_error_in_compile_time_2(x):
x = fluid.dygraph.to_variable(x)
x = fluid.layers.reshape(x, shape=[1, 2])
return x
@declarative
def func_error_in_runtime(x, iter_num=3):
x = fluid.dygraph.to_variable(x)
a = []
iter_num = fluid.layers.fill_constant(
shape=[1], value=iter_num, dtype="int32")
for i in range(iter_num):
a.append(b)
a = fluid.layers.concat(a, axis=0)
return a
class TestErrorInCompileTime(unittest.TestCase):
def setUp(self):
self.set_func()
self.set_input()
self.set_exception_type()
def set_func(self):
self.func = func_error_in_compile_time
def set_exception_type(self):
self.exception_type = TypeError
def set_input(self):
self.input = np.ones([3, 2])
def set_message(self):
self.expected_message = \
['File "{}", line 36, in func_error_in_compile_time'.format(self.filepath),
'inner_func()',
'File "{}", line 29, in inner_func'.format(self.filepath),
'fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")',
]
def _test_create_message(self, error_data):
self.filepath = inspect.getfile(unwrap(self.func))
self.set_message()
error_message = error_data.create_message()
self.assertIn('In user code:', error_message)
for m in self.expected_message:
self.assertIn(m, error_message)
def test(self):
with fluid.dygraph.guard():
with self.assertRaises(self.exception_type) as cm:
self.func(self.input)
exception = cm.exception
error_data = getattr(exception, ERROR_DATA)
self.assertIsInstance(error_data, ErrorData)
self._test_create_message(error_data)
class TestErrorInCompileTime2(TestErrorInCompileTime):
def set_func(self):
self.func = func_error_in_compile_time_2
def set_exception_type(self):
self.exception_type = EnforceNotMet
def set_message(self):
self.expected_message = \
[
'File "{}", line 47, in func_error_in_compile_time_2'.format(self.filepath),
'x = fluid.layers.reshape(x, shape=[1, 2])'
]
# TODO(liym27): Consider the case that op_callstack when error raised from c++ code
class TestErrorInRuntime(TestErrorInCompileTime):
def set_func(self):
self.func = func_error_in_runtime
def set_exception_type(self):
self.exception_type = EnforceNotMet
def test(self):
with fluid.dygraph.guard():
with self.assertRaises(self.exception_type) as cm:
self.func(self.input)
if __name__ == '__main__':
unittest.main()
......@@ -90,7 +90,8 @@ class TestOriginInfo(unittest.TestCase):
# step3
self.static_func, _ = ast_to_func(transformed_ast, self.dygraph_func)
info_map = create_origin_info_map(dygraph_ast, self.static_func)
info_map = create_and_update_origin_info_map(dygraph_ast,
self.static_func)
return info_map
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册