diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py index 1004665ca15fbc2458c1626735f161c7f4904596..feb8b0f9c9a16e2d418b12be0397ea11c890dfe7 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py @@ -82,6 +82,29 @@ class NestSequence(object): return self.tolist()[item] +class LazyInitialized(object): + """ + Descriptor to implement lazy initialization of property. + """ + + def __init__(self, function): + self.function = function + + def __get__(self, instance, cls): + val = self.function(instance) + setattr(instance, self.function.__name__, val) + return val + + +def _change_is_test_status(program, is_test): + # change all `is_test` attributes + for block in program.blocks: + for op in block.ops: + if op.has_attr('is_test'): + op._set_attr('is_test', is_test) + return program + + class PartialProgramLayer(layers.Layer): """ PartialProgramLayer wraps all the ops from layers decorated by `@declarative` @@ -109,15 +132,30 @@ class PartialProgramLayer(layers.Layer): self._outputs = NestSequence(outputs, need_check=True) self._params = parameters if parameters is not None else [] - main_program = self._verify_program(main_program) - self._infer_program = self._clone_for_test(main_program) - self._train_program = self._append_backward_desc(main_program) - - self._set_grad_type(self._params) + self._origin_main_program = self._verify_program(main_program) self._inner_scope = core.Scope() # Set default mode to train self.training = True + @LazyInitialized + def _infer_program(self): + """ + Lazy initialized property of infer_program. + """ + return self._clone_for_test(self._origin_main_program) + + @LazyInitialized + def _train_program(self): + """ + Lazy initialized property of train_program. + """ + train_program = self._append_backward_desc(self._origin_main_program) + # Note: Only set grad type once after initializing train program. So we + # put it here. + self._set_grad_type(self._params, train_program) + + return train_program + def _verify_program(self, main_program): """ Verify that the program parameter is initialized, prune some unused params, @@ -132,7 +170,8 @@ class PartialProgramLayer(layers.Layer): @switch_to_static_graph def _append_backward_desc(self, main_program): - program = main_program.clone() + # make sure all status of is_test are False in train mode. + program = _change_is_test_status(main_program.clone(), is_test=False) targets = [] for out in self._outputs.tolist(): if isinstance(out, framework.Variable): @@ -280,7 +319,7 @@ class PartialProgramLayer(layers.Layer): return out_vars - def _set_grad_type(self, params): + def _set_grad_type(self, params, train_program): # NOTE: if user set sparse gradient mode, the param's gradient # will be SelectedRows, not LoDTensor. But tracer will just # set param grad VarBase by forward VarBase(LoDTensor) @@ -289,7 +328,7 @@ class PartialProgramLayer(layers.Layer): # be user wanted result. for param in params: grad_name = param.name + core.grad_var_suffix() - grad_var = self._train_program.desc.block(0).find_var( + grad_var = train_program.desc.block(0).find_var( cpt.to_bytes(grad_name)) # NOTE: cannot find var desc maybe no problem, such as in batch_norm if grad_var is None: diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_lstm.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_lstm.py index e83128f045d8b94e8ec335c5dcc6ad8ca07548e4..1ed06f24bd05da9e844f74bdc6212a65d6e3fefb 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_lstm.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_lstm.py @@ -51,6 +51,32 @@ class TestLstm(unittest.TestCase): msg='dygraph_out is {}\n static_out is \n{}'.format(dygraph_out, static_out)) + def test_save_in_eval(self): + paddle.jit.ProgramTranslator().enable(True) + net = Net(12, 2) + # switch eval mode firstly + net.eval() + net = paddle.jit.to_static( + net, input_spec=[paddle.static.InputSpec(shape=[-1, 10, 12])]) + paddle.jit.save(net, 'simple_lstm') + # load saved model + load_net = paddle.jit.load('simple_lstm') + + x = paddle.randn((2, 10, 12)) + dygraph_out = net(x) + static_out = load_net(x) + self.assertTrue( + np.allclose(dygraph_out.numpy(), static_out.numpy()), + msg='dygraph_out is {}\n static_out is \n{}'.format(dygraph_out, + static_out)) + # switch back into train mode. + net.train() + train_out = net(x) + self.assertTrue( + np.allclose(dygraph_out.numpy(), train_out.numpy()), + msg='dygraph_out is {}\n static_out is \n{}'.format(dygraph_out, + train_out)) + if __name__ == "__main__": unittest.main()