diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index ad3a20869cede2b096f60d9f987cf971095f3e86..3275a2126eddee8044ef156320f4b829fb209695 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -133,8 +133,11 @@ class Layer(core.Layer): out = mylayer(x) """ - # global setting - framework._dygraph_tracer().train_mode() + # global setting in dygraph + # NOTE(chenweihang): nn.Layer also can be used in static mode, + # but _dygraph_tracer() can not be called in static mode + if in_dygraph_mode(): + framework._dygraph_tracer().train_mode() # Layer-level setting self.training = True for layer in self.sublayers(): @@ -171,8 +174,11 @@ class Layer(core.Layer): print(out) """ - # global setting - framework._dygraph_tracer().eval_mode() + # global setting in dygraph + # NOTE(chenweihang): nn.Layer also can be used in static mode, + # but _dygraph_tracer() can not be called in static mode + if in_dygraph_mode(): + framework._dygraph_tracer().eval_mode() # Layer-level setting self.training = False for layer in self.sublayers(): diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 8ae5264381e822062b2237507f94efc9b9daf15a..35ecbd6bf10c30823e0faa93a67906bcaa597b98 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -3701,6 +3701,23 @@ class TestLayerParameterTrainableSet(unittest.TestCase): self.assertFalse(net.weight.trainable) +class TestLayerTrainingAttribute(unittest.TestCase): + def test_set_train_eval_in_dynamic_mode(self): + with fluid.dygraph.guard(): + net = paddle.nn.Dropout() + net.train() + self.assertTrue(net.training) + net.eval() + self.assertFalse(net.training) + + def test_set_train_eval_in_static_mode(self): + net = paddle.nn.Dropout() + net.train() + self.assertTrue(net.training) + net.eval() + self.assertFalse(net.training) + + if __name__ == '__main__': paddle.enable_static() unittest.main()