未验证 提交 2cb61405 编写于 作者: X xiongkun 提交者: GitHub

add is_train into the cache key (#42889)

* add is_train into the cache key

* fix unittest error

* add unittest

* remove import
上级 fa6b3c9a
......@@ -197,10 +197,12 @@ class CacheKey(object):
def __hash__(self):
error_msg = "Arguments to a `@paddle.jit.to_static` must be a hashable Python objects (or nested structures of these types)."
with_hook = self.kwargs.get("with_hook", False)
return hash((id(self.function_spec),
make_hashable(self.input_args_with_spec, error_msg),
make_hashable(self.input_kwargs_with_spec, error_msg),
self._spec_names_id, self.class_instance, with_hook))
is_train = self.kwargs.get("is_train", False)
return hash(
(id(self.function_spec),
make_hashable(self.input_args_with_spec, error_msg),
make_hashable(self.input_kwargs_with_spec, error_msg),
self._spec_names_id, self.class_instance, with_hook, is_train))
def __eq__(self, other):
return (type(self) is type(other)) and hash(self) == hash(other)
......@@ -357,7 +359,7 @@ class StaticFunction(object):
try:
concrete_program, partial_program_layer = self.get_concrete_program(
*args, **kwargs)
*args, **kwargs, is_train=self._is_train_mode())
# 3. synchronize self.training attribute.
if isinstance(self._class_instance, layers.Layer):
......@@ -383,6 +385,12 @@ class StaticFunction(object):
" if you can't handle this {} yourself.".format(type(e)))
raise e
def _is_train_mode(self):
if self._class_instance is not None:
return self._class_instance.training
else:
return self._training
def _call_dygraph_function(self, *args, **kwargs):
"""
Calls dygraph function directly and returns the outputs.
......@@ -415,6 +423,8 @@ class StaticFunction(object):
"""
with_hook = kwargs.get("with_hook", False)
is_train = kwargs.get("is_train", True)
if "is_train" in kwargs: kwargs.pop("is_train")
if "with_hook" in kwargs: kwargs.pop("with_hook")
# 1. unify args/kwargs and replace Tensor with InputSpec
if len(args) != len(self._function_spec.args_name):
......@@ -430,7 +440,8 @@ class StaticFunction(object):
input_kwargs_with_spec,
self._class_instance,
**self._kwargs,
with_hook=with_hook)
with_hook=with_hook,
is_train=is_train)
# 3. check whether hit the cache or build a new program for the input arguments
concrete_program, partial_program_layer = self._program_cache[cache_key]
......@@ -525,7 +536,9 @@ class StaticFunction(object):
has_input_spec = (desired_input_spec is not None)
if has_input_spec:
concrete_program, _ = self.get_concrete_program(
*desired_input_spec, with_hook=with_hook)
*desired_input_spec,
with_hook=with_hook,
is_train=self._is_train_mode())
return concrete_program
else:
raise ValueError(
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
def drop_path(x, training=False):
if not training:
return x
else:
return 2 * x
class DropPath(paddle.nn.Layer):
def __init__(self):
super(DropPath, self).__init__()
@paddle.jit.to_static
def forward(self, x):
return drop_path(x, self.training)
class TestTrainEval(unittest.TestCase):
def setUp(self):
self.model = DropPath()
def tearDown(self):
pass
def test_train_and_eval(self):
x = paddle.to_tensor([1, 2, 3]).astype("int64")
eval_out = x.numpy()
train_out = x.numpy() * 2
self.model.train()
self.assertTrue(np.allclose(self.model(x).numpy(), train_out))
self.model.eval()
self.assertTrue(np.allclose(self.model(x).numpy(), eval_out))
if __name__ == "__main__":
unittest.main()
......@@ -135,22 +135,23 @@ class TestWithTrainAndEval(unittest.TestCase):
x = fluid.dygraph.to_variable(x_data)
linear_net(x)
_, partial_layer = linear_net.forward.program_cache.last()[-1]
_, train_partial_layer = linear_net.forward.program_cache.last()[-1]
# check default mode is for training
self.assertEqual(partial_layer.program,
partial_layer._train_program)
self.assertEqual(train_partial_layer.program,
train_partial_layer._train_program)
# switch to run test program after `eval()`
linear_net.eval()
linear_net(x)
self.assertEqual(partial_layer.program,
partial_layer._infer_program)
_, eval_partial_layer = linear_net.forward.program_cache.last()[-1]
self.assertEqual(eval_partial_layer.program,
eval_partial_layer._infer_program)
# switch back into training
linear_net.train()
linear_net(x)
self.assertEqual(partial_layer.program,
partial_layer._train_program)
self.assertEqual(train_partial_layer.program,
train_partial_layer._train_program)
class TestWithNoGrad(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册