From 2500dca87843177ab401f295323f9ff7307eed61 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 5 Nov 2020 10:05:46 +0800 Subject: [PATCH] [Dy2Stat] Fix bug in convert_call (#28368) * Fix bug in convert_call * refine unittest * refine code * refine code * fix unittest failed * add assert --- .../dygraph_to_static/convert_call_func.py | 10 +++--- .../unittests/dygraph_to_static/test_lstm.py | 32 +++++++++++++++++-- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py index 57c36e80fd..9654a23852 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py @@ -142,7 +142,7 @@ def convert_call(func): # Note(Aurelius84): Because `@declarative` returns a class instance instead of # a function. This will modify the value referring to itself in `__globals__`. - # For example: + # For example: # # @declarative # def foo(x): @@ -150,7 +150,7 @@ def convert_call(func): # # `foo` will be converted into a wrapper class, suppose as `StaticFunction`. # And `foo.__globals__['foo']` will still return this `StaticFunction` instead of - # `foo` function. So `isinstance(fn, StaticFunction)` is added here. + # `foo` function. So `isinstance(fn, StaticFunction)` is added here. global_functions = set() for fn in func.__globals__.values(): if inspect.isfunction(fn): @@ -193,8 +193,10 @@ def convert_call(func): try: _, forward_func = unwrap_decorators(func.forward) forward_func = convert_to_static(forward_func) - setattr(func, 'forward', forward_func) - func_self = func + # Bound mothod will be convert into plain function after `convert_to_static`. + # So descriptor mechanism is used to bound `self` instance on function to + # keep it as bound method. + setattr(func, 'forward', forward_func.__get__(func)) except Exception: # NOTE: func.forward may have been decorated. func_self = None if func_self else func_self diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_lstm.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_lstm.py index 279c44d324..cfb4bb69a2 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_lstm.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_lstm.py @@ -18,14 +18,24 @@ import unittest from paddle import nn +class LSTMLayer(nn.Layer): + def __init__(self, in_channels, hidden_size): + super(LSTMLayer, self).__init__() + self.cell = nn.LSTM( + in_channels, hidden_size, direction='bidirectional', num_layers=2) + + def forward(self, x): + x, _ = self.cell(x) + return x + + class Net(nn.Layer): def __init__(self, in_channels, hidden_size): super(Net, self).__init__() - self.lstm = nn.LSTM( - in_channels, hidden_size, direction='bidirectional', num_layers=2) + self.lstm = LSTMLayer(in_channels, hidden_size) def forward(self, x): - x, _ = self.lstm(x) + x = self.lstm(x) return x @@ -115,5 +125,21 @@ class TestSaveInEvalMode(unittest.TestCase): infer_out)) +class TestEvalAfterSave(unittest.TestCase): + def test_eval_after_save(self): + x = paddle.randn((2, 10, 12)).astype('float32') + net = Net(12, 2) + dy_out = net(x) + # save model + paddle.jit.save(net, 'jit.save/lstm', input_spec=[x]) + load_net = paddle.jit.load('jit.save/lstm') + load_out = load_net(x) + self.assertTrue(np.allclose(dy_out.numpy(), load_out.numpy())) + # eval + net.eval() + out = net(x) + self.assertTrue(np.allclose(dy_out.numpy(), out.numpy())) + + if __name__ == "__main__": unittest.main() -- GitLab