未验证 提交 eb602398 编写于 作者: 0 0x45f 提交者: GitHub

[Dy2stat]Allow users to switch eval/train mode when using @to_static to...

[Dy2stat]Allow users to switch eval/train mode when using @to_static to decorate a function (#37383)

* Allow users to switch eval/train mode when using @to_static to decorate a function

* refine code for train() and eval()
上级 e3503de8
...@@ -261,6 +261,25 @@ class StaticFunction(object): ...@@ -261,6 +261,25 @@ class StaticFunction(object):
# Note: Hold a reference to ProgramTranslator for switching `enable_to_static`. # Note: Hold a reference to ProgramTranslator for switching `enable_to_static`.
self._program_trans = ProgramTranslator() self._program_trans = ProgramTranslator()
self._kwargs = kwargs self._kwargs = kwargs
self._training = True
def train(self):
if isinstance(self._class_instance,
layers.Layer) and self._class_instance.training == False:
raise RuntimeError(
"Failed to switch train mode. {} is a Layer's method, "
"please use Layer.train() to switch train mode.".format(
self.dygraph_function))
self._training = True
def eval(self):
if isinstance(self._class_instance,
layers.Layer) and self._class_instance.training == True:
raise RuntimeError(
"Failed to switch eval mode. {} is a Layer's method, "
"please use Layer.eval() to switch eval mode.".format(
self.dygraph_function))
self._training = False
def __get__(self, instance, owner): def __get__(self, instance, owner):
""" """
...@@ -340,6 +359,8 @@ class StaticFunction(object): ...@@ -340,6 +359,8 @@ class StaticFunction(object):
# 3. synchronize self.training attribute. # 3. synchronize self.training attribute.
if isinstance(self._class_instance, layers.Layer): if isinstance(self._class_instance, layers.Layer):
partial_program_layer.training = self._class_instance.training partial_program_layer.training = self._class_instance.training
else:
partial_program_layer.training = self._training
# 4. return outputs. # 4. return outputs.
try: try:
......
...@@ -297,5 +297,52 @@ class TestErrorWithInitFromStaticMode(unittest.TestCase): ...@@ -297,5 +297,52 @@ class TestErrorWithInitFromStaticMode(unittest.TestCase):
self.program_translator.get_program(net.forward, self.x) self.program_translator.get_program(net.forward, self.x)
class SwitchModeNet(paddle.nn.Layer):
def __init__(self):
super(SwitchModeNet, self).__init__()
@paddle.jit.to_static
def forward(self, x):
return x + 1
@paddle.jit.to_static
def foo(self):
return True
@paddle.jit.to_static
def switch_mode_funciton():
return True
class TestFunctionTrainEvalMode(unittest.TestCase):
def test_switch_mode(self):
paddle.disable_static()
switch_mode_funciton.eval()
switch_mode_funciton()
self.assertEqual(switch_mode_funciton._training, False)
_, partial_layer = switch_mode_funciton.program_cache.last()[-1]
self.assertEqual(partial_layer.training, False)
switch_mode_funciton.train()
switch_mode_funciton()
self.assertEqual(switch_mode_funciton._training, True)
_, partial_layer = switch_mode_funciton.program_cache.last()[-1]
self.assertEqual(partial_layer.training, True)
def test_raise_error(self):
paddle.disable_static()
net = SwitchModeNet()
self.assertEqual(net.training, True)
with self.assertRaises(RuntimeError):
net.forward.eval()
net.eval()
self.assertEqual(net.training, False)
with self.assertRaises(RuntimeError):
net.foo.train()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册