diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py index 0a9e66a5bb0b1b7f15928891f8eefcbc67ebffb5..05fce7bf837664eafd89319eb6cdd973b745605f 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py @@ -112,14 +112,14 @@ class PartialProgramLayer(layers.Layer): self._outputs = NestSequence(outputs, need_check=True) self._params = parameters if parameters is not None else [] - self._infer_program = self._verify_program(main_program) - self._train_program = self._append_backward_desc() - # Switch infer or train by train() and eval() - self._trace_program = None + main_program = self._verify_program(main_program) + self._infer_program = self._clone_for_test(main_program) + self._train_program = self._append_backward_desc(main_program) + self._set_grad_type(self._params) self._inner_scope = core.Scope() # Set default mode to train - self.train() + self.training = True def _verify_program(self, main_program): """ @@ -136,8 +136,8 @@ class PartialProgramLayer(layers.Layer): return main_program @switch_to_static_graph - def _append_backward_desc(self): - program = self._infer_program.clone() + def _append_backward_desc(self, main_program): + program = main_program.clone() targets = [] for out in self._outputs.tolist(): if isinstance(out, framework.Variable): @@ -165,15 +165,6 @@ class PartialProgramLayer(layers.Layer): self._params = required_params - def train(self): - # self.training is inherited from layers.Layer - self.training = True - self._trace_program = self._train_program - - def eval(self): - self.training = False - self._trace_program = self._infer_program - def forward(self, inputs): in_vars, out_vars, tmp_scope_vec = self._prepare(inputs) @@ -186,7 +177,7 @@ class PartialProgramLayer(layers.Layer): outputs={'Out': valid_vars(out_vars), 'OutScope': tmp_scope_vec}, attrs={ - 'global_block': self._trace_program.desc.block(0), + 'global_block': self.program.desc.block(0), 'start_op_index': 0, 'end_op_index': self._infer_program.desc.block(0).op_size(), 'is_test': not self.training @@ -195,6 +186,10 @@ class PartialProgramLayer(layers.Layer): restored_nest_out = self._restore_out(out_vars) return self._remove_no_value(restored_nest_out) + @property + def program(self): + return self._train_program if self.training else self._infer_program + def _prepare(self, inputs): """ Prepare inputs, outputs, attrs. @@ -253,6 +248,10 @@ class PartialProgramLayer(layers.Layer): return outs + @switch_to_static_graph + def _clone_for_test(self, main_program): + return main_program.clone(for_test=True) + def _is_no_value(self, var): if isinstance(var, core.VarBase): if var.shape == [1] and var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM: diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index 64fbb51f9a5f7a2937b5f7791cf0a004517bceab..6272f7369ec6db0cf7b3e5d82f689ddabf3e19ab 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -487,6 +487,8 @@ class ProgramTranslator(object): _, partial_program_layer = self._program_cache[function_spec] if args and isinstance(args[0], layers.Layer): + # Synchronize self.training attribute. + partial_program_layer.training = args[0].training args = args[1:] return partial_program_layer(args) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_partial_program.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_partial_program.py index 2f67710649b05bc0dea38f126bfc87ef473c7ffe..3da60e955deee9b6d4c74ba5ff1a550ae135afdb 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_partial_program.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_partial_program.py @@ -16,7 +16,9 @@ from __future__ import print_function import numpy as np import paddle.fluid as fluid from paddle.fluid.layers.utils import flatten -from paddle.fluid.dygraph import declarative +from paddle.fluid.dygraph import declarative, ProgramTranslator + +from test_fetch_feed import Linear import unittest @@ -121,5 +123,33 @@ class TestWithNestedOutput(unittest.TestCase): self.assertTrue(dy_var, st_var) +class TestWithTrainAndEval(unittest.TestCase): + def test_switch_eval_and_train(self): + program_translator = ProgramTranslator() + + with fluid.dygraph.guard(): + linear_net = Linear() + x_data = np.random.random((4, 10)).astype('float32') + x = fluid.dygraph.to_variable(x_data) + linear_net(x) + + _, partial_layer = program_translator.get_program_cache().last()[-1] + # check default mode is for training + self.assertEqual(partial_layer.program, + partial_layer._train_program) + + # switch to run test program after `eval()` + linear_net.eval() + linear_net(x) + self.assertEqual(partial_layer.program, + partial_layer._infer_program) + + # switch back into training + linear_net.train() + linear_net(x) + self.assertEqual(partial_layer.program, + partial_layer._train_program) + + if __name__ == '__main__': unittest.main()