diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index 43ce1fae16fc2f50a5f7784e3882b7ac0bf19fd9..77b55f35e2eb2865b292217ce544c7fce3b72511 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -257,6 +257,10 @@ class StaticFunction(object): self._dygraph_function = getattr(function, '__func__') self._class_instance = getattr(function, '__self__') + if not hasattr(self._class_instance, '_original_funcs'): + raise TypeError( + "When using 'to_static' to convert method of a class, " + "please ensure the class inherits from nn.Layer") self._class_instance._original_funcs[ function.__name__] = self._dygraph_function else: @@ -406,6 +410,10 @@ class StaticFunction(object): def _is_train_mode(self): if self._class_instance is not None: + if not hasattr(self._class_instance, 'training'): + raise TypeError( + "When using 'to_static' to convert method of a class, " + "please ensure the class inherits from nn.Layer") return self._class_instance.training else: return self._training diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py index ef9eff2651853d99c86e03e38893e3db6961f8ee..46c847938c6ad8129e0361c0cb53133b74c3f386 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py @@ -478,5 +478,21 @@ class TestSetBuffers(unittest.TestCase): paddle.enable_static() +class ClassNoInheritLayer: + + def func(self, x): + return x + 1 + + +class TestClassNoInheritLayer(unittest.TestCase): + + def test_to_static(self): + paddle.disable_static() + net = ClassNoInheritLayer() + input_spec = [paddle.static.InputSpec(name='x', shape=[1])] + with self.assertRaises(TypeError): + static_func = paddle.jit.to_static(net.func, input_spec=input_spec) + + if __name__ == '__main__': unittest.main()