未验证 提交 c3d4a3d8 编写于 作者: W WangZhen 提交者: GitHub

[Dy2St]Raise TypeError when call to_static to convert a method of a common class (#44781)

* Fix to_static error when call to_static to convert a method of a common class

* raise typerror when class no inherits from layer

* Fix @to_static
上级 acfdb8b3
......@@ -257,6 +257,10 @@ class StaticFunction(object):
self._dygraph_function = getattr(function, '__func__')
self._class_instance = getattr(function, '__self__')
if not hasattr(self._class_instance, '_original_funcs'):
raise TypeError(
"When using 'to_static' to convert method of a class, "
"please ensure the class inherits from nn.Layer")
self._class_instance._original_funcs[
function.__name__] = self._dygraph_function
else:
......@@ -406,6 +410,10 @@ class StaticFunction(object):
def _is_train_mode(self):
if self._class_instance is not None:
if not hasattr(self._class_instance, 'training'):
raise TypeError(
"When using 'to_static' to convert method of a class, "
"please ensure the class inherits from nn.Layer")
return self._class_instance.training
else:
return self._training
......
......@@ -478,5 +478,21 @@ class TestSetBuffers(unittest.TestCase):
paddle.enable_static()
class ClassNoInheritLayer:
def func(self, x):
return x + 1
class TestClassNoInheritLayer(unittest.TestCase):
def test_to_static(self):
paddle.disable_static()
net = ClassNoInheritLayer()
input_spec = [paddle.static.InputSpec(name='x', shape=[1])]
with self.assertRaises(TypeError):
static_func = paddle.jit.to_static(net.func, input_spec=input_spec)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册