未验证 提交 3e20ddf7 编写于 作者: L liym27 提交者: GitHub

[Dy2Stat - Error Handling] Fix bug and optimize dy2stat error. (#27225)

上级 ac8afe18
......@@ -13,6 +13,7 @@
# limitations under the License.
import os
import six
import sys
import traceback
......@@ -20,6 +21,14 @@ from paddle.fluid.dygraph.dygraph_to_static.origin_info import Location, OriginI
ERROR_DATA = "Error data about original source code information and traceback."
# A flag to set whether to open the dygraph2static error reporting module
SIMPLIFY_ERROR_ENV_NAME = "TRANSLATOR_SIMPLIFY_NEW_ERROR"
DEFAULT_SIMPLIFY_NEW_ERROR = 1
# A flag to set whether to display the simplified error stack
DISABLE_ERROR_ENV_NAME = "TRANSLATOR_DISABLE_NEW_ERROR"
DEFAULT_DISABLE_NEW_ERROR = 0
def attach_error_data(error, in_runtime=False):
"""
......@@ -103,7 +112,10 @@ class ErrorData(object):
# Simplify error value to improve readability if error is raised in runtime
if self.in_runtime:
self._simplify_error_value()
if int(
os.getenv(SIMPLIFY_ERROR_ENV_NAME,
DEFAULT_SIMPLIFY_NEW_ERROR)):
self._simplify_error_value()
message_lines.append(str(self.error_value))
return '\n'.join(message_lines)
......@@ -150,3 +162,22 @@ class ErrorData(object):
error_value_str = '\n'.join(error_value_lines)
self.error_value = self.error_type(error_value_str)
def raise_new_exception(self):
# Raises the origin error if disable dygraph2static error module,
if int(os.getenv(DISABLE_ERROR_ENV_NAME, DEFAULT_DISABLE_NEW_ERROR)):
raise
new_exception = self.create_exception()
if six.PY3:
# NOTE(liym27):
# 1. Why `raise new_exception from None`?
# In Python 3, by default, an new exception is raised with trace information of the caught exception.
# This only raises new_exception and hides unwanted implementation details from tracebacks of the
# caught exception.
# 2. Use exec to bypass syntax error checking in Python 2.
six.exec_("raise new_exception from None")
else:
raise new_exception
......@@ -32,8 +32,7 @@ from paddle.fluid.layers.utils import flatten
from paddle.fluid.dygraph.base import param_guard
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static import DygraphToStaticAst
from paddle.fluid.dygraph.dygraph_to_static.error import ERROR_DATA
from paddle.fluid.dygraph.dygraph_to_static.error import attach_error_data
from paddle.fluid.dygraph.dygraph_to_static import error
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.origin_info import attach_origin_info
from paddle.fluid.dygraph.dygraph_to_static.origin_info import create_and_update_origin_info_map
......@@ -315,6 +314,7 @@ class StaticLayer(object):
# 2. trace ops from dygraph layers and cache the generated program.
args, kwargs = self._function_spec.unified_args_and_kwargs(args, kwargs)
try:
concrete_program, partial_program_layer = self.get_concrete_program(
*args, **kwargs)
......@@ -324,27 +324,22 @@ class StaticLayer(object):
partial_program_layer.training = self._class_instance.training
# 4. return outputs.
return partial_program_layer(args)
try:
return partial_program_layer(args)
except Exception as e:
if not hasattr(e, error.ERROR_DATA):
# runtime error
error.attach_error_data(e, in_runtime=True)
raise
except Exception as e:
if not hasattr(e, ERROR_DATA):
# runtime error
attach_error_data(e, in_runtime=True)
error_data = getattr(e, ERROR_DATA, None)
error_data = getattr(e, error.ERROR_DATA, None)
if error_data:
new_exception = error_data.create_exception()
if six.PY3:
# NOTE(liym27):
# 1. Why `raise new_exception from None`?
# In Python 3, by default, an new exception is raised with trace information of the caught exception.
# This only raises new_exception and hides unwanted implementation details from tracebacks of the
# caught exception.
# 2. Use exec to bypass syntax error checking in Python 2.
six.exec_("raise new_exception from None")
else:
raise new_exception
error_data.raise_new_exception()
else:
raise
logging_utils.warn(
"Please file an issue at 'https://github.com/PaddlePaddle/Paddle/issues'"
" if you can't handle this {} yourself.".format(type(e)))
raise e
def _call_dygraph_function(self, *args, **kwargs):
"""
......@@ -593,7 +588,7 @@ class ConcreteProgram(object):
outputs = static_func(*inputs)
except BaseException as e:
# NOTE: If e is raised in compile time, e should be attached to ERROR_DATA here.
attach_error_data(e)
error.attach_error_data(e)
raise
if not isinstance(outputs,
......@@ -813,28 +808,36 @@ class ProgramTranslator(object):
"The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable = False. "
"We will just return dygraph output.")
return dygraph_func(*args, **kwargs)
function_spec = FunctionSpec(dygraph_func)
cache_key = CacheKey.from_func_and_args(function_spec, args, kwargs,
getattr(dygraph_func,
'__self__', None))
_, partial_program_layer = self._program_cache[cache_key]
if args and isinstance(args[0], layers.Layer):
# Synchronize self.training attribute.
partial_program_layer.training = args[0].training
args = args[1:]
try:
return partial_program_layer(args)
function_spec = FunctionSpec(dygraph_func)
cache_key = CacheKey.from_func_and_args(function_spec, args, kwargs,
getattr(dygraph_func,
'__self__', None))
_, partial_program_layer = self._program_cache[cache_key]
if args and isinstance(args[0], layers.Layer):
# Synchronize self.training attribute.
partial_program_layer.training = args[0].training
args = args[1:]
try:
return partial_program_layer(args)
except BaseException as e:
# NOTE:
# 1. If e is raised in compile time, e should have been attached to ERROR_DATA before;
# 2. If e raised in runtime, e should be attached to ERROR_DATA here.
if not hasattr(e, error.ERROR_DATA):
# runtime error
error.attach_error_data(e, in_runtime=True)
raise
except BaseException as e:
# NOTE:
# 1. If e is raised in compile time, e should have been attached to ERROR_DATA before;
# 2. If e raised in runtime, e should be attached to ERROR_DATA here.
if not hasattr(e, ERROR_DATA):
# runtime error
attach_error_data(e, in_runtime=True)
raise
error_data = getattr(e, error.ERROR_DATA, None)
if error_data:
error_data.raise_new_exception()
else:
logging_utils.warn(
"Please file an issue at 'https://github.com/PaddlePaddle/Paddle/issues'"
" if you can't handle this {} yourself.".format(type(e)))
raise e
def get_func(self, dygraph_func):
"""
......
......@@ -14,15 +14,15 @@
from __future__ import print_function
import os
import inspect
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.core import EnforceNotMet
from paddle.fluid.dygraph.dygraph_to_static.error import ERROR_DATA, ErrorData
from paddle.fluid.dygraph.dygraph_to_static import error
from paddle.fluid.dygraph.dygraph_to_static.origin_info import unwrap
from paddle.fluid.dygraph.jit import declarative
def inner_func():
......@@ -30,7 +30,7 @@ def inner_func():
return
@declarative
@paddle.jit.to_static
def func_error_in_compile_time(x):
x = fluid.dygraph.to_variable(x)
inner_func()
......@@ -41,14 +41,14 @@ def func_error_in_compile_time(x):
return x_v
@declarative
@paddle.jit.to_static
def func_error_in_compile_time_2(x):
x = fluid.dygraph.to_variable(x)
x = fluid.layers.reshape(x, shape=[1, 2])
return x
@declarative
@paddle.jit.to_static
def func_error_in_runtime(x, iter_num=3):
x = fluid.dygraph.to_variable(x)
two = fluid.layers.fill_constant(shape=[1], value=2, dtype="int32")
......@@ -61,6 +61,9 @@ class TestErrorInCompileTime(unittest.TestCase):
self.set_func()
self.set_input()
self.set_exception_type()
self.prog_trans = paddle.jit.ProgramTranslator()
self.simplify_error = 1
self.disable_error = 0
def set_func(self):
self.func = func_error_in_compile_time
......@@ -88,14 +91,38 @@ class TestErrorInCompileTime(unittest.TestCase):
for m in self.expected_message:
self.assertIn(m, error_message)
def test(self):
with fluid.dygraph.guard():
with self.assertRaises(self.exception_type) as cm:
self.func(self.input)
exception = cm.exception
error_data = getattr(exception, ERROR_DATA)
self.assertIsInstance(error_data, ErrorData)
self._test_create_message(error_data)
def _test_attach_and_raise_new_exception(self, func_call):
paddle.disable_static()
with self.assertRaises(self.exception_type) as cm:
func_call()
exception = cm.exception
error_data = getattr(exception, error.ERROR_DATA, None)
self.assertIsInstance(error_data, error.ErrorData)
self._test_create_message(error_data)
def test_static_layer_call(self):
# NOTE: self.func(self.input) is the StaticLayer().__call__(self.input)
call_dy2static = lambda: self.func(self.input)
self.set_flags(0)
self._test_attach_and_raise_new_exception(call_dy2static)
def test_program_translator_get_output(self):
call_dy2static = lambda : self.prog_trans.get_output(unwrap(self.func), self.input)
self.set_flags(0)
self._test_attach_and_raise_new_exception(call_dy2static)
def set_flags(self, disable_error=0, simplify_error=1):
os.environ[error.DISABLE_ERROR_ENV_NAME] = str(disable_error)
self.disable_error = int(os.getenv(error.DISABLE_ERROR_ENV_NAME, 0))
self.assertEqual(self.disable_error, disable_error)
os.environ[error.SIMPLIFY_ERROR_ENV_NAME] = str(simplify_error)
self.simplify_error = int(os.getenv(error.SIMPLIFY_ERROR_ENV_NAME, 1))
self.assertEqual(self.simplify_error, simplify_error)
class TestErrorInCompileTime2(TestErrorInCompileTime):
......@@ -143,5 +170,28 @@ class TestErrorInRuntime(TestErrorInCompileTime):
self.assertIn(m, error_message)
@unwrap
@paddle.jit.to_static()
def func_decorated_by_other_1():
return 1
@paddle.jit.to_static()
@unwrap
def func_decorated_by_other_2():
return 1
class TestErrorInOther(unittest.TestCase):
def test(self):
paddle.disable_static()
prog_trans = paddle.jit.ProgramTranslator()
with self.assertRaises(NotImplementedError):
prog_trans.get_output(func_decorated_by_other_1)
with self.assertRaises(NotImplementedError):
func_decorated_by_other_2()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册