未验证 提交 16ef2b2e 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Fix losting pre/post hook from outermost layer while jit.save (#42273) (#42388)

* [Dy2Stat]Fix losting pre/post hook from outermost layer while jit.save

* fix kwargs

* fix unittest
上级 1e3d2e4a
...@@ -196,10 +196,11 @@ class CacheKey(object): ...@@ -196,10 +196,11 @@ class CacheKey(object):
def __hash__(self): def __hash__(self):
error_msg = "Arguments to a `@paddle.jit.to_static` must be a hashable Python objects (or nested structures of these types)." error_msg = "Arguments to a `@paddle.jit.to_static` must be a hashable Python objects (or nested structures of these types)."
with_hook = self.kwargs.get("with_hook", False)
return hash((id(self.function_spec), return hash((id(self.function_spec),
make_hashable(self.input_args_with_spec, error_msg), make_hashable(self.input_args_with_spec, error_msg),
make_hashable(self.input_kwargs_with_spec, error_msg), make_hashable(self.input_kwargs_with_spec, error_msg),
self._spec_names_id, self.class_instance)) self._spec_names_id, self.class_instance, with_hook))
def __eq__(self, other): def __eq__(self, other):
return (type(self) is type(other)) and hash(self) == hash(other) return (type(self) is type(other)) and hash(self) == hash(other)
...@@ -413,6 +414,8 @@ class StaticFunction(object): ...@@ -413,6 +414,8 @@ class StaticFunction(object):
Traced ConcreteProgram and executable translated Layer. Traced ConcreteProgram and executable translated Layer.
""" """
with_hook = kwargs.get("with_hook", False)
if "with_hook" in kwargs: kwargs.pop("with_hook")
# 1. unify args/kwargs and replace Tensor with InputSpec # 1. unify args/kwargs and replace Tensor with InputSpec
if len(args) != len(self._function_spec.args_name): if len(args) != len(self._function_spec.args_name):
args, kwargs = self._function_spec.unified_args_and_kwargs(args, args, kwargs = self._function_spec.unified_args_and_kwargs(args,
...@@ -421,9 +424,13 @@ class StaticFunction(object): ...@@ -421,9 +424,13 @@ class StaticFunction(object):
args, kwargs) args, kwargs)
# 2. generate cache key # 2. generate cache key
cache_key = CacheKey(self._function_spec, input_args_with_spec, cache_key = CacheKey(
input_kwargs_with_spec, self._class_instance, self._function_spec,
**self._kwargs) input_args_with_spec,
input_kwargs_with_spec,
self._class_instance,
**self._kwargs,
with_hook=with_hook)
# 3. check whether hit the cache or build a new program for the input arguments # 3. check whether hit the cache or build a new program for the input arguments
concrete_program, partial_program_layer = self._program_cache[cache_key] concrete_program, partial_program_layer = self._program_cache[cache_key]
...@@ -480,11 +487,13 @@ class StaticFunction(object): ...@@ -480,11 +487,13 @@ class StaticFunction(object):
""" """
return self.concrete_program_specify_input_spec(input_spec=None) return self.concrete_program_specify_input_spec(input_spec=None)
def concrete_program_specify_input_spec(self, input_spec=None): def concrete_program_specify_input_spec(self,
input_spec=None,
with_hook=False):
""" """
Returns recent ConcreteProgram instance of decorated function while Returns recent ConcreteProgram instance of decorated function while
specifying input_spec. If the self._function_spec already has specifying input_spec. If the self._function_spec already has
input_spce, it will check the compatibility of input input_spec and input_spec, it will check the compatibility of input input_spec and
the self._function_spec.input_spec. If input input_spec=None, then the self._function_spec.input_spec. If input input_spec=None, then
this method uses self._function_spec.input_spec this method uses self._function_spec.input_spec
...@@ -516,12 +525,18 @@ class StaticFunction(object): ...@@ -516,12 +525,18 @@ class StaticFunction(object):
has_input_spec = (desired_input_spec is not None) has_input_spec = (desired_input_spec is not None)
if has_input_spec: if has_input_spec:
concrete_program, _ = self.get_concrete_program( concrete_program, _ = self.get_concrete_program(
*desired_input_spec) *desired_input_spec, with_hook=with_hook)
return concrete_program return concrete_program
else: else:
raise ValueError( raise ValueError(
"No valid transformed program for {}.\n\t Please specific `input_spec` in `@paddle.jit.to_static` or feed input tensor to call the decorated function at once.\n". "No valid transformed program for {}.\n\t Please specific `input_spec` in `@paddle.jit.to_static` or feed input tensor to call the decorated function at once.\n".
format(self._function_spec)) format(self._function_spec))
elif with_hook:
cache_key = self._program_cache._recent_cache_key
cache_key.kwargs["with_hook"] = True
concrete_program, _ = self._program_cache[cache_key]
return concrete_program
# If more than one programs have been cached, return the recent converted program by default. # If more than one programs have been cached, return the recent converted program by default.
elif cached_program_len > 1: elif cached_program_len > 1:
logging_utils.warn( logging_utils.warn(
...@@ -588,6 +603,54 @@ def _verify_init_in_dynamic_mode(class_instance): ...@@ -588,6 +603,54 @@ def _verify_init_in_dynamic_mode(class_instance):
class_instance)) class_instance))
class HookHelper(object):
"""
Only For converting pre/post hooks operation in outermost layer while jit.save.
Because hooks in sublayer have been processed automatically.
"""
def __init__(self, func, class_instance, with_hook=False):
self.func = func
self.class_instance = class_instance
self.with_hook = with_hook
self.need_apply_hook = with_hook and isinstance(
self.class_instance,
layers.Layer) and getattr(func, "__name__") == "forward"
def apply_pre_hooks(self, inputs):
"""
Apply _forward_pre_hooks from outermost layer
"""
if not self.need_apply_hook: return inputs
inputs = inputs[1:]
for forward_pre_hook in self.class_instance._forward_pre_hooks.values():
hook_result = forward_pre_hook(self.class_instance, inputs)
if hook_result is not None:
if not isinstance(hook_result, tuple):
hook_result = (hook_result, )
inputs = hook_result
return [self.class_instance] + list(inputs)
def apply_post_hooks(self, inputs, outputs):
"""
Apply _forward_post_hooks from outermost layer
"""
if not self.need_apply_hook: return outputs
inputs = inputs[1:]
for forward_post_hook in self.class_instance._forward_post_hooks.values(
):
hook_result = forward_post_hook(self.class_instance, inputs,
outputs)
if hook_result is not None:
outputs = hook_result
inputs.insert(0, self.class_instance)
return outputs
class ConcreteProgram(object): class ConcreteProgram(object):
__slots__ = [ __slots__ = [
...@@ -629,6 +692,9 @@ class ConcreteProgram(object): ...@@ -629,6 +692,9 @@ class ConcreteProgram(object):
# Transforms dygraph function into static function and caches it. # Transforms dygraph function into static function and caches it.
dygraph_function = func_spec.dygraph_function dygraph_function = func_spec.dygraph_function
static_func = convert_to_static(dygraph_function) static_func = convert_to_static(dygraph_function)
# apply pre\post hook for outermost layer
hook_helper = HookHelper(dygraph_function, class_instance,
kwargs.get("with_hook", False))
main_program, startup_program = framework.Program(), framework.Program() main_program, startup_program = framework.Program(), framework.Program()
# Note: The random seed should be synchronized into cached program # Note: The random seed should be synchronized into cached program
...@@ -642,12 +708,13 @@ class ConcreteProgram(object): ...@@ -642,12 +708,13 @@ class ConcreteProgram(object):
with framework.program_guard(main_program, startup_program): with framework.program_guard(main_program, startup_program):
with _switch_declarative_mode_guard_(is_declarative=True): with _switch_declarative_mode_guard_(is_declarative=True):
# 1. Adds `fluid.data` layers for input if needed # 1. Adds `fluid.data` layers for input if needed
inputs = func_spec.to_static_inputs_with_spec(input_spec, static_inputs = func_spec.to_static_inputs_with_spec(
main_program) input_spec, main_program)
_kwargs = func_spec.to_static_inputs_with_spec( _kwargs = func_spec.to_static_inputs_with_spec(
input_kwargs_spec, main_program) input_kwargs_spec, main_program)
if class_instance: if class_instance:
inputs = tuple([class_instance] + list(inputs)) static_inputs = tuple([class_instance] + list(
static_inputs))
# 2. Gets all ParamBases and buffered VarBases in the function # 2. Gets all ParamBases and buffered VarBases in the function
all_parameters_and_buffers = _extract_indeed_params_buffers( all_parameters_and_buffers = _extract_indeed_params_buffers(
...@@ -658,10 +725,13 @@ class ConcreteProgram(object): ...@@ -658,10 +725,13 @@ class ConcreteProgram(object):
class_instance, False)), param_guard( class_instance, False)), param_guard(
get_buffers(class_instance, False)): get_buffers(class_instance, False)):
try: try:
# only for jit.save, do nothing while train and eval process
inputs = hook_helper.apply_pre_hooks(static_inputs)
if _kwargs: if _kwargs:
outputs = static_func(*inputs, **_kwargs) outputs = static_func(*inputs, **_kwargs)
else: else:
outputs = static_func(*inputs) outputs = static_func(*inputs)
outputs = hook_helper.apply_post_hooks(inputs, outputs)
except BaseException as e: except BaseException as e:
# NOTE: If e is raised in compile time, e should be attached to ERROR_DATA here. # NOTE: If e is raised in compile time, e should be attached to ERROR_DATA here.
error.attach_error_data(e) error.attach_error_data(e)
...@@ -679,7 +749,7 @@ class ConcreteProgram(object): ...@@ -679,7 +749,7 @@ class ConcreteProgram(object):
main_program = update_op_callstack_with_origin_info(main_program) main_program = update_op_callstack_with_origin_info(main_program)
return ConcreteProgram( return ConcreteProgram(
inputs=inputs, inputs=static_inputs,
outputs=outputs, outputs=outputs,
parameters=all_parameters_and_buffers, parameters=all_parameters_and_buffers,
function=dygraph_function, function=dygraph_function,
...@@ -709,6 +779,7 @@ class ProgramCache(object): ...@@ -709,6 +779,7 @@ class ProgramCache(object):
self._caches = collections.OrderedDict() self._caches = collections.OrderedDict()
# trace mostly recent used program # trace mostly recent used program
self._recent_key = None self._recent_key = None
self._recent_cache_key = None
def _build_once(self, cache_key): def _build_once(self, cache_key):
concrete_program = ConcreteProgram.from_func_spec( concrete_program = ConcreteProgram.from_func_spec(
...@@ -724,6 +795,7 @@ class ProgramCache(object): ...@@ -724,6 +795,7 @@ class ProgramCache(object):
raise ValueError('type(item) should be CacheKey, but received %s' % raise ValueError('type(item) should be CacheKey, but received %s' %
type_name(item)) type_name(item))
item_id = hash(item) item_id = hash(item)
self._recent_cache_key = item
self._recent_key = item_id self._recent_key = item_id
if item_id not in self._caches: if item_id not in self._caches:
self._caches[item_id] = self._build_once(item) self._caches[item_id] = self._build_once(item)
......
...@@ -302,6 +302,7 @@ class _SaveLoadConfig(object): ...@@ -302,6 +302,7 @@ class _SaveLoadConfig(object):
# If True, It will save inference program only, and do not save params of Program # If True, It will save inference program only, and do not save params of Program
self._program_only = False self._program_only = False
self.with_hook = False
@property @property
def output_spec(self): def output_spec(self):
...@@ -370,7 +371,7 @@ class _SaveLoadConfig(object): ...@@ -370,7 +371,7 @@ class _SaveLoadConfig(object):
def _parse_save_configs(configs): def _parse_save_configs(configs):
supported_configs = ['output_spec'] supported_configs = ['output_spec', "with_hook"]
# input check # input check
for key in configs: for key in configs:
...@@ -382,6 +383,7 @@ def _parse_save_configs(configs): ...@@ -382,6 +383,7 @@ def _parse_save_configs(configs):
# construct inner config # construct inner config
inner_config = _SaveLoadConfig() inner_config = _SaveLoadConfig()
inner_config.output_spec = configs.get('output_spec', None) inner_config.output_spec = configs.get('output_spec', None)
inner_config.with_hook = configs.get('with_hook', False)
return inner_config return inner_config
...@@ -454,11 +456,15 @@ def _get_input_var_names(inputs, input_spec): ...@@ -454,11 +456,15 @@ def _get_input_var_names(inputs, input_spec):
return result_list return result_list
def _get_output_vars(outputs, output_spec): def _get_output_vars(outputs, output_spec, with_hook=False):
name_no_exists_error = "The tensor `%s` does not exists. " \ name_no_exists_error = "The tensor `%s` does not exists. " \
"Please make sure the name of example Tensor " \ "Please make sure the name of example Tensor " \
"in configs.output_spec is the output tensor of " \ "in configs.output_spec is the output tensor of " \
"Layer.forward method." "Layer.forward method."
if output_spec and with_hook:
raise RuntimeError(
"Currently not support specify output_spec while founding pre/post hooks in your outermost layer."
)
result_list = [] result_list = []
output_vars_dict = OrderedDict() output_vars_dict = OrderedDict()
for var in flatten(outputs): for var in flatten(outputs):
...@@ -830,10 +836,16 @@ def save(layer, path, input_spec=None, **configs): ...@@ -830,10 +836,16 @@ def save(layer, path, input_spec=None, **configs):
# parse configs # parse configs
configs = _parse_save_configs(configs) configs = _parse_save_configs(configs)
# whether outermost layer has pre/post hook, if does, we need also save
# these operators in program.
with_hook = configs.with_hook
scope = core.Scope() scope = core.Scope()
extra_var_info = dict() extra_var_info = dict()
if isinstance(layer, Layer): if isinstance(layer, Layer):
functions = dir(inner_layer) functions = dir(inner_layer)
if inner_layer._forward_pre_hooks or inner_layer._forward_post_hooks:
with_hook = True
else: else:
# layer is function # layer is function
functions = [layer, ] functions = [layer, ]
...@@ -842,7 +854,7 @@ def save(layer, path, input_spec=None, **configs): ...@@ -842,7 +854,7 @@ def save(layer, path, input_spec=None, **configs):
static_func = getattr(inner_layer, attr_func, None) static_func = getattr(inner_layer, attr_func, None)
if isinstance(static_func, StaticFunction): if isinstance(static_func, StaticFunction):
concrete_program = static_func.concrete_program_specify_input_spec( concrete_program = static_func.concrete_program_specify_input_spec(
inner_input_spec) inner_input_spec, with_hook=with_hook)
elif 'forward' == attr_func: elif 'forward' == attr_func:
# transform in jit.save, if input_spec is incomplete, declarative will throw error # transform in jit.save, if input_spec is incomplete, declarative will throw error
# inner_input_spec is list[InputSpec], it should be packed with same structure # inner_input_spec is list[InputSpec], it should be packed with same structure
...@@ -852,7 +864,8 @@ def save(layer, path, input_spec=None, **configs): ...@@ -852,7 +864,8 @@ def save(layer, path, input_spec=None, **configs):
inner_input_spec) inner_input_spec)
static_forward = declarative( static_forward = declarative(
inner_layer.forward, input_spec=inner_input_spec) inner_layer.forward, input_spec=inner_input_spec)
concrete_program = static_forward.concrete_program concrete_program = static_forward.concrete_program_specify_input_spec(
with_hook=with_hook)
# the input_spec has been used in declarative, which is equal to # the input_spec has been used in declarative, which is equal to
# @declarative with input_spec and jit.save without input_spec, # @declarative with input_spec and jit.save without input_spec,
# avoid needless warning # avoid needless warning
...@@ -943,8 +956,10 @@ def save(layer, path, input_spec=None, **configs): ...@@ -943,8 +956,10 @@ def save(layer, path, input_spec=None, **configs):
# the rule is like [ Get input variables name ]. For output var, # the rule is like [ Get input variables name ]. For output var,
# we only support VarBase spec, and actually, we only need the # we only support VarBase spec, and actually, we only need the
# var name of output, and we don't recommended to use output_spec # var name of output, and we don't recommended to use output_spec
# print(concrete_program.main_program)
# print(concrete_program.outputs, configs.output_spec)
output_vars = _get_output_vars(concrete_program.outputs, output_vars = _get_output_vars(concrete_program.outputs,
configs.output_spec) configs.output_spec, with_hook)
# 5. save inference model # 5. save inference model
from paddle.fluid.io import save_inference_model from paddle.fluid.io import save_inference_model
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import numpy as np
def forward_post_hook1(layer, input, output):
return output + output
def forward_pre_hook1(layer, input):
input_return = (input[0] * 2, )
return input_return
class SimpleNet(paddle.nn.Layer):
def __init__(self, ):
super(SimpleNet, self).__init__()
self.fc1 = paddle.nn.Linear(10, 10)
# sublayer1 register post hook
self.fc1.register_forward_post_hook(forward_post_hook1)
self.fc2 = paddle.nn.Linear(10, 10)
# sublayer2 register pre hook
self.fc2.register_forward_pre_hook(forward_pre_hook1)
# register pre/post hook
self.register_forward_pre_hook(forward_pre_hook1)
self.register_forward_post_hook(forward_post_hook1)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
out = paddle.mean(x)
return out
class TestNestLayerHook(unittest.TestCase):
def setUp(self):
paddle.seed(2022)
self.x = paddle.randn([4, 10])
self.path = "./net_hook"
def train_net(self, to_static=False):
paddle.seed(2022)
net = SimpleNet()
if to_static:
net = paddle.jit.to_static(net)
out = net(self.x)
if to_static:
paddle.jit.save(net, self.path)
return out.numpy()[0]
def load_train(self):
net = paddle.jit.load(self.path)
out = net(self.x)
return out.numpy()[0]
def test_hook(self):
dy_out = self.train_net(to_static=False)
st_out = self.train_net(to_static=True)
load_out = self.load_train()
print(st_out, dy_out, load_out)
self.assertTrue(
np.allclose(st_out, dy_out),
msg='dygraph_res is {}\nstatic_res is {}'.format(dy_out, st_out))
self.assertTrue(
np.allclose(st_out, load_out),
msg='load_out is {}\nstatic_res is {}'.format(load_out, st_out))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册