diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py index 57c36e80fda88314926319a11200987428b387df..9654a23852024b8c31009eb68b7ea2b8cad864a8 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py @@ -142,7 +142,7 @@ def convert_call(func): # Note(Aurelius84): Because `@declarative` returns a class instance instead of # a function. This will modify the value referring to itself in `__globals__`. - # For example: + # For example: # # @declarative # def foo(x): @@ -150,7 +150,7 @@ def convert_call(func): # # `foo` will be converted into a wrapper class, suppose as `StaticFunction`. # And `foo.__globals__['foo']` will still return this `StaticFunction` instead of - # `foo` function. So `isinstance(fn, StaticFunction)` is added here. + # `foo` function. So `isinstance(fn, StaticFunction)` is added here. global_functions = set() for fn in func.__globals__.values(): if inspect.isfunction(fn): @@ -193,8 +193,10 @@ def convert_call(func): try: _, forward_func = unwrap_decorators(func.forward) forward_func = convert_to_static(forward_func) - setattr(func, 'forward', forward_func) - func_self = func + # Bound mothod will be convert into plain function after `convert_to_static`. + # So descriptor mechanism is used to bound `self` instance on function to + # keep it as bound method. + setattr(func, 'forward', forward_func.__get__(func)) except Exception: # NOTE: func.forward may have been decorated. func_self = None if func_self else func_self diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_lstm.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_lstm.py index 279c44d3245ea04b8c0593622b4ad0742ebfcfc2..cfb4bb69a2ea57312e007aa7a7df2f07d8daf3d5 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_lstm.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_lstm.py @@ -18,14 +18,24 @@ import unittest from paddle import nn +class LSTMLayer(nn.Layer): + def __init__(self, in_channels, hidden_size): + super(LSTMLayer, self).__init__() + self.cell = nn.LSTM( + in_channels, hidden_size, direction='bidirectional', num_layers=2) + + def forward(self, x): + x, _ = self.cell(x) + return x + + class Net(nn.Layer): def __init__(self, in_channels, hidden_size): super(Net, self).__init__() - self.lstm = nn.LSTM( - in_channels, hidden_size, direction='bidirectional', num_layers=2) + self.lstm = LSTMLayer(in_channels, hidden_size) def forward(self, x): - x, _ = self.lstm(x) + x = self.lstm(x) return x @@ -115,5 +125,21 @@ class TestSaveInEvalMode(unittest.TestCase): infer_out)) +class TestEvalAfterSave(unittest.TestCase): + def test_eval_after_save(self): + x = paddle.randn((2, 10, 12)).astype('float32') + net = Net(12, 2) + dy_out = net(x) + # save model + paddle.jit.save(net, 'jit.save/lstm', input_spec=[x]) + load_net = paddle.jit.load('jit.save/lstm') + load_out = load_net(x) + self.assertTrue(np.allclose(dy_out.numpy(), load_out.numpy())) + # eval + net.eval() + out = net(x) + self.assertTrue(np.allclose(dy_out.numpy(), out.numpy())) + + if __name__ == "__main__": unittest.main()