未验证 提交 c1a26e2a 编写于 作者: C Chen Weihang 提交者: GitHub

fix train eval set error in static mode (#29540)

上级 b5d4a1f3
...@@ -133,7 +133,10 @@ class Layer(core.Layer): ...@@ -133,7 +133,10 @@ class Layer(core.Layer):
out = mylayer(x) out = mylayer(x)
""" """
# global setting # global setting in dygraph
# NOTE(chenweihang): nn.Layer also can be used in static mode,
# but _dygraph_tracer() can not be called in static mode
if in_dygraph_mode():
framework._dygraph_tracer().train_mode() framework._dygraph_tracer().train_mode()
# Layer-level setting # Layer-level setting
self.training = True self.training = True
...@@ -171,7 +174,10 @@ class Layer(core.Layer): ...@@ -171,7 +174,10 @@ class Layer(core.Layer):
print(out) print(out)
""" """
# global setting # global setting in dygraph
# NOTE(chenweihang): nn.Layer also can be used in static mode,
# but _dygraph_tracer() can not be called in static mode
if in_dygraph_mode():
framework._dygraph_tracer().eval_mode() framework._dygraph_tracer().eval_mode()
# Layer-level setting # Layer-level setting
self.training = False self.training = False
......
...@@ -3701,6 +3701,23 @@ class TestLayerParameterTrainableSet(unittest.TestCase): ...@@ -3701,6 +3701,23 @@ class TestLayerParameterTrainableSet(unittest.TestCase):
self.assertFalse(net.weight.trainable) self.assertFalse(net.weight.trainable)
class TestLayerTrainingAttribute(unittest.TestCase):
def test_set_train_eval_in_dynamic_mode(self):
with fluid.dygraph.guard():
net = paddle.nn.Dropout()
net.train()
self.assertTrue(net.training)
net.eval()
self.assertFalse(net.training)
def test_set_train_eval_in_static_mode(self):
net = paddle.nn.Dropout()
net.train()
self.assertTrue(net.training)
net.eval()
self.assertFalse(net.training)
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册