未验证 提交 d6aee759 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Set buff.persistable=False when it's not initialized (#28749)

上级 1a532d51
...@@ -79,8 +79,12 @@ def param_guard(parameters): ...@@ -79,8 +79,12 @@ def param_guard(parameters):
# `mask` Tensor or `hidden_0` in RNN layers, which is equivalent to a Parameter # `mask` Tensor or `hidden_0` in RNN layers, which is equivalent to a Parameter
# and necessary for inferring. It will be pruned if it's not necessary for inferring. # and necessary for inferring. It will be pruned if it's not necessary for inferring.
else: else:
# But if its shape is empty while created from `create_variable()`, we consider this buffer
# non-persistable. See case of `drop_state` in lstm api.
is_persistable = len(var_base.shape) > 0
new_var = var_base._to_static_var( new_var = var_base._to_static_var(
to_parameter=False, persistable=True) to_parameter=False, persistable=is_persistable)
parameters[name] = new_var parameters[name] = new_var
yield yield
parameters.update(origin_parameters) parameters.update(origin_parameters)
......
...@@ -61,25 +61,26 @@ class TestLstm(unittest.TestCase): ...@@ -61,25 +61,26 @@ class TestLstm(unittest.TestCase):
msg='dygraph_out is {}\n static_out is \n{}'.format(dygraph_out, msg='dygraph_out is {}\n static_out is \n{}'.format(dygraph_out,
static_out)) static_out))
def test_save_in_eval(self): def test_save_in_eval(self, with_training=True):
paddle.jit.ProgramTranslator().enable(True) paddle.jit.ProgramTranslator().enable(True)
net = Net(12, 2) net = Net(12, 2)
x = paddle.randn((2, 10, 12)) x = paddle.randn((2, 10, 12))
x.stop_gradient = False if with_training:
dygraph_out = net(x) x.stop_gradient = False
loss = paddle.mean(dygraph_out) dygraph_out = net(x)
sgd = paddle.optimizer.SGD(learning_rate=0.001, loss = paddle.mean(dygraph_out)
parameters=net.parameters()) sgd = paddle.optimizer.SGD(learning_rate=0.001,
loss.backward() parameters=net.parameters())
sgd.step() loss.backward()
sgd.step()
# switch eval mode firstly # switch eval mode firstly
net.eval() net.eval()
x = paddle.randn((2, 10, 12)) x = paddle.randn((2, 10, 12))
dygraph_out = net(x)
dropout_out = net(x)
net = paddle.jit.to_static( net = paddle.jit.to_static(
net, input_spec=[paddle.static.InputSpec(shape=[-1, 10, 12])]) net, input_spec=[paddle.static.InputSpec(shape=[-1, 10, 12])])
paddle.jit.save(net, 'simple_lstm') paddle.jit.save(net, 'simple_lstm')
dygraph_out = net(x)
# load saved model # load saved model
load_net = paddle.jit.load('simple_lstm') load_net = paddle.jit.load('simple_lstm')
...@@ -96,6 +97,9 @@ class TestLstm(unittest.TestCase): ...@@ -96,6 +97,9 @@ class TestLstm(unittest.TestCase):
msg='dygraph_out is {}\n static_out is \n{}'.format(dygraph_out, msg='dygraph_out is {}\n static_out is \n{}'.format(dygraph_out,
train_out)) train_out))
def test_save_without_training(self):
self.test_save_in_eval(with_training=False)
class LinearNet(nn.Layer): class LinearNet(nn.Layer):
def __init__(self): def __init__(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册