From 16ef2b2e3b456257c945a7167b9ff787607c74b2 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Sat, 30 Apr 2022 15:11:42 +0800 Subject: [PATCH] [Dy2Stat]Fix losting pre/post hook from outermost layer while jit.save (#42273) (#42388) * [Dy2Stat]Fix losting pre/post hook from outermost layer while jit.save * fix kwargs * fix unittest --- .../dygraph_to_static/program_translator.py | 94 ++++++++++++++++--- python/paddle/fluid/dygraph/jit.py | 25 ++++- .../dygraph_to_static/test_layer_hook.py | 90 ++++++++++++++++++ 3 files changed, 193 insertions(+), 16 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_layer_hook.py diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index bc1a0e30dd..b860740f71 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -196,10 +196,11 @@ class CacheKey(object): def __hash__(self): error_msg = "Arguments to a `@paddle.jit.to_static` must be a hashable Python objects (or nested structures of these types)." + with_hook = self.kwargs.get("with_hook", False) return hash((id(self.function_spec), make_hashable(self.input_args_with_spec, error_msg), make_hashable(self.input_kwargs_with_spec, error_msg), - self._spec_names_id, self.class_instance)) + self._spec_names_id, self.class_instance, with_hook)) def __eq__(self, other): return (type(self) is type(other)) and hash(self) == hash(other) @@ -413,6 +414,8 @@ class StaticFunction(object): Traced ConcreteProgram and executable translated Layer. """ + with_hook = kwargs.get("with_hook", False) + if "with_hook" in kwargs: kwargs.pop("with_hook") # 1. unify args/kwargs and replace Tensor with InputSpec if len(args) != len(self._function_spec.args_name): args, kwargs = self._function_spec.unified_args_and_kwargs(args, @@ -421,9 +424,13 @@ class StaticFunction(object): args, kwargs) # 2. generate cache key - cache_key = CacheKey(self._function_spec, input_args_with_spec, - input_kwargs_with_spec, self._class_instance, - **self._kwargs) + cache_key = CacheKey( + self._function_spec, + input_args_with_spec, + input_kwargs_with_spec, + self._class_instance, + **self._kwargs, + with_hook=with_hook) # 3. check whether hit the cache or build a new program for the input arguments concrete_program, partial_program_layer = self._program_cache[cache_key] @@ -480,11 +487,13 @@ class StaticFunction(object): """ return self.concrete_program_specify_input_spec(input_spec=None) - def concrete_program_specify_input_spec(self, input_spec=None): + def concrete_program_specify_input_spec(self, + input_spec=None, + with_hook=False): """ Returns recent ConcreteProgram instance of decorated function while specifying input_spec. If the self._function_spec already has - input_spce, it will check the compatibility of input input_spec and + input_spec, it will check the compatibility of input input_spec and the self._function_spec.input_spec. If input input_spec=None, then this method uses self._function_spec.input_spec @@ -516,12 +525,18 @@ class StaticFunction(object): has_input_spec = (desired_input_spec is not None) if has_input_spec: concrete_program, _ = self.get_concrete_program( - *desired_input_spec) + *desired_input_spec, with_hook=with_hook) return concrete_program else: raise ValueError( "No valid transformed program for {}.\n\t Please specific `input_spec` in `@paddle.jit.to_static` or feed input tensor to call the decorated function at once.\n". format(self._function_spec)) + elif with_hook: + cache_key = self._program_cache._recent_cache_key + cache_key.kwargs["with_hook"] = True + concrete_program, _ = self._program_cache[cache_key] + return concrete_program + # If more than one programs have been cached, return the recent converted program by default. elif cached_program_len > 1: logging_utils.warn( @@ -588,6 +603,54 @@ def _verify_init_in_dynamic_mode(class_instance): class_instance)) +class HookHelper(object): + """ + Only For converting pre/post hooks operation in outermost layer while jit.save. + Because hooks in sublayer have been processed automatically. + """ + + def __init__(self, func, class_instance, with_hook=False): + self.func = func + self.class_instance = class_instance + self.with_hook = with_hook + self.need_apply_hook = with_hook and isinstance( + self.class_instance, + layers.Layer) and getattr(func, "__name__") == "forward" + + def apply_pre_hooks(self, inputs): + """ + Apply _forward_pre_hooks from outermost layer + """ + if not self.need_apply_hook: return inputs + + inputs = inputs[1:] + for forward_pre_hook in self.class_instance._forward_pre_hooks.values(): + hook_result = forward_pre_hook(self.class_instance, inputs) + if hook_result is not None: + if not isinstance(hook_result, tuple): + hook_result = (hook_result, ) + inputs = hook_result + + return [self.class_instance] + list(inputs) + + def apply_post_hooks(self, inputs, outputs): + """ + Apply _forward_post_hooks from outermost layer + """ + if not self.need_apply_hook: return outputs + + inputs = inputs[1:] + for forward_post_hook in self.class_instance._forward_post_hooks.values( + ): + hook_result = forward_post_hook(self.class_instance, inputs, + outputs) + if hook_result is not None: + outputs = hook_result + + inputs.insert(0, self.class_instance) + return outputs + + class ConcreteProgram(object): __slots__ = [ @@ -629,6 +692,9 @@ class ConcreteProgram(object): # Transforms dygraph function into static function and caches it. dygraph_function = func_spec.dygraph_function static_func = convert_to_static(dygraph_function) + # apply pre\post hook for outermost layer + hook_helper = HookHelper(dygraph_function, class_instance, + kwargs.get("with_hook", False)) main_program, startup_program = framework.Program(), framework.Program() # Note: The random seed should be synchronized into cached program @@ -642,12 +708,13 @@ class ConcreteProgram(object): with framework.program_guard(main_program, startup_program): with _switch_declarative_mode_guard_(is_declarative=True): # 1. Adds `fluid.data` layers for input if needed - inputs = func_spec.to_static_inputs_with_spec(input_spec, - main_program) + static_inputs = func_spec.to_static_inputs_with_spec( + input_spec, main_program) _kwargs = func_spec.to_static_inputs_with_spec( input_kwargs_spec, main_program) if class_instance: - inputs = tuple([class_instance] + list(inputs)) + static_inputs = tuple([class_instance] + list( + static_inputs)) # 2. Gets all ParamBases and buffered VarBases in the function all_parameters_and_buffers = _extract_indeed_params_buffers( @@ -658,10 +725,13 @@ class ConcreteProgram(object): class_instance, False)), param_guard( get_buffers(class_instance, False)): try: + # only for jit.save, do nothing while train and eval process + inputs = hook_helper.apply_pre_hooks(static_inputs) if _kwargs: outputs = static_func(*inputs, **_kwargs) else: outputs = static_func(*inputs) + outputs = hook_helper.apply_post_hooks(inputs, outputs) except BaseException as e: # NOTE: If e is raised in compile time, e should be attached to ERROR_DATA here. error.attach_error_data(e) @@ -679,7 +749,7 @@ class ConcreteProgram(object): main_program = update_op_callstack_with_origin_info(main_program) return ConcreteProgram( - inputs=inputs, + inputs=static_inputs, outputs=outputs, parameters=all_parameters_and_buffers, function=dygraph_function, @@ -709,6 +779,7 @@ class ProgramCache(object): self._caches = collections.OrderedDict() # trace mostly recent used program self._recent_key = None + self._recent_cache_key = None def _build_once(self, cache_key): concrete_program = ConcreteProgram.from_func_spec( @@ -724,6 +795,7 @@ class ProgramCache(object): raise ValueError('type(item) should be CacheKey, but received %s' % type_name(item)) item_id = hash(item) + self._recent_cache_key = item self._recent_key = item_id if item_id not in self._caches: self._caches[item_id] = self._build_once(item) diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index 7957b33bf1..e0e259215c 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -302,6 +302,7 @@ class _SaveLoadConfig(object): # If True, It will save inference program only, and do not save params of Program self._program_only = False + self.with_hook = False @property def output_spec(self): @@ -370,7 +371,7 @@ class _SaveLoadConfig(object): def _parse_save_configs(configs): - supported_configs = ['output_spec'] + supported_configs = ['output_spec', "with_hook"] # input check for key in configs: @@ -382,6 +383,7 @@ def _parse_save_configs(configs): # construct inner config inner_config = _SaveLoadConfig() inner_config.output_spec = configs.get('output_spec', None) + inner_config.with_hook = configs.get('with_hook', False) return inner_config @@ -454,11 +456,15 @@ def _get_input_var_names(inputs, input_spec): return result_list -def _get_output_vars(outputs, output_spec): +def _get_output_vars(outputs, output_spec, with_hook=False): name_no_exists_error = "The tensor `%s` does not exists. " \ "Please make sure the name of example Tensor " \ "in configs.output_spec is the output tensor of " \ "Layer.forward method." + if output_spec and with_hook: + raise RuntimeError( + "Currently not support specify output_spec while founding pre/post hooks in your outermost layer." + ) result_list = [] output_vars_dict = OrderedDict() for var in flatten(outputs): @@ -830,10 +836,16 @@ def save(layer, path, input_spec=None, **configs): # parse configs configs = _parse_save_configs(configs) + # whether outermost layer has pre/post hook, if does, we need also save + # these operators in program. + with_hook = configs.with_hook + scope = core.Scope() extra_var_info = dict() if isinstance(layer, Layer): functions = dir(inner_layer) + if inner_layer._forward_pre_hooks or inner_layer._forward_post_hooks: + with_hook = True else: # layer is function functions = [layer, ] @@ -842,7 +854,7 @@ def save(layer, path, input_spec=None, **configs): static_func = getattr(inner_layer, attr_func, None) if isinstance(static_func, StaticFunction): concrete_program = static_func.concrete_program_specify_input_spec( - inner_input_spec) + inner_input_spec, with_hook=with_hook) elif 'forward' == attr_func: # transform in jit.save, if input_spec is incomplete, declarative will throw error # inner_input_spec is list[InputSpec], it should be packed with same structure @@ -852,7 +864,8 @@ def save(layer, path, input_spec=None, **configs): inner_input_spec) static_forward = declarative( inner_layer.forward, input_spec=inner_input_spec) - concrete_program = static_forward.concrete_program + concrete_program = static_forward.concrete_program_specify_input_spec( + with_hook=with_hook) # the input_spec has been used in declarative, which is equal to # @declarative with input_spec and jit.save without input_spec, # avoid needless warning @@ -943,8 +956,10 @@ def save(layer, path, input_spec=None, **configs): # the rule is like [ Get input variables name ]. For output var, # we only support VarBase spec, and actually, we only need the # var name of output, and we don't recommended to use output_spec + # print(concrete_program.main_program) + # print(concrete_program.outputs, configs.output_spec) output_vars = _get_output_vars(concrete_program.outputs, - configs.output_spec) + configs.output_spec, with_hook) # 5. save inference model from paddle.fluid.io import save_inference_model diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_layer_hook.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_layer_hook.py new file mode 100644 index 0000000000..dcb41cfc6a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_layer_hook.py @@ -0,0 +1,90 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle + +import numpy as np + + +def forward_post_hook1(layer, input, output): + return output + output + + +def forward_pre_hook1(layer, input): + input_return = (input[0] * 2, ) + return input_return + + +class SimpleNet(paddle.nn.Layer): + def __init__(self, ): + super(SimpleNet, self).__init__() + self.fc1 = paddle.nn.Linear(10, 10) + # sublayer1 register post hook + self.fc1.register_forward_post_hook(forward_post_hook1) + + self.fc2 = paddle.nn.Linear(10, 10) + # sublayer2 register pre hook + self.fc2.register_forward_pre_hook(forward_pre_hook1) + + # register pre/post hook + self.register_forward_pre_hook(forward_pre_hook1) + self.register_forward_post_hook(forward_post_hook1) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + out = paddle.mean(x) + + return out + + +class TestNestLayerHook(unittest.TestCase): + def setUp(self): + paddle.seed(2022) + self.x = paddle.randn([4, 10]) + self.path = "./net_hook" + + def train_net(self, to_static=False): + paddle.seed(2022) + net = SimpleNet() + if to_static: + net = paddle.jit.to_static(net) + out = net(self.x) + + if to_static: + paddle.jit.save(net, self.path) + + return out.numpy()[0] + + def load_train(self): + net = paddle.jit.load(self.path) + out = net(self.x) + return out.numpy()[0] + + def test_hook(self): + dy_out = self.train_net(to_static=False) + st_out = self.train_net(to_static=True) + load_out = self.load_train() + print(st_out, dy_out, load_out) + self.assertTrue( + np.allclose(st_out, dy_out), + msg='dygraph_res is {}\nstatic_res is {}'.format(dy_out, st_out)) + self.assertTrue( + np.allclose(st_out, load_out), + msg='load_out is {}\nstatic_res is {}'.format(load_out, st_out)) + + +if __name__ == "__main__": + unittest.main() -- GitLab