From 4615af2ccdaf3ba967d06192af87839869d35e7e Mon Sep 17 00:00:00 2001 From: WangZhen <23097963+0x45f@users.noreply.github.com> Date: Tue, 9 Aug 2022 19:28:51 +0800 Subject: [PATCH] [Dy2St]Fix abnormal format of message when raise KeyError in Dy2St (#44996) * Fix abnormal format of message when raise KeyError in Dy2St * Format code * Format code * Add UT * Rename method --- .../fluid/dygraph/dygraph_to_static/error.py | 9 ++++++++- .../unittests/dygraph_to_static/test_error.py | 14 ++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/error.py b/python/paddle/fluid/dygraph/dygraph_to_static/error.py index 3b868ade4e2..93670758dae 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/error.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/error.py @@ -143,6 +143,10 @@ class SuggestionDict(object): return self.suggestion_dict[key] +class Dy2StKeyError(Exception): + pass + + class ErrorData(object): """ Error data attached to an exception which is raised in un-transformed code. @@ -159,7 +163,10 @@ class ErrorData(object): def create_exception(self): message = self.create_message() - new_exception = self.error_type(message) + if self.error_type is KeyError: + new_exception = Dy2StKeyError(message) + else: + new_exception = self.error_type(message) setattr(new_exception, ERROR_DATA, self) return new_exception diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py index 7d980b5f75a..27d7389b903 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py @@ -443,6 +443,20 @@ class TestSuggestionErrorInRuntime(TestErrorBase): for disable_new_error in [0, 1]: self._test_raise_new_exception(disable_new_error) +@paddle.jit.to_static +def func_ker_error(x): + d = { + 'x': x + } + y = d['y'] + x + return y + +class TestKeyError(unittest.TestCase): + def test_key_error(self): + paddle.disable_static() + with self.assertRaises(error.Dy2StKeyError): + x = paddle.to_tensor([1]) + func_ker_error(x) if __name__ == '__main__': unittest.main() -- GitLab