From eb602398eaa078e5d3a13d20e15e1e4dcb38e4ee Mon Sep 17 00:00:00 2001 From: 0x45f <23097963+0x45f@users.noreply.github.com> Date: Mon, 22 Nov 2021 17:50:19 +0800 Subject: [PATCH] [Dy2stat]Allow users to switch eval/train mode when using @to_static to decorate a function (#37383) * Allow users to switch eval/train mode when using @to_static to decorate a function * refine code for train() and eval() --- .../dygraph_to_static/program_translator.py | 21 +++++++++ .../test_program_translator.py | 47 +++++++++++++++++++ 2 files changed, 68 insertions(+) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index 58aac8e266f..d5d0e8ab88b 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -261,6 +261,25 @@ class StaticFunction(object): # Note: Hold a reference to ProgramTranslator for switching `enable_to_static`. self._program_trans = ProgramTranslator() self._kwargs = kwargs + self._training = True + + def train(self): + if isinstance(self._class_instance, + layers.Layer) and self._class_instance.training == False: + raise RuntimeError( + "Failed to switch train mode. {} is a Layer's method, " + "please use Layer.train() to switch train mode.".format( + self.dygraph_function)) + self._training = True + + def eval(self): + if isinstance(self._class_instance, + layers.Layer) and self._class_instance.training == True: + raise RuntimeError( + "Failed to switch eval mode. {} is a Layer's method, " + "please use Layer.eval() to switch eval mode.".format( + self.dygraph_function)) + self._training = False def __get__(self, instance, owner): """ @@ -340,6 +359,8 @@ class StaticFunction(object): # 3. synchronize self.training attribute. if isinstance(self._class_instance, layers.Layer): partial_program_layer.training = self._class_instance.training + else: + partial_program_layer.training = self._training # 4. return outputs. try: diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py index 6fef356326b..c08a8d350f8 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py @@ -297,5 +297,52 @@ class TestErrorWithInitFromStaticMode(unittest.TestCase): self.program_translator.get_program(net.forward, self.x) +class SwitchModeNet(paddle.nn.Layer): + def __init__(self): + super(SwitchModeNet, self).__init__() + + @paddle.jit.to_static + def forward(self, x): + return x + 1 + + @paddle.jit.to_static + def foo(self): + return True + + +@paddle.jit.to_static +def switch_mode_funciton(): + return True + + +class TestFunctionTrainEvalMode(unittest.TestCase): + def test_switch_mode(self): + paddle.disable_static() + switch_mode_funciton.eval() + switch_mode_funciton() + self.assertEqual(switch_mode_funciton._training, False) + _, partial_layer = switch_mode_funciton.program_cache.last()[-1] + self.assertEqual(partial_layer.training, False) + + switch_mode_funciton.train() + switch_mode_funciton() + self.assertEqual(switch_mode_funciton._training, True) + _, partial_layer = switch_mode_funciton.program_cache.last()[-1] + self.assertEqual(partial_layer.training, True) + + def test_raise_error(self): + paddle.disable_static() + net = SwitchModeNet() + + self.assertEqual(net.training, True) + with self.assertRaises(RuntimeError): + net.forward.eval() + + net.eval() + self.assertEqual(net.training, False) + with self.assertRaises(RuntimeError): + net.foo.train() + + if __name__ == '__main__': unittest.main() -- GitLab