未验证 提交 c1a26e2a 编写于 作者: C Chen Weihang 提交者: GitHub

fix train eval set error in static mode (#29540)

上级 b5d4a1f3
......@@ -133,8 +133,11 @@ class Layer(core.Layer):
out = mylayer(x)
"""
# global setting
framework._dygraph_tracer().train_mode()
# global setting in dygraph
# NOTE(chenweihang): nn.Layer also can be used in static mode,
# but _dygraph_tracer() can not be called in static mode
if in_dygraph_mode():
framework._dygraph_tracer().train_mode()
# Layer-level setting
self.training = True
for layer in self.sublayers():
......@@ -171,8 +174,11 @@ class Layer(core.Layer):
print(out)
"""
# global setting
framework._dygraph_tracer().eval_mode()
# global setting in dygraph
# NOTE(chenweihang): nn.Layer also can be used in static mode,
# but _dygraph_tracer() can not be called in static mode
if in_dygraph_mode():
framework._dygraph_tracer().eval_mode()
# Layer-level setting
self.training = False
for layer in self.sublayers():
......
......@@ -3701,6 +3701,23 @@ class TestLayerParameterTrainableSet(unittest.TestCase):
self.assertFalse(net.weight.trainable)
class TestLayerTrainingAttribute(unittest.TestCase):
def test_set_train_eval_in_dynamic_mode(self):
with fluid.dygraph.guard():
net = paddle.nn.Dropout()
net.train()
self.assertTrue(net.training)
net.eval()
self.assertFalse(net.training)
def test_set_train_eval_in_static_mode(self):
net = paddle.nn.Dropout()
net.train()
self.assertTrue(net.training)
net.eval()
self.assertFalse(net.training)
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册