diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index fe60c24ff36ec7d3abd7ed1ae54217c2a1f310c6..fd824de3a1e6a45df7e0ad5dfa596d1817e68dd6 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -132,8 +132,11 @@ class Layer(core.Layer): out = mylayer(x) """ - # global setting - framework._dygraph_tracer().train_mode() + # global setting in dygraph + # NOTE(chenweihang): nn.Layer also can be used in static mode, + # but _dygraph_tracer() can not be called in static mode + if in_dygraph_mode(): + framework._dygraph_tracer().train_mode() # Layer-level setting self.training = True for layer in self.sublayers(): @@ -170,8 +173,11 @@ class Layer(core.Layer): print(out) """ - # global setting - framework._dygraph_tracer().eval_mode() + # global setting in dygraph + # NOTE(chenweihang): nn.Layer also can be used in static mode, + # but _dygraph_tracer() can not be called in static mode + if in_dygraph_mode(): + framework._dygraph_tracer().eval_mode() # Layer-level setting self.training = False for layer in self.sublayers(): diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 8ae5264381e822062b2237507f94efc9b9daf15a..35ecbd6bf10c30823e0faa93a67906bcaa597b98 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -3701,6 +3701,23 @@ class TestLayerParameterTrainableSet(unittest.TestCase): self.assertFalse(net.weight.trainable) +class TestLayerTrainingAttribute(unittest.TestCase): + def test_set_train_eval_in_dynamic_mode(self): + with fluid.dygraph.guard(): + net = paddle.nn.Dropout() + net.train() + self.assertTrue(net.training) + net.eval() + self.assertFalse(net.training) + + def test_set_train_eval_in_static_mode(self): + net = paddle.nn.Dropout() + net.train() + self.assertTrue(net.training) + net.eval() + self.assertFalse(net.training) + + if __name__ == '__main__': paddle.enable_static() unittest.main()