未验证 提交 ffcc1175 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat] Fix Error when generating train_program in eval mode (#27975)

* Fix save in eval mode

* remove assert statement

* fix test_partial_program failed

* add more test

* modify back into _train_program
上级 05fd49e9
......@@ -82,6 +82,29 @@ class NestSequence(object):
return self.tolist()[item]
class LazyInitialized(object):
"""
Descriptor to implement lazy initialization of property.
"""
def __init__(self, function):
self.function = function
def __get__(self, instance, cls):
val = self.function(instance)
setattr(instance, self.function.__name__, val)
return val
def _change_is_test_status(program, is_test):
# change all `is_test` attributes
for block in program.blocks:
for op in block.ops:
if op.has_attr('is_test'):
op._set_attr('is_test', is_test)
return program
class PartialProgramLayer(layers.Layer):
"""
PartialProgramLayer wraps all the ops from layers decorated by `@declarative`
......@@ -109,15 +132,30 @@ class PartialProgramLayer(layers.Layer):
self._outputs = NestSequence(outputs, need_check=True)
self._params = parameters if parameters is not None else []
main_program = self._verify_program(main_program)
self._infer_program = self._clone_for_test(main_program)
self._train_program = self._append_backward_desc(main_program)
self._set_grad_type(self._params)
self._origin_main_program = self._verify_program(main_program)
self._inner_scope = core.Scope()
# Set default mode to train
self.training = True
@LazyInitialized
def _infer_program(self):
"""
Lazy initialized property of infer_program.
"""
return self._clone_for_test(self._origin_main_program)
@LazyInitialized
def _train_program(self):
"""
Lazy initialized property of train_program.
"""
train_program = self._append_backward_desc(self._origin_main_program)
# Note: Only set grad type once after initializing train program. So we
# put it here.
self._set_grad_type(self._params, train_program)
return train_program
def _verify_program(self, main_program):
"""
Verify that the program parameter is initialized, prune some unused params,
......@@ -132,7 +170,8 @@ class PartialProgramLayer(layers.Layer):
@switch_to_static_graph
def _append_backward_desc(self, main_program):
program = main_program.clone()
# make sure all status of is_test are False in train mode.
program = _change_is_test_status(main_program.clone(), is_test=False)
targets = []
for out in self._outputs.tolist():
if isinstance(out, framework.Variable):
......@@ -280,7 +319,7 @@ class PartialProgramLayer(layers.Layer):
return out_vars
def _set_grad_type(self, params):
def _set_grad_type(self, params, train_program):
# NOTE: if user set sparse gradient mode, the param's gradient
# will be SelectedRows, not LoDTensor. But tracer will just
# set param grad VarBase by forward VarBase(LoDTensor)
......@@ -289,7 +328,7 @@ class PartialProgramLayer(layers.Layer):
# be user wanted result.
for param in params:
grad_name = param.name + core.grad_var_suffix()
grad_var = self._train_program.desc.block(0).find_var(
grad_var = train_program.desc.block(0).find_var(
cpt.to_bytes(grad_name))
# NOTE: cannot find var desc maybe no problem, such as in batch_norm
if grad_var is None:
......
......@@ -51,6 +51,32 @@ class TestLstm(unittest.TestCase):
msg='dygraph_out is {}\n static_out is \n{}'.format(dygraph_out,
static_out))
def test_save_in_eval(self):
paddle.jit.ProgramTranslator().enable(True)
net = Net(12, 2)
# switch eval mode firstly
net.eval()
net = paddle.jit.to_static(
net, input_spec=[paddle.static.InputSpec(shape=[-1, 10, 12])])
paddle.jit.save(net, 'simple_lstm')
# load saved model
load_net = paddle.jit.load('simple_lstm')
x = paddle.randn((2, 10, 12))
dygraph_out = net(x)
static_out = load_net(x)
self.assertTrue(
np.allclose(dygraph_out.numpy(), static_out.numpy()),
msg='dygraph_out is {}\n static_out is \n{}'.format(dygraph_out,
static_out))
# switch back into train mode.
net.train()
train_out = net(x)
self.assertTrue(
np.allclose(dygraph_out.numpy(), train_out.numpy()),
msg='dygraph_out is {}\n static_out is \n{}'.format(dygraph_out,
train_out))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册