From 61d7027776f7bd4a16cbb4931fd70599cab22186 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Wed, 16 Dec 2020 20:15:02 -0600 Subject: [PATCH] [Cherry-pick][bug fix] fix train eval set error in static mode (#29540) (#29571) Fix Layer train eval setting failed in static mode, more details please see #29540 --- python/paddle/fluid/dygraph/layers.py | 14 ++++++++++---- .../paddle/fluid/tests/unittests/test_layers.py | 17 +++++++++++++++++ 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index fe60c24ff3..fd824de3a1 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -132,8 +132,11 @@ class Layer(core.Layer): out = mylayer(x) """ - # global setting - framework._dygraph_tracer().train_mode() + # global setting in dygraph + # NOTE(chenweihang): nn.Layer also can be used in static mode, + # but _dygraph_tracer() can not be called in static mode + if in_dygraph_mode(): + framework._dygraph_tracer().train_mode() # Layer-level setting self.training = True for layer in self.sublayers(): @@ -170,8 +173,11 @@ class Layer(core.Layer): print(out) """ - # global setting - framework._dygraph_tracer().eval_mode() + # global setting in dygraph + # NOTE(chenweihang): nn.Layer also can be used in static mode, + # but _dygraph_tracer() can not be called in static mode + if in_dygraph_mode(): + framework._dygraph_tracer().eval_mode() # Layer-level setting self.training = False for layer in self.sublayers(): diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 8ae5264381..35ecbd6bf1 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -3701,6 +3701,23 @@ class TestLayerParameterTrainableSet(unittest.TestCase): self.assertFalse(net.weight.trainable) +class TestLayerTrainingAttribute(unittest.TestCase): + def test_set_train_eval_in_dynamic_mode(self): + with fluid.dygraph.guard(): + net = paddle.nn.Dropout() + net.train() + self.assertTrue(net.training) + net.eval() + self.assertFalse(net.training) + + def test_set_train_eval_in_static_mode(self): + net = paddle.nn.Dropout() + net.train() + self.assertTrue(net.training) + net.eval() + self.assertFalse(net.training) + + if __name__ == '__main__': paddle.enable_static() unittest.main() -- GitLab