未验证 提交 2500dca8 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat] Fix bug in convert_call (#28368)

* Fix bug in convert_call

* refine unittest

* refine code

* refine code

* fix unittest failed

* add assert
上级 ca415414
...@@ -142,7 +142,7 @@ def convert_call(func): ...@@ -142,7 +142,7 @@ def convert_call(func):
# Note(Aurelius84): Because `@declarative` returns a class instance instead of # Note(Aurelius84): Because `@declarative` returns a class instance instead of
# a function. This will modify the value referring to itself in `__globals__`. # a function. This will modify the value referring to itself in `__globals__`.
# For example: # For example:
# #
# @declarative # @declarative
# def foo(x): # def foo(x):
...@@ -150,7 +150,7 @@ def convert_call(func): ...@@ -150,7 +150,7 @@ def convert_call(func):
# #
# `foo` will be converted into a wrapper class, suppose as `StaticFunction`. # `foo` will be converted into a wrapper class, suppose as `StaticFunction`.
# And `foo.__globals__['foo']` will still return this `StaticFunction` instead of # And `foo.__globals__['foo']` will still return this `StaticFunction` instead of
# `foo` function. So `isinstance(fn, StaticFunction)` is added here. # `foo` function. So `isinstance(fn, StaticFunction)` is added here.
global_functions = set() global_functions = set()
for fn in func.__globals__.values(): for fn in func.__globals__.values():
if inspect.isfunction(fn): if inspect.isfunction(fn):
...@@ -193,8 +193,10 @@ def convert_call(func): ...@@ -193,8 +193,10 @@ def convert_call(func):
try: try:
_, forward_func = unwrap_decorators(func.forward) _, forward_func = unwrap_decorators(func.forward)
forward_func = convert_to_static(forward_func) forward_func = convert_to_static(forward_func)
setattr(func, 'forward', forward_func) # Bound mothod will be convert into plain function after `convert_to_static`.
func_self = func # So descriptor mechanism is used to bound `self` instance on function to
# keep it as bound method.
setattr(func, 'forward', forward_func.__get__(func))
except Exception: except Exception:
# NOTE: func.forward may have been decorated. # NOTE: func.forward may have been decorated.
func_self = None if func_self else func_self func_self = None if func_self else func_self
......
...@@ -18,14 +18,24 @@ import unittest ...@@ -18,14 +18,24 @@ import unittest
from paddle import nn from paddle import nn
class LSTMLayer(nn.Layer):
def __init__(self, in_channels, hidden_size):
super(LSTMLayer, self).__init__()
self.cell = nn.LSTM(
in_channels, hidden_size, direction='bidirectional', num_layers=2)
def forward(self, x):
x, _ = self.cell(x)
return x
class Net(nn.Layer): class Net(nn.Layer):
def __init__(self, in_channels, hidden_size): def __init__(self, in_channels, hidden_size):
super(Net, self).__init__() super(Net, self).__init__()
self.lstm = nn.LSTM( self.lstm = LSTMLayer(in_channels, hidden_size)
in_channels, hidden_size, direction='bidirectional', num_layers=2)
def forward(self, x): def forward(self, x):
x, _ = self.lstm(x) x = self.lstm(x)
return x return x
...@@ -115,5 +125,21 @@ class TestSaveInEvalMode(unittest.TestCase): ...@@ -115,5 +125,21 @@ class TestSaveInEvalMode(unittest.TestCase):
infer_out)) infer_out))
class TestEvalAfterSave(unittest.TestCase):
def test_eval_after_save(self):
x = paddle.randn((2, 10, 12)).astype('float32')
net = Net(12, 2)
dy_out = net(x)
# save model
paddle.jit.save(net, 'jit.save/lstm', input_spec=[x])
load_net = paddle.jit.load('jit.save/lstm')
load_out = load_net(x)
self.assertTrue(np.allclose(dy_out.numpy(), load_out.numpy()))
# eval
net.eval()
out = net(x)
self.assertTrue(np.allclose(dy_out.numpy(), out.numpy()))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册