未验证 提交 3e20ddf7 编写于 作者: L liym27 提交者: GitHub

[Dy2Stat - Error Handling] Fix bug and optimize dy2stat error. (#27225)

上级 ac8afe18
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
import six
import sys import sys
import traceback import traceback
...@@ -20,6 +21,14 @@ from paddle.fluid.dygraph.dygraph_to_static.origin_info import Location, OriginI ...@@ -20,6 +21,14 @@ from paddle.fluid.dygraph.dygraph_to_static.origin_info import Location, OriginI
ERROR_DATA = "Error data about original source code information and traceback." ERROR_DATA = "Error data about original source code information and traceback."
# A flag to set whether to open the dygraph2static error reporting module
SIMPLIFY_ERROR_ENV_NAME = "TRANSLATOR_SIMPLIFY_NEW_ERROR"
DEFAULT_SIMPLIFY_NEW_ERROR = 1
# A flag to set whether to display the simplified error stack
DISABLE_ERROR_ENV_NAME = "TRANSLATOR_DISABLE_NEW_ERROR"
DEFAULT_DISABLE_NEW_ERROR = 0
def attach_error_data(error, in_runtime=False): def attach_error_data(error, in_runtime=False):
""" """
...@@ -103,7 +112,10 @@ class ErrorData(object): ...@@ -103,7 +112,10 @@ class ErrorData(object):
# Simplify error value to improve readability if error is raised in runtime # Simplify error value to improve readability if error is raised in runtime
if self.in_runtime: if self.in_runtime:
self._simplify_error_value() if int(
os.getenv(SIMPLIFY_ERROR_ENV_NAME,
DEFAULT_SIMPLIFY_NEW_ERROR)):
self._simplify_error_value()
message_lines.append(str(self.error_value)) message_lines.append(str(self.error_value))
return '\n'.join(message_lines) return '\n'.join(message_lines)
...@@ -150,3 +162,22 @@ class ErrorData(object): ...@@ -150,3 +162,22 @@ class ErrorData(object):
error_value_str = '\n'.join(error_value_lines) error_value_str = '\n'.join(error_value_lines)
self.error_value = self.error_type(error_value_str) self.error_value = self.error_type(error_value_str)
def raise_new_exception(self):
# Raises the origin error if disable dygraph2static error module,
if int(os.getenv(DISABLE_ERROR_ENV_NAME, DEFAULT_DISABLE_NEW_ERROR)):
raise
new_exception = self.create_exception()
if six.PY3:
# NOTE(liym27):
# 1. Why `raise new_exception from None`?
# In Python 3, by default, an new exception is raised with trace information of the caught exception.
# This only raises new_exception and hides unwanted implementation details from tracebacks of the
# caught exception.
# 2. Use exec to bypass syntax error checking in Python 2.
six.exec_("raise new_exception from None")
else:
raise new_exception
...@@ -32,8 +32,7 @@ from paddle.fluid.layers.utils import flatten ...@@ -32,8 +32,7 @@ from paddle.fluid.layers.utils import flatten
from paddle.fluid.dygraph.base import param_guard from paddle.fluid.dygraph.base import param_guard
from paddle.fluid.dygraph.base import switch_to_static_graph from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static import DygraphToStaticAst from paddle.fluid.dygraph.dygraph_to_static import DygraphToStaticAst
from paddle.fluid.dygraph.dygraph_to_static.error import ERROR_DATA from paddle.fluid.dygraph.dygraph_to_static import error
from paddle.fluid.dygraph.dygraph_to_static.error import attach_error_data
from paddle.fluid.dygraph.dygraph_to_static import logging_utils from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.origin_info import attach_origin_info from paddle.fluid.dygraph.dygraph_to_static.origin_info import attach_origin_info
from paddle.fluid.dygraph.dygraph_to_static.origin_info import create_and_update_origin_info_map from paddle.fluid.dygraph.dygraph_to_static.origin_info import create_and_update_origin_info_map
...@@ -315,6 +314,7 @@ class StaticLayer(object): ...@@ -315,6 +314,7 @@ class StaticLayer(object):
# 2. trace ops from dygraph layers and cache the generated program. # 2. trace ops from dygraph layers and cache the generated program.
args, kwargs = self._function_spec.unified_args_and_kwargs(args, kwargs) args, kwargs = self._function_spec.unified_args_and_kwargs(args, kwargs)
try: try:
concrete_program, partial_program_layer = self.get_concrete_program( concrete_program, partial_program_layer = self.get_concrete_program(
*args, **kwargs) *args, **kwargs)
...@@ -324,27 +324,22 @@ class StaticLayer(object): ...@@ -324,27 +324,22 @@ class StaticLayer(object):
partial_program_layer.training = self._class_instance.training partial_program_layer.training = self._class_instance.training
# 4. return outputs. # 4. return outputs.
return partial_program_layer(args) try:
return partial_program_layer(args)
except Exception as e:
if not hasattr(e, error.ERROR_DATA):
# runtime error
error.attach_error_data(e, in_runtime=True)
raise
except Exception as e: except Exception as e:
if not hasattr(e, ERROR_DATA): error_data = getattr(e, error.ERROR_DATA, None)
# runtime error
attach_error_data(e, in_runtime=True)
error_data = getattr(e, ERROR_DATA, None)
if error_data: if error_data:
new_exception = error_data.create_exception() error_data.raise_new_exception()
if six.PY3:
# NOTE(liym27):
# 1. Why `raise new_exception from None`?
# In Python 3, by default, an new exception is raised with trace information of the caught exception.
# This only raises new_exception and hides unwanted implementation details from tracebacks of the
# caught exception.
# 2. Use exec to bypass syntax error checking in Python 2.
six.exec_("raise new_exception from None")
else:
raise new_exception
else: else:
raise logging_utils.warn(
"Please file an issue at 'https://github.com/PaddlePaddle/Paddle/issues'"
" if you can't handle this {} yourself.".format(type(e)))
raise e
def _call_dygraph_function(self, *args, **kwargs): def _call_dygraph_function(self, *args, **kwargs):
""" """
...@@ -593,7 +588,7 @@ class ConcreteProgram(object): ...@@ -593,7 +588,7 @@ class ConcreteProgram(object):
outputs = static_func(*inputs) outputs = static_func(*inputs)
except BaseException as e: except BaseException as e:
# NOTE: If e is raised in compile time, e should be attached to ERROR_DATA here. # NOTE: If e is raised in compile time, e should be attached to ERROR_DATA here.
attach_error_data(e) error.attach_error_data(e)
raise raise
if not isinstance(outputs, if not isinstance(outputs,
...@@ -813,28 +808,36 @@ class ProgramTranslator(object): ...@@ -813,28 +808,36 @@ class ProgramTranslator(object):
"The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable = False. " "The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable = False. "
"We will just return dygraph output.") "We will just return dygraph output.")
return dygraph_func(*args, **kwargs) return dygraph_func(*args, **kwargs)
function_spec = FunctionSpec(dygraph_func)
cache_key = CacheKey.from_func_and_args(function_spec, args, kwargs,
getattr(dygraph_func,
'__self__', None))
_, partial_program_layer = self._program_cache[cache_key]
if args and isinstance(args[0], layers.Layer):
# Synchronize self.training attribute.
partial_program_layer.training = args[0].training
args = args[1:]
try: try:
return partial_program_layer(args) function_spec = FunctionSpec(dygraph_func)
cache_key = CacheKey.from_func_and_args(function_spec, args, kwargs,
getattr(dygraph_func,
'__self__', None))
_, partial_program_layer = self._program_cache[cache_key]
if args and isinstance(args[0], layers.Layer):
# Synchronize self.training attribute.
partial_program_layer.training = args[0].training
args = args[1:]
try:
return partial_program_layer(args)
except BaseException as e:
# NOTE:
# 1. If e is raised in compile time, e should have been attached to ERROR_DATA before;
# 2. If e raised in runtime, e should be attached to ERROR_DATA here.
if not hasattr(e, error.ERROR_DATA):
# runtime error
error.attach_error_data(e, in_runtime=True)
raise
except BaseException as e: except BaseException as e:
# NOTE: error_data = getattr(e, error.ERROR_DATA, None)
# 1. If e is raised in compile time, e should have been attached to ERROR_DATA before; if error_data:
# 2. If e raised in runtime, e should be attached to ERROR_DATA here. error_data.raise_new_exception()
if not hasattr(e, ERROR_DATA): else:
# runtime error logging_utils.warn(
attach_error_data(e, in_runtime=True) "Please file an issue at 'https://github.com/PaddlePaddle/Paddle/issues'"
raise " if you can't handle this {} yourself.".format(type(e)))
raise e
def get_func(self, dygraph_func): def get_func(self, dygraph_func):
""" """
......
...@@ -14,15 +14,15 @@ ...@@ -14,15 +14,15 @@
from __future__ import print_function from __future__ import print_function
import os
import inspect import inspect
import unittest import unittest
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.core import EnforceNotMet from paddle.fluid.core import EnforceNotMet
from paddle.fluid.dygraph.dygraph_to_static.error import ERROR_DATA, ErrorData from paddle.fluid.dygraph.dygraph_to_static import error
from paddle.fluid.dygraph.dygraph_to_static.origin_info import unwrap from paddle.fluid.dygraph.dygraph_to_static.origin_info import unwrap
from paddle.fluid.dygraph.jit import declarative
def inner_func(): def inner_func():
...@@ -30,7 +30,7 @@ def inner_func(): ...@@ -30,7 +30,7 @@ def inner_func():
return return
@declarative @paddle.jit.to_static
def func_error_in_compile_time(x): def func_error_in_compile_time(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
inner_func() inner_func()
...@@ -41,14 +41,14 @@ def func_error_in_compile_time(x): ...@@ -41,14 +41,14 @@ def func_error_in_compile_time(x):
return x_v return x_v
@declarative @paddle.jit.to_static
def func_error_in_compile_time_2(x): def func_error_in_compile_time_2(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
x = fluid.layers.reshape(x, shape=[1, 2]) x = fluid.layers.reshape(x, shape=[1, 2])
return x return x
@declarative @paddle.jit.to_static
def func_error_in_runtime(x, iter_num=3): def func_error_in_runtime(x, iter_num=3):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
two = fluid.layers.fill_constant(shape=[1], value=2, dtype="int32") two = fluid.layers.fill_constant(shape=[1], value=2, dtype="int32")
...@@ -61,6 +61,9 @@ class TestErrorInCompileTime(unittest.TestCase): ...@@ -61,6 +61,9 @@ class TestErrorInCompileTime(unittest.TestCase):
self.set_func() self.set_func()
self.set_input() self.set_input()
self.set_exception_type() self.set_exception_type()
self.prog_trans = paddle.jit.ProgramTranslator()
self.simplify_error = 1
self.disable_error = 0
def set_func(self): def set_func(self):
self.func = func_error_in_compile_time self.func = func_error_in_compile_time
...@@ -88,14 +91,38 @@ class TestErrorInCompileTime(unittest.TestCase): ...@@ -88,14 +91,38 @@ class TestErrorInCompileTime(unittest.TestCase):
for m in self.expected_message: for m in self.expected_message:
self.assertIn(m, error_message) self.assertIn(m, error_message)
def test(self): def _test_attach_and_raise_new_exception(self, func_call):
with fluid.dygraph.guard(): paddle.disable_static()
with self.assertRaises(self.exception_type) as cm: with self.assertRaises(self.exception_type) as cm:
self.func(self.input) func_call()
exception = cm.exception exception = cm.exception
error_data = getattr(exception, ERROR_DATA)
self.assertIsInstance(error_data, ErrorData) error_data = getattr(exception, error.ERROR_DATA, None)
self._test_create_message(error_data)
self.assertIsInstance(error_data, error.ErrorData)
self._test_create_message(error_data)
def test_static_layer_call(self):
# NOTE: self.func(self.input) is the StaticLayer().__call__(self.input)
call_dy2static = lambda: self.func(self.input)
self.set_flags(0)
self._test_attach_and_raise_new_exception(call_dy2static)
def test_program_translator_get_output(self):
call_dy2static = lambda : self.prog_trans.get_output(unwrap(self.func), self.input)
self.set_flags(0)
self._test_attach_and_raise_new_exception(call_dy2static)
def set_flags(self, disable_error=0, simplify_error=1):
os.environ[error.DISABLE_ERROR_ENV_NAME] = str(disable_error)
self.disable_error = int(os.getenv(error.DISABLE_ERROR_ENV_NAME, 0))
self.assertEqual(self.disable_error, disable_error)
os.environ[error.SIMPLIFY_ERROR_ENV_NAME] = str(simplify_error)
self.simplify_error = int(os.getenv(error.SIMPLIFY_ERROR_ENV_NAME, 1))
self.assertEqual(self.simplify_error, simplify_error)
class TestErrorInCompileTime2(TestErrorInCompileTime): class TestErrorInCompileTime2(TestErrorInCompileTime):
...@@ -143,5 +170,28 @@ class TestErrorInRuntime(TestErrorInCompileTime): ...@@ -143,5 +170,28 @@ class TestErrorInRuntime(TestErrorInCompileTime):
self.assertIn(m, error_message) self.assertIn(m, error_message)
@unwrap
@paddle.jit.to_static()
def func_decorated_by_other_1():
return 1
@paddle.jit.to_static()
@unwrap
def func_decorated_by_other_2():
return 1
class TestErrorInOther(unittest.TestCase):
def test(self):
paddle.disable_static()
prog_trans = paddle.jit.ProgramTranslator()
with self.assertRaises(NotImplementedError):
prog_trans.get_output(func_decorated_by_other_1)
with self.assertRaises(NotImplementedError):
func_decorated_by_other_2()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册