From c1a26e2a05b2b10be3b235df165d7f779d2a87fb Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Fri, 11 Dec 2020 03:47:52 -0600 Subject: [PATCH] fix train eval set error in static mode (#29540) --- python/paddle/fluid/dygraph/layers.py | 14 ++++++++++---- .../paddle/fluid/tests/unittests/test_layers.py | 17 +++++++++++++++++ 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index ad3a20869ce..3275a2126ed 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -133,8 +133,11 @@ class Layer(core.Layer): out = mylayer(x) """ - # global setting - framework._dygraph_tracer().train_mode() + # global setting in dygraph + # NOTE(chenweihang): nn.Layer also can be used in static mode, + # but _dygraph_tracer() can not be called in static mode + if in_dygraph_mode(): + framework._dygraph_tracer().train_mode() # Layer-level setting self.training = True for layer in self.sublayers(): @@ -171,8 +174,11 @@ class Layer(core.Layer): print(out) """ - # global setting - framework._dygraph_tracer().eval_mode() + # global setting in dygraph + # NOTE(chenweihang): nn.Layer also can be used in static mode, + # but _dygraph_tracer() can not be called in static mode + if in_dygraph_mode(): + framework._dygraph_tracer().eval_mode() # Layer-level setting self.training = False for layer in self.sublayers(): diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 8ae5264381e..35ecbd6bf10 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -3701,6 +3701,23 @@ class TestLayerParameterTrainableSet(unittest.TestCase): self.assertFalse(net.weight.trainable) +class TestLayerTrainingAttribute(unittest.TestCase): + def test_set_train_eval_in_dynamic_mode(self): + with fluid.dygraph.guard(): + net = paddle.nn.Dropout() + net.train() + self.assertTrue(net.training) + net.eval() + self.assertFalse(net.training) + + def test_set_train_eval_in_static_mode(self): + net = paddle.nn.Dropout() + net.train() + self.assertTrue(net.training) + net.eval() + self.assertFalse(net.training) + + if __name__ == '__main__': paddle.enable_static() unittest.main() -- GitLab