未验证 提交 8ec4af27 编写于 作者: A Aurelius84 提交者: GitHub

【Dy2stat】Fix is_test switch incorrectly in PartialProgram (#25809)

* fix eval() sync problem

* add unittest

* modify according reviewer

* fix unittest failed
上级 9b5a65b8
...@@ -112,14 +112,14 @@ class PartialProgramLayer(layers.Layer): ...@@ -112,14 +112,14 @@ class PartialProgramLayer(layers.Layer):
self._outputs = NestSequence(outputs, need_check=True) self._outputs = NestSequence(outputs, need_check=True)
self._params = parameters if parameters is not None else [] self._params = parameters if parameters is not None else []
self._infer_program = self._verify_program(main_program) main_program = self._verify_program(main_program)
self._train_program = self._append_backward_desc() self._infer_program = self._clone_for_test(main_program)
# Switch infer or train by train() and eval() self._train_program = self._append_backward_desc(main_program)
self._trace_program = None
self._set_grad_type(self._params) self._set_grad_type(self._params)
self._inner_scope = core.Scope() self._inner_scope = core.Scope()
# Set default mode to train # Set default mode to train
self.train() self.training = True
def _verify_program(self, main_program): def _verify_program(self, main_program):
""" """
...@@ -136,8 +136,8 @@ class PartialProgramLayer(layers.Layer): ...@@ -136,8 +136,8 @@ class PartialProgramLayer(layers.Layer):
return main_program return main_program
@switch_to_static_graph @switch_to_static_graph
def _append_backward_desc(self): def _append_backward_desc(self, main_program):
program = self._infer_program.clone() program = main_program.clone()
targets = [] targets = []
for out in self._outputs.tolist(): for out in self._outputs.tolist():
if isinstance(out, framework.Variable): if isinstance(out, framework.Variable):
...@@ -165,15 +165,6 @@ class PartialProgramLayer(layers.Layer): ...@@ -165,15 +165,6 @@ class PartialProgramLayer(layers.Layer):
self._params = required_params self._params = required_params
def train(self):
# self.training is inherited from layers.Layer
self.training = True
self._trace_program = self._train_program
def eval(self):
self.training = False
self._trace_program = self._infer_program
def forward(self, inputs): def forward(self, inputs):
in_vars, out_vars, tmp_scope_vec = self._prepare(inputs) in_vars, out_vars, tmp_scope_vec = self._prepare(inputs)
...@@ -186,7 +177,7 @@ class PartialProgramLayer(layers.Layer): ...@@ -186,7 +177,7 @@ class PartialProgramLayer(layers.Layer):
outputs={'Out': valid_vars(out_vars), outputs={'Out': valid_vars(out_vars),
'OutScope': tmp_scope_vec}, 'OutScope': tmp_scope_vec},
attrs={ attrs={
'global_block': self._trace_program.desc.block(0), 'global_block': self.program.desc.block(0),
'start_op_index': 0, 'start_op_index': 0,
'end_op_index': self._infer_program.desc.block(0).op_size(), 'end_op_index': self._infer_program.desc.block(0).op_size(),
'is_test': not self.training 'is_test': not self.training
...@@ -195,6 +186,10 @@ class PartialProgramLayer(layers.Layer): ...@@ -195,6 +186,10 @@ class PartialProgramLayer(layers.Layer):
restored_nest_out = self._restore_out(out_vars) restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out) return self._remove_no_value(restored_nest_out)
@property
def program(self):
return self._train_program if self.training else self._infer_program
def _prepare(self, inputs): def _prepare(self, inputs):
""" """
Prepare inputs, outputs, attrs. Prepare inputs, outputs, attrs.
...@@ -253,6 +248,10 @@ class PartialProgramLayer(layers.Layer): ...@@ -253,6 +248,10 @@ class PartialProgramLayer(layers.Layer):
return outs return outs
@switch_to_static_graph
def _clone_for_test(self, main_program):
return main_program.clone(for_test=True)
def _is_no_value(self, var): def _is_no_value(self, var):
if isinstance(var, core.VarBase): if isinstance(var, core.VarBase):
if var.shape == [1] and var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM: if var.shape == [1] and var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM:
......
...@@ -487,6 +487,8 @@ class ProgramTranslator(object): ...@@ -487,6 +487,8 @@ class ProgramTranslator(object):
_, partial_program_layer = self._program_cache[function_spec] _, partial_program_layer = self._program_cache[function_spec]
if args and isinstance(args[0], layers.Layer): if args and isinstance(args[0], layers.Layer):
# Synchronize self.training attribute.
partial_program_layer.training = args[0].training
args = args[1:] args = args[1:]
return partial_program_layer(args) return partial_program_layer(args)
......
...@@ -16,7 +16,9 @@ from __future__ import print_function ...@@ -16,7 +16,9 @@ from __future__ import print_function
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.layers.utils import flatten from paddle.fluid.layers.utils import flatten
from paddle.fluid.dygraph import declarative from paddle.fluid.dygraph import declarative, ProgramTranslator
from test_fetch_feed import Linear
import unittest import unittest
...@@ -121,5 +123,33 @@ class TestWithNestedOutput(unittest.TestCase): ...@@ -121,5 +123,33 @@ class TestWithNestedOutput(unittest.TestCase):
self.assertTrue(dy_var, st_var) self.assertTrue(dy_var, st_var)
class TestWithTrainAndEval(unittest.TestCase):
def test_switch_eval_and_train(self):
program_translator = ProgramTranslator()
with fluid.dygraph.guard():
linear_net = Linear()
x_data = np.random.random((4, 10)).astype('float32')
x = fluid.dygraph.to_variable(x_data)
linear_net(x)
_, partial_layer = program_translator.get_program_cache().last()[-1]
# check default mode is for training
self.assertEqual(partial_layer.program,
partial_layer._train_program)
# switch to run test program after `eval()`
linear_net.eval()
linear_net(x)
self.assertEqual(partial_layer.program,
partial_layer._infer_program)
# switch back into training
linear_net.train()
linear_net(x)
self.assertEqual(partial_layer.program,
partial_layer._train_program)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册