diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index 58aac8e266fedd0c76a6629d43d8a92f9757b03d..d5d0e8ab88b869a4fd000c63ac29b9dc0b45c8e1 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -261,6 +261,25 @@ class StaticFunction(object): # Note: Hold a reference to ProgramTranslator for switching `enable_to_static`. self._program_trans = ProgramTranslator() self._kwargs = kwargs + self._training = True + + def train(self): + if isinstance(self._class_instance, + layers.Layer) and self._class_instance.training == False: + raise RuntimeError( + "Failed to switch train mode. {} is a Layer's method, " + "please use Layer.train() to switch train mode.".format( + self.dygraph_function)) + self._training = True + + def eval(self): + if isinstance(self._class_instance, + layers.Layer) and self._class_instance.training == True: + raise RuntimeError( + "Failed to switch eval mode. {} is a Layer's method, " + "please use Layer.eval() to switch eval mode.".format( + self.dygraph_function)) + self._training = False def __get__(self, instance, owner): """ @@ -340,6 +359,8 @@ class StaticFunction(object): # 3. synchronize self.training attribute. if isinstance(self._class_instance, layers.Layer): partial_program_layer.training = self._class_instance.training + else: + partial_program_layer.training = self._training # 4. return outputs. try: diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py index 6fef356326b81d00a0eca205586cc0d8247c1e5a..c08a8d350f8aa83eb2c7e2eae8726917c02bba4f 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py @@ -297,5 +297,52 @@ class TestErrorWithInitFromStaticMode(unittest.TestCase): self.program_translator.get_program(net.forward, self.x) +class SwitchModeNet(paddle.nn.Layer): + def __init__(self): + super(SwitchModeNet, self).__init__() + + @paddle.jit.to_static + def forward(self, x): + return x + 1 + + @paddle.jit.to_static + def foo(self): + return True + + +@paddle.jit.to_static +def switch_mode_funciton(): + return True + + +class TestFunctionTrainEvalMode(unittest.TestCase): + def test_switch_mode(self): + paddle.disable_static() + switch_mode_funciton.eval() + switch_mode_funciton() + self.assertEqual(switch_mode_funciton._training, False) + _, partial_layer = switch_mode_funciton.program_cache.last()[-1] + self.assertEqual(partial_layer.training, False) + + switch_mode_funciton.train() + switch_mode_funciton() + self.assertEqual(switch_mode_funciton._training, True) + _, partial_layer = switch_mode_funciton.program_cache.last()[-1] + self.assertEqual(partial_layer.training, True) + + def test_raise_error(self): + paddle.disable_static() + net = SwitchModeNet() + + self.assertEqual(net.training, True) + with self.assertRaises(RuntimeError): + net.forward.eval() + + net.eval() + self.assertEqual(net.training, False) + with self.assertRaises(RuntimeError): + net.foo.train() + + if __name__ == '__main__': unittest.main()