From 3e20ddf73d627a2e63900d65815cc9e5bc800f84 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Fri, 11 Sep 2020 11:42:59 +0800 Subject: [PATCH] [Dy2Stat - Error Handling] Fix bug and optimize dy2stat error. (#27225) --- .../fluid/dygraph/dygraph_to_static/error.py | 33 ++++++- .../dygraph_to_static/program_translator.py | 85 ++++++++++--------- .../unittests/dygraph_to_static/test_error.py | 78 ++++++++++++++--- 3 files changed, 140 insertions(+), 56 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/error.py b/python/paddle/fluid/dygraph/dygraph_to_static/error.py index 5aba7ca0fd..be21ab6d53 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/error.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/error.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import six import sys import traceback @@ -20,6 +21,14 @@ from paddle.fluid.dygraph.dygraph_to_static.origin_info import Location, OriginI ERROR_DATA = "Error data about original source code information and traceback." +# A flag to set whether to open the dygraph2static error reporting module +SIMPLIFY_ERROR_ENV_NAME = "TRANSLATOR_SIMPLIFY_NEW_ERROR" +DEFAULT_SIMPLIFY_NEW_ERROR = 1 + +# A flag to set whether to display the simplified error stack +DISABLE_ERROR_ENV_NAME = "TRANSLATOR_DISABLE_NEW_ERROR" +DEFAULT_DISABLE_NEW_ERROR = 0 + def attach_error_data(error, in_runtime=False): """ @@ -103,7 +112,10 @@ class ErrorData(object): # Simplify error value to improve readability if error is raised in runtime if self.in_runtime: - self._simplify_error_value() + if int( + os.getenv(SIMPLIFY_ERROR_ENV_NAME, + DEFAULT_SIMPLIFY_NEW_ERROR)): + self._simplify_error_value() message_lines.append(str(self.error_value)) return '\n'.join(message_lines) @@ -150,3 +162,22 @@ class ErrorData(object): error_value_str = '\n'.join(error_value_lines) self.error_value = self.error_type(error_value_str) + + def raise_new_exception(self): + + # Raises the origin error if disable dygraph2static error module, + if int(os.getenv(DISABLE_ERROR_ENV_NAME, DEFAULT_DISABLE_NEW_ERROR)): + raise + + new_exception = self.create_exception() + if six.PY3: + # NOTE(liym27): + # 1. Why `raise new_exception from None`? + # In Python 3, by default, an new exception is raised with trace information of the caught exception. + # This only raises new_exception and hides unwanted implementation details from tracebacks of the + # caught exception. + # 2. Use exec to bypass syntax error checking in Python 2. + + six.exec_("raise new_exception from None") + else: + raise new_exception diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index 3d27810f1d..e5fce3e6ed 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -32,8 +32,7 @@ from paddle.fluid.layers.utils import flatten from paddle.fluid.dygraph.base import param_guard from paddle.fluid.dygraph.base import switch_to_static_graph from paddle.fluid.dygraph.dygraph_to_static import DygraphToStaticAst -from paddle.fluid.dygraph.dygraph_to_static.error import ERROR_DATA -from paddle.fluid.dygraph.dygraph_to_static.error import attach_error_data +from paddle.fluid.dygraph.dygraph_to_static import error from paddle.fluid.dygraph.dygraph_to_static import logging_utils from paddle.fluid.dygraph.dygraph_to_static.origin_info import attach_origin_info from paddle.fluid.dygraph.dygraph_to_static.origin_info import create_and_update_origin_info_map @@ -315,6 +314,7 @@ class StaticLayer(object): # 2. trace ops from dygraph layers and cache the generated program. args, kwargs = self._function_spec.unified_args_and_kwargs(args, kwargs) + try: concrete_program, partial_program_layer = self.get_concrete_program( *args, **kwargs) @@ -324,27 +324,22 @@ class StaticLayer(object): partial_program_layer.training = self._class_instance.training # 4. return outputs. - return partial_program_layer(args) + try: + return partial_program_layer(args) + except Exception as e: + if not hasattr(e, error.ERROR_DATA): + # runtime error + error.attach_error_data(e, in_runtime=True) + raise except Exception as e: - if not hasattr(e, ERROR_DATA): - # runtime error - attach_error_data(e, in_runtime=True) - error_data = getattr(e, ERROR_DATA, None) + error_data = getattr(e, error.ERROR_DATA, None) if error_data: - new_exception = error_data.create_exception() - if six.PY3: - # NOTE(liym27): - # 1. Why `raise new_exception from None`? - # In Python 3, by default, an new exception is raised with trace information of the caught exception. - # This only raises new_exception and hides unwanted implementation details from tracebacks of the - # caught exception. - # 2. Use exec to bypass syntax error checking in Python 2. - - six.exec_("raise new_exception from None") - else: - raise new_exception + error_data.raise_new_exception() else: - raise + logging_utils.warn( + "Please file an issue at 'https://github.com/PaddlePaddle/Paddle/issues'" + " if you can't handle this {} yourself.".format(type(e))) + raise e def _call_dygraph_function(self, *args, **kwargs): """ @@ -593,7 +588,7 @@ class ConcreteProgram(object): outputs = static_func(*inputs) except BaseException as e: # NOTE: If e is raised in compile time, e should be attached to ERROR_DATA here. - attach_error_data(e) + error.attach_error_data(e) raise if not isinstance(outputs, @@ -813,28 +808,36 @@ class ProgramTranslator(object): "The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable = False. " "We will just return dygraph output.") return dygraph_func(*args, **kwargs) - - function_spec = FunctionSpec(dygraph_func) - cache_key = CacheKey.from_func_and_args(function_spec, args, kwargs, - getattr(dygraph_func, - '__self__', None)) - _, partial_program_layer = self._program_cache[cache_key] - - if args and isinstance(args[0], layers.Layer): - # Synchronize self.training attribute. - partial_program_layer.training = args[0].training - args = args[1:] try: - return partial_program_layer(args) - + function_spec = FunctionSpec(dygraph_func) + cache_key = CacheKey.from_func_and_args(function_spec, args, kwargs, + getattr(dygraph_func, + '__self__', None)) + _, partial_program_layer = self._program_cache[cache_key] + + if args and isinstance(args[0], layers.Layer): + # Synchronize self.training attribute. + partial_program_layer.training = args[0].training + args = args[1:] + try: + return partial_program_layer(args) + except BaseException as e: + # NOTE: + # 1. If e is raised in compile time, e should have been attached to ERROR_DATA before; + # 2. If e raised in runtime, e should be attached to ERROR_DATA here. + if not hasattr(e, error.ERROR_DATA): + # runtime error + error.attach_error_data(e, in_runtime=True) + raise except BaseException as e: - # NOTE: - # 1. If e is raised in compile time, e should have been attached to ERROR_DATA before; - # 2. If e raised in runtime, e should be attached to ERROR_DATA here. - if not hasattr(e, ERROR_DATA): - # runtime error - attach_error_data(e, in_runtime=True) - raise + error_data = getattr(e, error.ERROR_DATA, None) + if error_data: + error_data.raise_new_exception() + else: + logging_utils.warn( + "Please file an issue at 'https://github.com/PaddlePaddle/Paddle/issues'" + " if you can't handle this {} yourself.".format(type(e))) + raise e def get_func(self, dygraph_func): """ diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py index 586020d434..2998ba8575 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py @@ -14,15 +14,15 @@ from __future__ import print_function +import os import inspect import unittest - import numpy as np +import paddle import paddle.fluid as fluid from paddle.fluid.core import EnforceNotMet -from paddle.fluid.dygraph.dygraph_to_static.error import ERROR_DATA, ErrorData +from paddle.fluid.dygraph.dygraph_to_static import error from paddle.fluid.dygraph.dygraph_to_static.origin_info import unwrap -from paddle.fluid.dygraph.jit import declarative def inner_func(): @@ -30,7 +30,7 @@ def inner_func(): return -@declarative +@paddle.jit.to_static def func_error_in_compile_time(x): x = fluid.dygraph.to_variable(x) inner_func() @@ -41,14 +41,14 @@ def func_error_in_compile_time(x): return x_v -@declarative +@paddle.jit.to_static def func_error_in_compile_time_2(x): x = fluid.dygraph.to_variable(x) x = fluid.layers.reshape(x, shape=[1, 2]) return x -@declarative +@paddle.jit.to_static def func_error_in_runtime(x, iter_num=3): x = fluid.dygraph.to_variable(x) two = fluid.layers.fill_constant(shape=[1], value=2, dtype="int32") @@ -61,6 +61,9 @@ class TestErrorInCompileTime(unittest.TestCase): self.set_func() self.set_input() self.set_exception_type() + self.prog_trans = paddle.jit.ProgramTranslator() + self.simplify_error = 1 + self.disable_error = 0 def set_func(self): self.func = func_error_in_compile_time @@ -88,14 +91,38 @@ class TestErrorInCompileTime(unittest.TestCase): for m in self.expected_message: self.assertIn(m, error_message) - def test(self): - with fluid.dygraph.guard(): - with self.assertRaises(self.exception_type) as cm: - self.func(self.input) - exception = cm.exception - error_data = getattr(exception, ERROR_DATA) - self.assertIsInstance(error_data, ErrorData) - self._test_create_message(error_data) + def _test_attach_and_raise_new_exception(self, func_call): + paddle.disable_static() + with self.assertRaises(self.exception_type) as cm: + func_call() + exception = cm.exception + + error_data = getattr(exception, error.ERROR_DATA, None) + + self.assertIsInstance(error_data, error.ErrorData) + self._test_create_message(error_data) + + def test_static_layer_call(self): + # NOTE: self.func(self.input) is the StaticLayer().__call__(self.input) + call_dy2static = lambda: self.func(self.input) + + self.set_flags(0) + self._test_attach_and_raise_new_exception(call_dy2static) + + def test_program_translator_get_output(self): + call_dy2static = lambda : self.prog_trans.get_output(unwrap(self.func), self.input) + + self.set_flags(0) + self._test_attach_and_raise_new_exception(call_dy2static) + + def set_flags(self, disable_error=0, simplify_error=1): + os.environ[error.DISABLE_ERROR_ENV_NAME] = str(disable_error) + self.disable_error = int(os.getenv(error.DISABLE_ERROR_ENV_NAME, 0)) + self.assertEqual(self.disable_error, disable_error) + + os.environ[error.SIMPLIFY_ERROR_ENV_NAME] = str(simplify_error) + self.simplify_error = int(os.getenv(error.SIMPLIFY_ERROR_ENV_NAME, 1)) + self.assertEqual(self.simplify_error, simplify_error) class TestErrorInCompileTime2(TestErrorInCompileTime): @@ -143,5 +170,28 @@ class TestErrorInRuntime(TestErrorInCompileTime): self.assertIn(m, error_message) +@unwrap +@paddle.jit.to_static() +def func_decorated_by_other_1(): + return 1 + + +@paddle.jit.to_static() +@unwrap +def func_decorated_by_other_2(): + return 1 + + +class TestErrorInOther(unittest.TestCase): + def test(self): + paddle.disable_static() + prog_trans = paddle.jit.ProgramTranslator() + with self.assertRaises(NotImplementedError): + prog_trans.get_output(func_decorated_by_other_1) + + with self.assertRaises(NotImplementedError): + func_decorated_by_other_2() + + if __name__ == '__main__': unittest.main() -- GitLab