未验证 提交 d6aee759 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Set buff.persistable=False when it's not initialized (#28749)

上级 1a532d51
......@@ -79,8 +79,12 @@ def param_guard(parameters):
# `mask` Tensor or `hidden_0` in RNN layers, which is equivalent to a Parameter
# and necessary for inferring. It will be pruned if it's not necessary for inferring.
else:
# But if its shape is empty while created from `create_variable()`, we consider this buffer
# non-persistable. See case of `drop_state` in lstm api.
is_persistable = len(var_base.shape) > 0
new_var = var_base._to_static_var(
to_parameter=False, persistable=True)
to_parameter=False, persistable=is_persistable)
parameters[name] = new_var
yield
parameters.update(origin_parameters)
......
......@@ -61,25 +61,26 @@ class TestLstm(unittest.TestCase):
msg='dygraph_out is {}\n static_out is \n{}'.format(dygraph_out,
static_out))
def test_save_in_eval(self):
def test_save_in_eval(self, with_training=True):
paddle.jit.ProgramTranslator().enable(True)
net = Net(12, 2)
x = paddle.randn((2, 10, 12))
x.stop_gradient = False
dygraph_out = net(x)
loss = paddle.mean(dygraph_out)
sgd = paddle.optimizer.SGD(learning_rate=0.001,
parameters=net.parameters())
loss.backward()
sgd.step()
if with_training:
x.stop_gradient = False
dygraph_out = net(x)
loss = paddle.mean(dygraph_out)
sgd = paddle.optimizer.SGD(learning_rate=0.001,
parameters=net.parameters())
loss.backward()
sgd.step()
# switch eval mode firstly
net.eval()
x = paddle.randn((2, 10, 12))
dygraph_out = net(x)
dropout_out = net(x)
net = paddle.jit.to_static(
net, input_spec=[paddle.static.InputSpec(shape=[-1, 10, 12])])
paddle.jit.save(net, 'simple_lstm')
dygraph_out = net(x)
# load saved model
load_net = paddle.jit.load('simple_lstm')
......@@ -96,6 +97,9 @@ class TestLstm(unittest.TestCase):
msg='dygraph_out is {}\n static_out is \n{}'.format(dygraph_out,
train_out))
def test_save_without_training(self):
self.test_save_in_eval(with_training=False)
class LinearNet(nn.Layer):
def __init__(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册