From eed736dcde108a983030396f324142f636087b2c Mon Sep 17 00:00:00 2001 From: 0x45f <23097963+0x45f@users.noreply.github.com> Date: Tue, 23 Nov 2021 10:20:09 +0800 Subject: [PATCH] [Dy2stat]Allow users to switch eval/train mode when using @to_static to decorate a function (#37383) (#37432) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 本PR之前使用@to_static装饰一个单独的function时,对于生成的Program无法切换train/eval模式,只能运行在train模式下。这也就导致动转静后用户多次调用function显存会一直增长。 本PR之后,使用@to_static装饰一个单独的function时,可以通过function.train()或者function.eval()的方式来切换train/eval模式。 --- .../dygraph_to_static/program_translator.py | 21 +++++++++ .../test_program_translator.py | 47 +++++++++++++++++++ 2 files changed, 68 insertions(+) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index 58aac8e266..d5d0e8ab88 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -261,6 +261,25 @@ class StaticFunction(object): # Note: Hold a reference to ProgramTranslator for switching `enable_to_static`. self._program_trans = ProgramTranslator() self._kwargs = kwargs + self._training = True + + def train(self): + if isinstance(self._class_instance, + layers.Layer) and self._class_instance.training == False: + raise RuntimeError( + "Failed to switch train mode. {} is a Layer's method, " + "please use Layer.train() to switch train mode.".format( + self.dygraph_function)) + self._training = True + + def eval(self): + if isinstance(self._class_instance, + layers.Layer) and self._class_instance.training == True: + raise RuntimeError( + "Failed to switch eval mode. {} is a Layer's method, " + "please use Layer.eval() to switch eval mode.".format( + self.dygraph_function)) + self._training = False def __get__(self, instance, owner): """ @@ -340,6 +359,8 @@ class StaticFunction(object): # 3. synchronize self.training attribute. if isinstance(self._class_instance, layers.Layer): partial_program_layer.training = self._class_instance.training + else: + partial_program_layer.training = self._training # 4. return outputs. try: diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py index 6fef356326..c08a8d350f 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py @@ -297,5 +297,52 @@ class TestErrorWithInitFromStaticMode(unittest.TestCase): self.program_translator.get_program(net.forward, self.x) +class SwitchModeNet(paddle.nn.Layer): + def __init__(self): + super(SwitchModeNet, self).__init__() + + @paddle.jit.to_static + def forward(self, x): + return x + 1 + + @paddle.jit.to_static + def foo(self): + return True + + +@paddle.jit.to_static +def switch_mode_funciton(): + return True + + +class TestFunctionTrainEvalMode(unittest.TestCase): + def test_switch_mode(self): + paddle.disable_static() + switch_mode_funciton.eval() + switch_mode_funciton() + self.assertEqual(switch_mode_funciton._training, False) + _, partial_layer = switch_mode_funciton.program_cache.last()[-1] + self.assertEqual(partial_layer.training, False) + + switch_mode_funciton.train() + switch_mode_funciton() + self.assertEqual(switch_mode_funciton._training, True) + _, partial_layer = switch_mode_funciton.program_cache.last()[-1] + self.assertEqual(partial_layer.training, True) + + def test_raise_error(self): + paddle.disable_static() + net = SwitchModeNet() + + self.assertEqual(net.training, True) + with self.assertRaises(RuntimeError): + net.forward.eval() + + net.eval() + self.assertEqual(net.training, False) + with self.assertRaises(RuntimeError): + net.foo.train() + + if __name__ == '__main__': unittest.main() -- GitLab