From d6aee7597cc3c94adf897991860fef9744047c03 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Fri, 20 Nov 2020 18:57:13 +0800 Subject: [PATCH] [Dy2Stat]Set buff.persistable=False when it's not initialized (#28749) --- python/paddle/fluid/dygraph/base.py | 6 ++++- .../unittests/dygraph_to_static/test_lstm.py | 24 +++++++++++-------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/python/paddle/fluid/dygraph/base.py b/python/paddle/fluid/dygraph/base.py index 5f0d8e08982..a26b903493a 100644 --- a/python/paddle/fluid/dygraph/base.py +++ b/python/paddle/fluid/dygraph/base.py @@ -79,8 +79,12 @@ def param_guard(parameters): # `mask` Tensor or `hidden_0` in RNN layers, which is equivalent to a Parameter # and necessary for inferring. It will be pruned if it's not necessary for inferring. else: + # But if its shape is empty while created from `create_variable()`, we consider this buffer + # non-persistable. See case of `drop_state` in lstm api. + is_persistable = len(var_base.shape) > 0 + new_var = var_base._to_static_var( - to_parameter=False, persistable=True) + to_parameter=False, persistable=is_persistable) parameters[name] = new_var yield parameters.update(origin_parameters) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_lstm.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_lstm.py index cab858f0480..cce2a383dd8 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_lstm.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_lstm.py @@ -61,25 +61,26 @@ class TestLstm(unittest.TestCase): msg='dygraph_out is {}\n static_out is \n{}'.format(dygraph_out, static_out)) - def test_save_in_eval(self): + def test_save_in_eval(self, with_training=True): paddle.jit.ProgramTranslator().enable(True) net = Net(12, 2) x = paddle.randn((2, 10, 12)) - x.stop_gradient = False - dygraph_out = net(x) - loss = paddle.mean(dygraph_out) - sgd = paddle.optimizer.SGD(learning_rate=0.001, - parameters=net.parameters()) - loss.backward() - sgd.step() + if with_training: + x.stop_gradient = False + dygraph_out = net(x) + loss = paddle.mean(dygraph_out) + sgd = paddle.optimizer.SGD(learning_rate=0.001, + parameters=net.parameters()) + loss.backward() + sgd.step() # switch eval mode firstly net.eval() x = paddle.randn((2, 10, 12)) - dygraph_out = net(x) - dropout_out = net(x) net = paddle.jit.to_static( net, input_spec=[paddle.static.InputSpec(shape=[-1, 10, 12])]) paddle.jit.save(net, 'simple_lstm') + + dygraph_out = net(x) # load saved model load_net = paddle.jit.load('simple_lstm') @@ -96,6 +97,9 @@ class TestLstm(unittest.TestCase): msg='dygraph_out is {}\n static_out is \n{}'.format(dygraph_out, train_out)) + def test_save_without_training(self): + self.test_save_in_eval(with_training=False) + class LinearNet(nn.Layer): def __init__(self): -- GitLab