未验证 提交 27cf7afb 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Fix losting pre/post hook from outermost layer while jit.save (#42273)

* [Dy2Stat]Fix losting pre/post hook from outermost layer while jit.save

* fix kwargs

* fix unittest
上级 7f14f78c
......@@ -196,10 +196,11 @@ class CacheKey(object):
def __hash__(self):
error_msg = "Arguments to a `@paddle.jit.to_static` must be a hashable Python objects (or nested structures of these types)."
with_hook = self.kwargs.get("with_hook", False)
return hash((id(self.function_spec),
make_hashable(self.input_args_with_spec, error_msg),
make_hashable(self.input_kwargs_with_spec, error_msg),
self._spec_names_id, self.class_instance))
self._spec_names_id, self.class_instance, with_hook))
def __eq__(self, other):
return (type(self) is type(other)) and hash(self) == hash(other)
......@@ -413,6 +414,8 @@ class StaticFunction(object):
Traced ConcreteProgram and executable translated Layer.
"""
with_hook = kwargs.get("with_hook", False)
if "with_hook" in kwargs: kwargs.pop("with_hook")
# 1. unify args/kwargs and replace Tensor with InputSpec
if len(args) != len(self._function_spec.args_name):
args, kwargs = self._function_spec.unified_args_and_kwargs(args,
......@@ -421,9 +424,13 @@ class StaticFunction(object):
args, kwargs)
# 2. generate cache key
cache_key = CacheKey(self._function_spec, input_args_with_spec,
input_kwargs_with_spec, self._class_instance,
**self._kwargs)
cache_key = CacheKey(
self._function_spec,
input_args_with_spec,
input_kwargs_with_spec,
self._class_instance,
**self._kwargs,
with_hook=with_hook)
# 3. check whether hit the cache or build a new program for the input arguments
concrete_program, partial_program_layer = self._program_cache[cache_key]
......@@ -480,11 +487,13 @@ class StaticFunction(object):
"""
return self.concrete_program_specify_input_spec(input_spec=None)
def concrete_program_specify_input_spec(self, input_spec=None):
def concrete_program_specify_input_spec(self,
input_spec=None,
with_hook=False):
"""
Returns recent ConcreteProgram instance of decorated function while
specifying input_spec. If the self._function_spec already has
input_spce, it will check the compatibility of input input_spec and
input_spec, it will check the compatibility of input input_spec and
the self._function_spec.input_spec. If input input_spec=None, then
this method uses self._function_spec.input_spec
......@@ -516,12 +525,18 @@ class StaticFunction(object):
has_input_spec = (desired_input_spec is not None)
if has_input_spec:
concrete_program, _ = self.get_concrete_program(
*desired_input_spec)
*desired_input_spec, with_hook=with_hook)
return concrete_program
else:
raise ValueError(
"No valid transformed program for {}.\n\t Please specific `input_spec` in `@paddle.jit.to_static` or feed input tensor to call the decorated function at once.\n".
format(self._function_spec))
elif with_hook:
cache_key = self._program_cache._recent_cache_key
cache_key.kwargs["with_hook"] = True
concrete_program, _ = self._program_cache[cache_key]
return concrete_program
# If more than one programs have been cached, return the recent converted program by default.
elif cached_program_len > 1:
logging_utils.warn(
......@@ -588,6 +603,54 @@ def _verify_init_in_dynamic_mode(class_instance):
class_instance))
class HookHelper(object):
"""
Only For converting pre/post hooks operation in outermost layer while jit.save.
Because hooks in sublayer have been processed automatically.
"""
def __init__(self, func, class_instance, with_hook=False):
self.func = func
self.class_instance = class_instance
self.with_hook = with_hook
self.need_apply_hook = with_hook and isinstance(
self.class_instance,
layers.Layer) and getattr(func, "__name__") == "forward"
def apply_pre_hooks(self, inputs):
"""
Apply _forward_pre_hooks from outermost layer
"""
if not self.need_apply_hook: return inputs
inputs = inputs[1:]
for forward_pre_hook in self.class_instance._forward_pre_hooks.values():
hook_result = forward_pre_hook(self.class_instance, inputs)
if hook_result is not None:
if not isinstance(hook_result, tuple):
hook_result = (hook_result, )
inputs = hook_result
return [self.class_instance] + list(inputs)
def apply_post_hooks(self, inputs, outputs):
"""
Apply _forward_post_hooks from outermost layer
"""
if not self.need_apply_hook: return outputs
inputs = inputs[1:]
for forward_post_hook in self.class_instance._forward_post_hooks.values(
):
hook_result = forward_post_hook(self.class_instance, inputs,
outputs)
if hook_result is not None:
outputs = hook_result
inputs.insert(0, self.class_instance)
return outputs
class ConcreteProgram(object):
__slots__ = [
......@@ -629,6 +692,9 @@ class ConcreteProgram(object):
# Transforms dygraph function into static function and caches it.
dygraph_function = func_spec.dygraph_function
static_func = convert_to_static(dygraph_function)
# apply pre\post hook for outermost layer
hook_helper = HookHelper(dygraph_function, class_instance,
kwargs.get("with_hook", False))
main_program, startup_program = framework.Program(), framework.Program()
# Note: The random seed should be synchronized into cached program
......@@ -642,12 +708,13 @@ class ConcreteProgram(object):
with framework.program_guard(main_program, startup_program):
with _switch_declarative_mode_guard_(is_declarative=True):
# 1. Adds `fluid.data` layers for input if needed
inputs = func_spec.to_static_inputs_with_spec(input_spec,
main_program)
static_inputs = func_spec.to_static_inputs_with_spec(
input_spec, main_program)
_kwargs = func_spec.to_static_inputs_with_spec(
input_kwargs_spec, main_program)
if class_instance:
inputs = tuple([class_instance] + list(inputs))
static_inputs = tuple([class_instance] + list(
static_inputs))
# 2. Gets all ParamBases and buffered VarBases in the function
all_parameters_and_buffers = _extract_indeed_params_buffers(
......@@ -658,10 +725,13 @@ class ConcreteProgram(object):
class_instance, False)), param_guard(
get_buffers(class_instance, False)):
try:
# only for jit.save, do nothing while train and eval process
inputs = hook_helper.apply_pre_hooks(static_inputs)
if _kwargs:
outputs = static_func(*inputs, **_kwargs)
else:
outputs = static_func(*inputs)
outputs = hook_helper.apply_post_hooks(inputs, outputs)
except BaseException as e:
# NOTE: If e is raised in compile time, e should be attached to ERROR_DATA here.
error.attach_error_data(e)
......@@ -679,7 +749,7 @@ class ConcreteProgram(object):
main_program = update_op_callstack_with_origin_info(main_program)
return ConcreteProgram(
inputs=inputs,
inputs=static_inputs,
outputs=outputs,
parameters=all_parameters_and_buffers,
function=dygraph_function,
......@@ -709,6 +779,7 @@ class ProgramCache(object):
self._caches = collections.OrderedDict()
# trace mostly recent used program
self._recent_key = None
self._recent_cache_key = None
def _build_once(self, cache_key):
concrete_program = ConcreteProgram.from_func_spec(
......@@ -724,6 +795,7 @@ class ProgramCache(object):
raise ValueError('type(item) should be CacheKey, but received %s' %
type_name(item))
item_id = hash(item)
self._recent_cache_key = item
self._recent_key = item_id
if item_id not in self._caches:
self._caches[item_id] = self._build_once(item)
......
......@@ -302,6 +302,7 @@ class _SaveLoadConfig(object):
# If True, It will save inference program only, and do not save params of Program
self._program_only = False
self.with_hook = False
@property
def output_spec(self):
......@@ -370,7 +371,7 @@ class _SaveLoadConfig(object):
def _parse_save_configs(configs):
supported_configs = ['output_spec']
supported_configs = ['output_spec', "with_hook"]
# input check
for key in configs:
......@@ -382,6 +383,7 @@ def _parse_save_configs(configs):
# construct inner config
inner_config = _SaveLoadConfig()
inner_config.output_spec = configs.get('output_spec', None)
inner_config.with_hook = configs.get('with_hook', False)
return inner_config
......@@ -454,11 +456,15 @@ def _get_input_var_names(inputs, input_spec):
return result_list
def _get_output_vars(outputs, output_spec):
def _get_output_vars(outputs, output_spec, with_hook=False):
name_no_exists_error = "The tensor `%s` does not exists. " \
"Please make sure the name of example Tensor " \
"in configs.output_spec is the output tensor of " \
"Layer.forward method."
if output_spec and with_hook:
raise RuntimeError(
"Currently not support specify output_spec while founding pre/post hooks in your outermost layer."
)
result_list = []
output_vars_dict = OrderedDict()
for var in flatten(outputs):
......@@ -830,10 +836,16 @@ def save(layer, path, input_spec=None, **configs):
# parse configs
configs = _parse_save_configs(configs)
# whether outermost layer has pre/post hook, if does, we need also save
# these operators in program.
with_hook = configs.with_hook
scope = core.Scope()
extra_var_info = dict()
if isinstance(layer, Layer):
functions = dir(inner_layer)
if inner_layer._forward_pre_hooks or inner_layer._forward_post_hooks:
with_hook = True
else:
# layer is function
functions = [layer, ]
......@@ -842,7 +854,7 @@ def save(layer, path, input_spec=None, **configs):
static_func = getattr(inner_layer, attr_func, None)
if isinstance(static_func, StaticFunction):
concrete_program = static_func.concrete_program_specify_input_spec(
inner_input_spec)
inner_input_spec, with_hook=with_hook)
elif 'forward' == attr_func:
# transform in jit.save, if input_spec is incomplete, declarative will throw error
# inner_input_spec is list[InputSpec], it should be packed with same structure
......@@ -852,7 +864,8 @@ def save(layer, path, input_spec=None, **configs):
inner_input_spec)
static_forward = declarative(
inner_layer.forward, input_spec=inner_input_spec)
concrete_program = static_forward.concrete_program
concrete_program = static_forward.concrete_program_specify_input_spec(
with_hook=with_hook)
# the input_spec has been used in declarative, which is equal to
# @declarative with input_spec and jit.save without input_spec,
# avoid needless warning
......@@ -943,8 +956,10 @@ def save(layer, path, input_spec=None, **configs):
# the rule is like [ Get input variables name ]. For output var,
# we only support VarBase spec, and actually, we only need the
# var name of output, and we don't recommended to use output_spec
# print(concrete_program.main_program)
# print(concrete_program.outputs, configs.output_spec)
output_vars = _get_output_vars(concrete_program.outputs,
configs.output_spec)
configs.output_spec, with_hook)
# 5. save inference model
from paddle.fluid.io import save_inference_model
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import numpy as np
def forward_post_hook1(layer, input, output):
return output + output
def forward_pre_hook1(layer, input):
input_return = (input[0] * 2, )
return input_return
class SimpleNet(paddle.nn.Layer):
def __init__(self, ):
super(SimpleNet, self).__init__()
self.fc1 = paddle.nn.Linear(10, 10)
# sublayer1 register post hook
self.fc1.register_forward_post_hook(forward_post_hook1)
self.fc2 = paddle.nn.Linear(10, 10)
# sublayer2 register pre hook
self.fc2.register_forward_pre_hook(forward_pre_hook1)
# register pre/post hook
self.register_forward_pre_hook(forward_pre_hook1)
self.register_forward_post_hook(forward_post_hook1)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
out = paddle.mean(x)
return out
class TestNestLayerHook(unittest.TestCase):
def setUp(self):
paddle.seed(2022)
self.x = paddle.randn([4, 10])
self.path = "./net_hook"
def train_net(self, to_static=False):
paddle.seed(2022)
net = SimpleNet()
if to_static:
net = paddle.jit.to_static(net)
out = net(self.x)
if to_static:
paddle.jit.save(net, self.path)
return out.numpy()[0]
def load_train(self):
net = paddle.jit.load(self.path)
out = net(self.x)
return out.numpy()[0]
def test_hook(self):
dy_out = self.train_net(to_static=False)
st_out = self.train_net(to_static=True)
load_out = self.load_train()
print(st_out, dy_out, load_out)
self.assertTrue(
np.allclose(st_out, dy_out),
msg='dygraph_res is {}\nstatic_res is {}'.format(dy_out, st_out))
self.assertTrue(
np.allclose(st_out, load_out),
msg='load_out is {}\nstatic_res is {}'.format(load_out, st_out))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册