From 2cb61405abcab502c07be750151ed0773175094e Mon Sep 17 00:00:00 2001 From: xiongkun Date: Mon, 23 May 2022 14:23:25 +0800 Subject: [PATCH] add is_train into the cache key (#42889) * add is_train into the cache key * fix unittest error * add unittest * remove import --- .../dygraph_to_static/program_translator.py | 27 ++++++--- .../dygraph_to_static/test_drop_path.py | 55 +++++++++++++++++++ .../dygraph_to_static/test_partial_program.py | 15 ++--- 3 files changed, 83 insertions(+), 14 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_drop_path.py diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index b860740f71b..2efb6965085 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -197,10 +197,12 @@ class CacheKey(object): def __hash__(self): error_msg = "Arguments to a `@paddle.jit.to_static` must be a hashable Python objects (or nested structures of these types)." with_hook = self.kwargs.get("with_hook", False) - return hash((id(self.function_spec), - make_hashable(self.input_args_with_spec, error_msg), - make_hashable(self.input_kwargs_with_spec, error_msg), - self._spec_names_id, self.class_instance, with_hook)) + is_train = self.kwargs.get("is_train", False) + return hash( + (id(self.function_spec), + make_hashable(self.input_args_with_spec, error_msg), + make_hashable(self.input_kwargs_with_spec, error_msg), + self._spec_names_id, self.class_instance, with_hook, is_train)) def __eq__(self, other): return (type(self) is type(other)) and hash(self) == hash(other) @@ -357,7 +359,7 @@ class StaticFunction(object): try: concrete_program, partial_program_layer = self.get_concrete_program( - *args, **kwargs) + *args, **kwargs, is_train=self._is_train_mode()) # 3. synchronize self.training attribute. if isinstance(self._class_instance, layers.Layer): @@ -383,6 +385,12 @@ class StaticFunction(object): " if you can't handle this {} yourself.".format(type(e))) raise e + def _is_train_mode(self): + if self._class_instance is not None: + return self._class_instance.training + else: + return self._training + def _call_dygraph_function(self, *args, **kwargs): """ Calls dygraph function directly and returns the outputs. @@ -415,6 +423,8 @@ class StaticFunction(object): """ with_hook = kwargs.get("with_hook", False) + is_train = kwargs.get("is_train", True) + if "is_train" in kwargs: kwargs.pop("is_train") if "with_hook" in kwargs: kwargs.pop("with_hook") # 1. unify args/kwargs and replace Tensor with InputSpec if len(args) != len(self._function_spec.args_name): @@ -430,7 +440,8 @@ class StaticFunction(object): input_kwargs_with_spec, self._class_instance, **self._kwargs, - with_hook=with_hook) + with_hook=with_hook, + is_train=is_train) # 3. check whether hit the cache or build a new program for the input arguments concrete_program, partial_program_layer = self._program_cache[cache_key] @@ -525,7 +536,9 @@ class StaticFunction(object): has_input_spec = (desired_input_spec is not None) if has_input_spec: concrete_program, _ = self.get_concrete_program( - *desired_input_spec, with_hook=with_hook) + *desired_input_spec, + with_hook=with_hook, + is_train=self._is_train_mode()) return concrete_program else: raise ValueError( diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_drop_path.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_drop_path.py new file mode 100644 index 00000000000..7383c834ba9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_drop_path.py @@ -0,0 +1,55 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np + +import paddle + + +def drop_path(x, training=False): + if not training: + return x + else: + return 2 * x + + +class DropPath(paddle.nn.Layer): + def __init__(self): + super(DropPath, self).__init__() + + @paddle.jit.to_static + def forward(self, x): + return drop_path(x, self.training) + + +class TestTrainEval(unittest.TestCase): + def setUp(self): + self.model = DropPath() + + def tearDown(self): + pass + + def test_train_and_eval(self): + x = paddle.to_tensor([1, 2, 3]).astype("int64") + eval_out = x.numpy() + train_out = x.numpy() * 2 + self.model.train() + self.assertTrue(np.allclose(self.model(x).numpy(), train_out)) + self.model.eval() + self.assertTrue(np.allclose(self.model(x).numpy(), eval_out)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_partial_program.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_partial_program.py index 427e4c22524..4f55dbd324c 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_partial_program.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_partial_program.py @@ -135,22 +135,23 @@ class TestWithTrainAndEval(unittest.TestCase): x = fluid.dygraph.to_variable(x_data) linear_net(x) - _, partial_layer = linear_net.forward.program_cache.last()[-1] + _, train_partial_layer = linear_net.forward.program_cache.last()[-1] # check default mode is for training - self.assertEqual(partial_layer.program, - partial_layer._train_program) + self.assertEqual(train_partial_layer.program, + train_partial_layer._train_program) # switch to run test program after `eval()` linear_net.eval() linear_net(x) - self.assertEqual(partial_layer.program, - partial_layer._infer_program) + _, eval_partial_layer = linear_net.forward.program_cache.last()[-1] + self.assertEqual(eval_partial_layer.program, + eval_partial_layer._infer_program) # switch back into training linear_net.train() linear_net(x) - self.assertEqual(partial_layer.program, - partial_layer._train_program) + self.assertEqual(train_partial_layer.program, + train_partial_layer._train_program) class TestWithNoGrad(unittest.TestCase): -- GitLab