From c3d4a3d8fec0fdb4b99c643be1e2cef1804bcd52 Mon Sep 17 00:00:00 2001 From: WangZhen <23097963+0x45f@users.noreply.github.com> Date: Tue, 2 Aug 2022 14:10:36 +0800 Subject: [PATCH] [Dy2St]Raise TypeError when call to_static to convert a method of a common class (#44781) * Fix to_static error when call to_static to convert a method of a common class * raise typerror when class no inherits from layer * Fix @to_static --- .../dygraph_to_static/program_translator.py | 8 ++++++++ .../dygraph_to_static/test_declarative.py | 16 ++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index 43ce1fae16f..77b55f35e2e 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -257,6 +257,10 @@ class StaticFunction(object): self._dygraph_function = getattr(function, '__func__') self._class_instance = getattr(function, '__self__') + if not hasattr(self._class_instance, '_original_funcs'): + raise TypeError( + "When using 'to_static' to convert method of a class, " + "please ensure the class inherits from nn.Layer") self._class_instance._original_funcs[ function.__name__] = self._dygraph_function else: @@ -406,6 +410,10 @@ class StaticFunction(object): def _is_train_mode(self): if self._class_instance is not None: + if not hasattr(self._class_instance, 'training'): + raise TypeError( + "When using 'to_static' to convert method of a class, " + "please ensure the class inherits from nn.Layer") return self._class_instance.training else: return self._training diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py index ef9eff26518..46c847938c6 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py @@ -478,5 +478,21 @@ class TestSetBuffers(unittest.TestCase): paddle.enable_static() +class ClassNoInheritLayer: + + def func(self, x): + return x + 1 + + +class TestClassNoInheritLayer(unittest.TestCase): + + def test_to_static(self): + paddle.disable_static() + net = ClassNoInheritLayer() + input_spec = [paddle.static.InputSpec(name='x', shape=[1])] + with self.assertRaises(TypeError): + static_func = paddle.jit.to_static(net.func, input_spec=input_spec) + + if __name__ == '__main__': unittest.main() -- GitLab