未验证 提交 8ec4af27 编写于 作者: A Aurelius84 提交者: GitHub

【Dy2stat】Fix is_test switch incorrectly in PartialProgram (#25809)

* fix eval() sync problem

* add unittest

* modify according reviewer

* fix unittest failed
上级 9b5a65b8
......@@ -112,14 +112,14 @@ class PartialProgramLayer(layers.Layer):
self._outputs = NestSequence(outputs, need_check=True)
self._params = parameters if parameters is not None else []
self._infer_program = self._verify_program(main_program)
self._train_program = self._append_backward_desc()
# Switch infer or train by train() and eval()
self._trace_program = None
main_program = self._verify_program(main_program)
self._infer_program = self._clone_for_test(main_program)
self._train_program = self._append_backward_desc(main_program)
self._set_grad_type(self._params)
self._inner_scope = core.Scope()
# Set default mode to train
self.train()
self.training = True
def _verify_program(self, main_program):
"""
......@@ -136,8 +136,8 @@ class PartialProgramLayer(layers.Layer):
return main_program
@switch_to_static_graph
def _append_backward_desc(self):
program = self._infer_program.clone()
def _append_backward_desc(self, main_program):
program = main_program.clone()
targets = []
for out in self._outputs.tolist():
if isinstance(out, framework.Variable):
......@@ -165,15 +165,6 @@ class PartialProgramLayer(layers.Layer):
self._params = required_params
def train(self):
# self.training is inherited from layers.Layer
self.training = True
self._trace_program = self._train_program
def eval(self):
self.training = False
self._trace_program = self._infer_program
def forward(self, inputs):
in_vars, out_vars, tmp_scope_vec = self._prepare(inputs)
......@@ -186,7 +177,7 @@ class PartialProgramLayer(layers.Layer):
outputs={'Out': valid_vars(out_vars),
'OutScope': tmp_scope_vec},
attrs={
'global_block': self._trace_program.desc.block(0),
'global_block': self.program.desc.block(0),
'start_op_index': 0,
'end_op_index': self._infer_program.desc.block(0).op_size(),
'is_test': not self.training
......@@ -195,6 +186,10 @@ class PartialProgramLayer(layers.Layer):
restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out)
@property
def program(self):
return self._train_program if self.training else self._infer_program
def _prepare(self, inputs):
"""
Prepare inputs, outputs, attrs.
......@@ -253,6 +248,10 @@ class PartialProgramLayer(layers.Layer):
return outs
@switch_to_static_graph
def _clone_for_test(self, main_program):
return main_program.clone(for_test=True)
def _is_no_value(self, var):
if isinstance(var, core.VarBase):
if var.shape == [1] and var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM:
......
......@@ -487,6 +487,8 @@ class ProgramTranslator(object):
_, partial_program_layer = self._program_cache[function_spec]
if args and isinstance(args[0], layers.Layer):
# Synchronize self.training attribute.
partial_program_layer.training = args[0].training
args = args[1:]
return partial_program_layer(args)
......
......@@ -16,7 +16,9 @@ from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.layers.utils import flatten
from paddle.fluid.dygraph import declarative
from paddle.fluid.dygraph import declarative, ProgramTranslator
from test_fetch_feed import Linear
import unittest
......@@ -121,5 +123,33 @@ class TestWithNestedOutput(unittest.TestCase):
self.assertTrue(dy_var, st_var)
class TestWithTrainAndEval(unittest.TestCase):
def test_switch_eval_and_train(self):
program_translator = ProgramTranslator()
with fluid.dygraph.guard():
linear_net = Linear()
x_data = np.random.random((4, 10)).astype('float32')
x = fluid.dygraph.to_variable(x_data)
linear_net(x)
_, partial_layer = program_translator.get_program_cache().last()[-1]
# check default mode is for training
self.assertEqual(partial_layer.program,
partial_layer._train_program)
# switch to run test program after `eval()`
linear_net.eval()
linear_net(x)
self.assertEqual(partial_layer.program,
partial_layer._infer_program)
# switch back into training
linear_net.train()
linear_net(x)
self.assertEqual(partial_layer.program,
partial_layer._train_program)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册