未验证 提交 d3c9c079 编写于 作者: W WangZhen 提交者: GitHub

Fix rollback error when non-forward function to_static (#56042)

上级 947c5fa7
......@@ -421,6 +421,13 @@ class StaticFunction:
# Note(Aurelius84): To construct new instance of StaticFunction when we
# first encouter the bound function of layer and cache it.
new_static_layer = self._clone()
if (
self._dygraph_function.__name__
not in instance._original_funcs.keys()
):
instance._original_funcs[
self._dygraph_function.__name__
] = self._dygraph_function
new_static_layer._class_instance = instance
self._descriptor_cache[instance] = new_static_layer
......@@ -581,7 +588,7 @@ class StaticFunction:
assert (
func_name in self._class_instance._original_funcs
), "Not Found function '{}' in class '{}'.".format(
func_name, self._class_instance.__name__
func_name, self._class_instance.__class__
)
func = self._class_instance._original_funcs[func_name]
setattr(
......
......@@ -123,5 +123,26 @@ class TestRollBackNet(unittest.TestCase):
)
class FuncRollback(paddle.nn.Layer):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return x + 1
@paddle.jit.to_static
def func(self, x):
return x + 2
class TestRollBackNotForward(unittest.TestCase):
def test_rollback(self):
x = paddle.zeros([2, 2])
net = FuncRollback()
out = net.func(x)
net.func.rollback()
self.assertTrue(not isinstance(net.func, StaticFunction))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册