未验证 提交 2cb61405 编写于 作者: X xiongkun 提交者: GitHub

add is_train into the cache key (#42889)

* add is_train into the cache key

* fix unittest error

* add unittest

* remove import
上级 fa6b3c9a
...@@ -197,10 +197,12 @@ class CacheKey(object): ...@@ -197,10 +197,12 @@ class CacheKey(object):
def __hash__(self): def __hash__(self):
error_msg = "Arguments to a `@paddle.jit.to_static` must be a hashable Python objects (or nested structures of these types)." error_msg = "Arguments to a `@paddle.jit.to_static` must be a hashable Python objects (or nested structures of these types)."
with_hook = self.kwargs.get("with_hook", False) with_hook = self.kwargs.get("with_hook", False)
return hash((id(self.function_spec), is_train = self.kwargs.get("is_train", False)
make_hashable(self.input_args_with_spec, error_msg), return hash(
make_hashable(self.input_kwargs_with_spec, error_msg), (id(self.function_spec),
self._spec_names_id, self.class_instance, with_hook)) make_hashable(self.input_args_with_spec, error_msg),
make_hashable(self.input_kwargs_with_spec, error_msg),
self._spec_names_id, self.class_instance, with_hook, is_train))
def __eq__(self, other): def __eq__(self, other):
return (type(self) is type(other)) and hash(self) == hash(other) return (type(self) is type(other)) and hash(self) == hash(other)
...@@ -357,7 +359,7 @@ class StaticFunction(object): ...@@ -357,7 +359,7 @@ class StaticFunction(object):
try: try:
concrete_program, partial_program_layer = self.get_concrete_program( concrete_program, partial_program_layer = self.get_concrete_program(
*args, **kwargs) *args, **kwargs, is_train=self._is_train_mode())
# 3. synchronize self.training attribute. # 3. synchronize self.training attribute.
if isinstance(self._class_instance, layers.Layer): if isinstance(self._class_instance, layers.Layer):
...@@ -383,6 +385,12 @@ class StaticFunction(object): ...@@ -383,6 +385,12 @@ class StaticFunction(object):
" if you can't handle this {} yourself.".format(type(e))) " if you can't handle this {} yourself.".format(type(e)))
raise e raise e
def _is_train_mode(self):
if self._class_instance is not None:
return self._class_instance.training
else:
return self._training
def _call_dygraph_function(self, *args, **kwargs): def _call_dygraph_function(self, *args, **kwargs):
""" """
Calls dygraph function directly and returns the outputs. Calls dygraph function directly and returns the outputs.
...@@ -415,6 +423,8 @@ class StaticFunction(object): ...@@ -415,6 +423,8 @@ class StaticFunction(object):
""" """
with_hook = kwargs.get("with_hook", False) with_hook = kwargs.get("with_hook", False)
is_train = kwargs.get("is_train", True)
if "is_train" in kwargs: kwargs.pop("is_train")
if "with_hook" in kwargs: kwargs.pop("with_hook") if "with_hook" in kwargs: kwargs.pop("with_hook")
# 1. unify args/kwargs and replace Tensor with InputSpec # 1. unify args/kwargs and replace Tensor with InputSpec
if len(args) != len(self._function_spec.args_name): if len(args) != len(self._function_spec.args_name):
...@@ -430,7 +440,8 @@ class StaticFunction(object): ...@@ -430,7 +440,8 @@ class StaticFunction(object):
input_kwargs_with_spec, input_kwargs_with_spec,
self._class_instance, self._class_instance,
**self._kwargs, **self._kwargs,
with_hook=with_hook) with_hook=with_hook,
is_train=is_train)
# 3. check whether hit the cache or build a new program for the input arguments # 3. check whether hit the cache or build a new program for the input arguments
concrete_program, partial_program_layer = self._program_cache[cache_key] concrete_program, partial_program_layer = self._program_cache[cache_key]
...@@ -525,7 +536,9 @@ class StaticFunction(object): ...@@ -525,7 +536,9 @@ class StaticFunction(object):
has_input_spec = (desired_input_spec is not None) has_input_spec = (desired_input_spec is not None)
if has_input_spec: if has_input_spec:
concrete_program, _ = self.get_concrete_program( concrete_program, _ = self.get_concrete_program(
*desired_input_spec, with_hook=with_hook) *desired_input_spec,
with_hook=with_hook,
is_train=self._is_train_mode())
return concrete_program return concrete_program
else: else:
raise ValueError( raise ValueError(
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
def drop_path(x, training=False):
if not training:
return x
else:
return 2 * x
class DropPath(paddle.nn.Layer):
def __init__(self):
super(DropPath, self).__init__()
@paddle.jit.to_static
def forward(self, x):
return drop_path(x, self.training)
class TestTrainEval(unittest.TestCase):
def setUp(self):
self.model = DropPath()
def tearDown(self):
pass
def test_train_and_eval(self):
x = paddle.to_tensor([1, 2, 3]).astype("int64")
eval_out = x.numpy()
train_out = x.numpy() * 2
self.model.train()
self.assertTrue(np.allclose(self.model(x).numpy(), train_out))
self.model.eval()
self.assertTrue(np.allclose(self.model(x).numpy(), eval_out))
if __name__ == "__main__":
unittest.main()
...@@ -135,22 +135,23 @@ class TestWithTrainAndEval(unittest.TestCase): ...@@ -135,22 +135,23 @@ class TestWithTrainAndEval(unittest.TestCase):
x = fluid.dygraph.to_variable(x_data) x = fluid.dygraph.to_variable(x_data)
linear_net(x) linear_net(x)
_, partial_layer = linear_net.forward.program_cache.last()[-1] _, train_partial_layer = linear_net.forward.program_cache.last()[-1]
# check default mode is for training # check default mode is for training
self.assertEqual(partial_layer.program, self.assertEqual(train_partial_layer.program,
partial_layer._train_program) train_partial_layer._train_program)
# switch to run test program after `eval()` # switch to run test program after `eval()`
linear_net.eval() linear_net.eval()
linear_net(x) linear_net(x)
self.assertEqual(partial_layer.program, _, eval_partial_layer = linear_net.forward.program_cache.last()[-1]
partial_layer._infer_program) self.assertEqual(eval_partial_layer.program,
eval_partial_layer._infer_program)
# switch back into training # switch back into training
linear_net.train() linear_net.train()
linear_net(x) linear_net(x)
self.assertEqual(partial_layer.program, self.assertEqual(train_partial_layer.program,
partial_layer._train_program) train_partial_layer._train_program)
class TestWithNoGrad(unittest.TestCase): class TestWithNoGrad(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册