diff --git a/python/paddle/fluid/dygraph/io.py b/python/paddle/fluid/dygraph/io.py index 75a27f256962c9b897c5b325b1e4aada90e7c13b..3fd690cab058ae19cb26117c7665876284a925c8 100644 --- a/python/paddle/fluid/dygraph/io.py +++ b/python/paddle/fluid/dygraph/io.py @@ -1286,9 +1286,11 @@ class TranslatedLayer(layers.Layer): def train(self): self._is_test = False + self.training = True def eval(self): self._is_test = True + self.training = False def program(self, method_name='forward'): """ diff --git a/python/paddle/fluid/tests/unittests/test_translated_layer.py b/python/paddle/fluid/tests/unittests/test_translated_layer.py index bf1ed1f06c5721e7d5df1221b02cb4ea5bfe11ea..79652b37b7708992aa2595ce16687abffefaf4e1 100644 --- a/python/paddle/fluid/tests/unittests/test_translated_layer.py +++ b/python/paddle/fluid/tests/unittests/test_translated_layer.py @@ -48,6 +48,7 @@ class LinearNet(nn.Layer): def __init__(self): super(LinearNet, self).__init__() self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM) + self._dropout = paddle.nn.Dropout(p=0.5) @paddle.jit.to_static(input_spec=[ paddle.static.InputSpec( @@ -183,6 +184,20 @@ class TestTranslatedLayer(unittest.TestCase): for spec_x, spec_y in zip(expect_spec, actual_spec): self.assertEqual(spec_x, spec_y) + def test_layer_state(self): + # load + translated_layer = paddle.jit.load(self.model_path) + translated_layer.eval() + self.assertEqual(translated_layer.training, False) + for layer in translated_layer.sublayers(): + print("123") + self.assertEqual(layer.training, False) + + translated_layer.train() + self.assertEqual(translated_layer.training, True) + for layer in translated_layer.sublayers(): + self.assertEqual(layer.training, True) + if __name__ == '__main__': unittest.main()