diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/error.py b/python/paddle/fluid/dygraph/dygraph_to_static/error.py index be21ab6d5394ed5f89c23988a9405b57e05b56fb..350e0ad5d72f157cbeec5332de652da4153e9fce 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/error.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/error.py @@ -79,6 +79,11 @@ class TraceBackFrame(OriginInfo): self.function_name = function_name self.source_code = source_code + def formated_message(self): + return ' File "{}", line {}, in {}\n\t{}'.format( + self.location.filepath, self.location.lineno, self.function_name, + self.source_code.lstrip()) + class ErrorData(object): """ @@ -106,7 +111,7 @@ class ErrorData(object): message_lines = [] # Step1: Adds header message to prompt users that the following is the original information. - header_message = "In user code:" + header_message = "In transformed code:" message_lines.append(header_message) message_lines.append("") diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/origin_info.py b/python/paddle/fluid/dygraph/dygraph_to_static/origin_info.py index 76e732d4d37f6a2056afba72649077acf16ba30e..b2f4060b106828865d5e2ffc6ce9215cd0c19503 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/origin_info.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/origin_info.py @@ -69,9 +69,10 @@ class OriginInfo(object): self.location, self.source_code, self.function_name) def formated_message(self): - return ' File "{}", line {}, in {}\n\t{}'.format( + flag_for_origin_info = "(* user code *)" + return ' File "{}", line {}, in {} {}\n\t{}'.format( self.location.filepath, self.location.lineno, self.function_name, - self.source_code.lstrip()) + flag_for_origin_info, self.source_code.lstrip()) def as_frame(self): return (self.location.filepath, self.location.lineno, diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index 82c3e2602869541b96104a7214c67e08e7e6dda2..31ca24e3c125421828b24e09f10de637661747ae 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -609,6 +609,9 @@ class ConcreteProgram(object): except BaseException as e: # NOTE: If e is raised in compile time, e should be attached to ERROR_DATA here. error.attach_error_data(e) + error_data = getattr(e, error.ERROR_DATA, None) + if error_data: + error_data.raise_new_exception() raise if outputs is not None: diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py index 82f4bd763a29e0b8b4e54ab7d66eec809cd28c89..3c43cbc518b7c4e31cf6abf8f830d045f4ac8631 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py @@ -48,83 +48,174 @@ def func_error_in_compile_time_2(x): @paddle.jit.to_static -def func_error_in_runtime(x, iter_num=3): +def func_error_in_runtime(x): x = fluid.dygraph.to_variable(x) two = fluid.layers.fill_constant(shape=[1], value=2, dtype="int32") x = fluid.layers.reshape(x, shape=[1, two]) return x -class TestErrorInCompileTime(unittest.TestCase): +@unwrap +@paddle.jit.to_static() +def func_decorated_by_other_1(): + return 1 + + +@paddle.jit.to_static() +@unwrap +def func_decorated_by_other_2(): + return 1 + + +class LayerErrorInCompiletime(fluid.dygraph.Layer): + def __init__(self, fc_size=20): + super(LayerErrorInCompiletime, self).__init__() + self._linear = fluid.dygraph.Linear(fc_size, fc_size) + + @paddle.jit.to_static( + input_spec=[paddle.static.InputSpec( + shape=[20, 20], dtype='float32')]) + def forward(self, x): + y = self._linear(x) + z = fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int") + out = fluid.layers.mean(y[z]) + return out + + +class TestFlags(unittest.TestCase): + def setUp(self): + self.reset_flags_to_default() + + def reset_flags_to_default(self): + # Reset flags to use defaut value + + # 1. A flag to set whether to open the dygraph2static error reporting module + os.environ[error.DISABLE_ERROR_ENV_NAME] = str( + error.DEFAULT_DISABLE_NEW_ERROR) + disable_error = int(os.getenv(error.DISABLE_ERROR_ENV_NAME, 999)) + self.assertEqual(disable_error, 0) + + # 2. A flag to set whether to display the simplified error stack + os.environ[error.SIMPLIFY_ERROR_ENV_NAME] = str( + error.DEFAULT_SIMPLIFY_NEW_ERROR) + simplify_error = int(os.getenv(error.SIMPLIFY_ERROR_ENV_NAME, 999)) + self.assertEqual(simplify_error, 1) + + def _test_set_flag(self, flag_name, set_value): + os.environ[flag_name] = str(set_value) + new_value = int(os.getenv(error.DISABLE_ERROR_ENV_NAME, 999)) + self.assertEqual(new_value, set_value) + + def test_translator_disable_new_error(self): + self._test_set_flag(error.DISABLE_ERROR_ENV_NAME, 1) + + def test_translator_simplify_new_error(self): + self._test_set_flag(error.SIMPLIFY_ERROR_ENV_NAME, 0) + + +class TestErrorBase(unittest.TestCase): def setUp(self): - self.set_func() self.set_input() + self.set_func() + self.set_func_call() + self.filepath = inspect.getfile(unwrap(self.func_call)) self.set_exception_type() + self.set_message() self.prog_trans = paddle.jit.ProgramTranslator() - self.simplify_error = 1 - self.disable_error = 0 - - def set_func(self): - self.func = func_error_in_compile_time - - def set_exception_type(self): - self.exception_type = TypeError def set_input(self): self.input = np.ones([3, 2]) - def set_message(self): - self.expected_message = \ - ['File "{}", line 35, in func_error_in_compile_time'.format(self.filepath), - 'inner_func()', - 'File "{}", line 28, in inner_func'.format(self.filepath), - 'fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")', - ] + def set_func(self): + raise NotImplementedError("Error test should implement set_func") - def _test_create_message(self, error_data): - self.filepath = inspect.getfile(unwrap(self.func)) - self.set_message() - error_message = error_data.create_message() + def set_func_call(self): + raise NotImplementedError("Error test should implement set_func_call") - self.assertIn('In user code:', error_message) - for m in self.expected_message: - self.assertIn(m, error_message) + def set_exception_type(self): + raise NotImplementedError( + "Error test should implement set_exception_type") - def _test_attach_and_raise_new_exception(self, func_call): + def set_message(self): + raise NotImplementedError("Error test should implement set_message") + + def reset_flags_to_default(self): + os.environ[error.DISABLE_ERROR_ENV_NAME] = str( + error.DEFAULT_DISABLE_NEW_ERROR) + os.environ[error.SIMPLIFY_ERROR_ENV_NAME] = str( + error.DEFAULT_SIMPLIFY_NEW_ERROR) + + def disable_new_error(self): + os.environ[error.DISABLE_ERROR_ENV_NAME] = str( + 1 - error.DEFAULT_DISABLE_NEW_ERROR) + + def _test_new_error_message(self, new_exception, disable_new_error=0): + error_message = str(new_exception) + + if disable_new_error: + # If disable new error, 'In user code:' should not in error_message. + self.assertNotIn('In transformed code:', error_message) + else: + # 1. 'In user code:' must be in error_message because it indicates that + # this is an optimized error message + self.assertIn('In transformed code:', error_message) + + # 2. Check whether the converted static graph code is mapped to the dygraph code. + for m in self.expected_message: + self.assertIn(m, error_message) + + def _test_raise_new_exception(self, disable_new_error=0): paddle.disable_static() - with self.assertRaises(self.exception_type) as cm: - func_call() - exception = cm.exception - error_data = getattr(exception, error.ERROR_DATA, None) + if disable_new_error: + self.disable_new_error() + else: + self.reset_flags_to_default() + + # 1. Check whether the new exception type is the same as the old one + with self.assertRaises(self.exception_type) as new_cm: + self.func_call() + + new_exception = new_cm.exception + # 2. Check whether the new_exception is attached ErrorData to indicate that this is a new exception + error_data = getattr(new_exception, error.ERROR_DATA, None) self.assertIsInstance(error_data, error.ErrorData) - self._test_create_message(error_data) - def test_static_layer_call(self): - # NOTE: self.func(self.input) is the StaticLayer().__call__(self.input) - call_dy2static = lambda: self.func(self.input) + # 3. Check whether the error message is optimized + self._test_new_error_message(new_exception, disable_new_error) + - self.set_flags(0) - self._test_attach_and_raise_new_exception(call_dy2static) +# Situation 1: Call StaticLayer.__call__ to use Dynamic-to-Static +class TestErrorStaticLayerCallInCompiletime(TestErrorBase): + def set_func(self): + self.func = func_error_in_compile_time + + def set_input(self): + self.input = np.ones([3, 2]) - def test_program_translator_get_output(self): - call_dy2static = lambda : self.prog_trans.get_output(unwrap(self.func), self.input) + def set_exception_type(self): + self.exception_type = TypeError - self.set_flags(0) - self._test_attach_and_raise_new_exception(call_dy2static) + def set_message(self): + self.expected_message = \ + ['File "{}", line 35, in func_error_in_compile_time'.format(self.filepath), + 'inner_func()', + 'File "{}", line 28, in inner_func'.format(self.filepath), + 'fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")', + ] - def set_flags(self, disable_error=0, simplify_error=1): - os.environ[error.DISABLE_ERROR_ENV_NAME] = str(disable_error) - self.disable_error = int(os.getenv(error.DISABLE_ERROR_ENV_NAME, 0)) - self.assertEqual(self.disable_error, disable_error) + def set_func_call(self): + # NOTE: self.func(self.input) is the StaticLayer().__call__(self.input) + self.func_call = lambda: self.func(self.input) - os.environ[error.SIMPLIFY_ERROR_ENV_NAME] = str(simplify_error) - self.simplify_error = int(os.getenv(error.SIMPLIFY_ERROR_ENV_NAME, 1)) - self.assertEqual(self.simplify_error, simplify_error) + def test_error(self): + for disable_new_error in [0, 1]: + self._test_raise_new_exception(disable_new_error) -class TestErrorInCompileTime2(TestErrorInCompileTime): +class TestErrorStaticLayerCallInCompiletime_2( + TestErrorStaticLayerCallInCompiletime): def set_func(self): self.func = func_error_in_compile_time_2 @@ -132,7 +223,6 @@ class TestErrorInCompileTime2(TestErrorInCompileTime): self.exception_type = ValueError def set_message(self): - self.expected_message = \ [ 'File "{}", line 46, in func_error_in_compile_time_2'.format(self.filepath), @@ -140,7 +230,7 @@ class TestErrorInCompileTime2(TestErrorInCompileTime): ] -class TestErrorInRuntime(TestErrorInCompileTime): +class TestErrorStaticLayerCallInRuntime(TestErrorStaticLayerCallInCompiletime): def set_func(self): self.func = func_error_in_runtime @@ -154,33 +244,50 @@ class TestErrorInRuntime(TestErrorInCompileTime): 'x = fluid.layers.reshape(x, shape=[1, two])' ] - def _test_create_message(self, error_data): - self.filepath = inspect.getfile(unwrap(self.func)) - self.set_message() - with self.assertRaises(ValueError): - error_data.create_message() +# Situation 2: Call ProgramTranslator().get_output(...) to use Dynamic-to-Static +class TestErrorGetOutputInCompiletime(TestErrorStaticLayerCallInCompiletime): + def set_func_call(self): + self.func_call = lambda : self.prog_trans.get_output(unwrap(self.func), self.input) - error_data.in_runtime = False - error_message = error_data.create_message() - self.assertIn('In user code:', error_message) - for m in self.expected_message: - self.assertIn(m, error_message) +class TestErrorGetOutputInCompiletime_2( + TestErrorStaticLayerCallInCompiletime_2): + def set_func_call(self): + self.func_call = lambda : self.prog_trans.get_output(unwrap(self.func), self.input) -@unwrap -@paddle.jit.to_static() -def func_decorated_by_other_1(): - return 1 +class TestErrorGetOutputInRuntime(TestErrorStaticLayerCallInRuntime): + def set_func_call(self): + self.func_call = lambda : self.prog_trans.get_output(unwrap(self.func), self.input) -@paddle.jit.to_static() -@unwrap -def func_decorated_by_other_2(): - return 1 +class TestJitSaveInCompiletime(TestErrorBase): + def setUp(self): + self.reset_flags_to_default() + self.set_func_call() + self.filepath = inspect.getfile(unwrap(self.func_call)) + self.set_exception_type() + self.set_message() + + def set_exception_type(self): + self.exception_type = TypeError + + def set_message(self): + self.expected_message = \ + ['File "{}", line 80, in forward'.format(self.filepath), + 'fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")', + ] + + def set_func_call(self): + layer = LayerErrorInCompiletime() + self.func_call = lambda : paddle.jit.save(layer, path="./test_dy2stat_error/model") + + def test_error(self): + self._test_raise_new_exception() +# Situation 4: NotImplementedError class TestErrorInOther(unittest.TestCase): def test(self): paddle.disable_static()