From f616d9a7bde1e38dc27974a8aee3a40cd79dd21c Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Wed, 15 Dec 2021 19:12:52 +0800 Subject: [PATCH] Fix bugs in Translated Layer when change mode from train/eval to eval/train (#38141) * fix bugs in Translated layer when change train/eval * fix python converage --- python/paddle/fluid/dygraph/io.py | 2 ++ .../tests/unittests/test_translated_layer.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/python/paddle/fluid/dygraph/io.py b/python/paddle/fluid/dygraph/io.py index 75a27f25696..3fd690cab05 100644 --- a/python/paddle/fluid/dygraph/io.py +++ b/python/paddle/fluid/dygraph/io.py @@ -1286,9 +1286,11 @@ class TranslatedLayer(layers.Layer): def train(self): self._is_test = False + self.training = True def eval(self): self._is_test = True + self.training = False def program(self, method_name='forward'): """ diff --git a/python/paddle/fluid/tests/unittests/test_translated_layer.py b/python/paddle/fluid/tests/unittests/test_translated_layer.py index bf1ed1f06c5..79652b37b77 100644 --- a/python/paddle/fluid/tests/unittests/test_translated_layer.py +++ b/python/paddle/fluid/tests/unittests/test_translated_layer.py @@ -48,6 +48,7 @@ class LinearNet(nn.Layer): def __init__(self): super(LinearNet, self).__init__() self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM) + self._dropout = paddle.nn.Dropout(p=0.5) @paddle.jit.to_static(input_spec=[ paddle.static.InputSpec( @@ -183,6 +184,20 @@ class TestTranslatedLayer(unittest.TestCase): for spec_x, spec_y in zip(expect_spec, actual_spec): self.assertEqual(spec_x, spec_y) + def test_layer_state(self): + # load + translated_layer = paddle.jit.load(self.model_path) + translated_layer.eval() + self.assertEqual(translated_layer.training, False) + for layer in translated_layer.sublayers(): + print("123") + self.assertEqual(layer.training, False) + + translated_layer.train() + self.assertEqual(translated_layer.training, True) + for layer in translated_layer.sublayers(): + self.assertEqual(layer.training, True) + if __name__ == '__main__': unittest.main() -- GitLab