未验证 提交 e91f7c02 编写于 作者: M Ming-Xu Huang 提交者: GitHub

Jit pre save hook (#38186)

* Pre-save hooks of jit.save

1. Added pre_save_hooks features to jit.save.
2. Added related unittests

* Added jit pre_save_hooks functions's alias to paddle.jit and copyright.

* Make jit.save_pre_hook style be consisent with Paddle's rule.

* Fixed arguments passing bug in run_save_pre_hooks

* Added API Documents

* Move clear and run_pre_save_hooks as internal methonds only.

* Made register_save_pre_hook as an internal function.
上级 d3686471
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 NVIDIA Corporation. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -20,6 +21,7 @@ import warnings
import functools
from collections import OrderedDict
import inspect
import threading
import six
import paddle
......@@ -525,6 +527,105 @@ def _build_load_path_and_config(path, config):
return model_path, config
_save_pre_hooks_lock = threading.Lock()
_save_pre_hooks = []
class HookRemoveHelper(object):
""" A HookRemoveHelper that can be used to remove hook. """
def __init__(self, hook):
self._hook = hook
def remove(self):
_remove_save_pre_hook(self._hook)
def _register_save_pre_hook(hook):
"""
Register a save pre-hook for `paddle.jit.save`.
This hook will be executed before `save` function has been invoked.
hook(layer, input_spec, configs) -> None
- layer (Layer|function): This argument is corresponding to `layer` in `paddle.jit.save`.
- input_spec (list or tuple[InputSpec|Tensor|Python built-in variable]): This argument is corresponding to `input_spec` in `paddle.jit.save`.
- configs (dict): This argument is corresponding to `configs` in `paddle.jit.save`.
Args:
hook(function): a function registered as a save pre-hook
Returns:
HookRemoveHelper: a HookRemoveHelper object that can be used to remove the added hook by calling `hook_remove_helper.remove()`.
Examples:
.. code-block:: python
import numpy as np
import paddle
IMAGE_SIZE = 256
CLASS_NUM = 10
class LinearNet(paddle.nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear = paddle.nn.Linear(IMAGE_SIZE, CLASS_NUM)
def forward(self, x):
return self._linear(x)
saving_count = 0
def save_pre_hook(layer, input_spec, configs):
global saving_count
saving_count += 1
remove_handler = paddle.jit.register_save_pre_hook(save_pre_hook)
layer = LinearNet()
paddle.jit.save(layer, "/tmp", [paddle.static.InputSpec(shape=[-1, IMAGE_SIZE])])
# saving_count == 1
remove_handler.remove()
paddle.jit.save(layer, "/tmp", [paddle.static.InputSpec(shape=[-1, IMAGE_SIZE])])
# saving_count == 1
"""
global _save_pre_hooks_lock
global _save_pre_hooks
_save_pre_hooks_lock.acquire()
if hook not in _save_pre_hooks:
_save_pre_hooks.append(hook)
_save_pre_hooks_lock.release()
return HookRemoveHelper(hook)
def _clear_save_pre_hooks():
global _save_pre_hooks_lock
global _save_pre_hooks
_save_pre_hooks_lock.acquire()
_save_pre_hooks.clear()
_save_pre_hooks_lock.release()
def _remove_save_pre_hook(hook):
global _save_pre_hooks_lock
global _save_pre_hooks
_save_pre_hooks_lock.acquire()
if hook in _save_pre_hooks:
_save_pre_hooks.remove(hook)
_save_pre_hooks_lock.release()
def _run_save_pre_hooks(func):
def wrapper(layer, path, input_spec=None, **configs):
global _save_pre_hooks
for hook in _save_pre_hooks:
hook(layer, input_spec, configs)
func(layer, path, input_spec, **configs)
return wrapper
@_run_save_pre_hooks
@switch_to_static_graph
def save(layer, path, input_spec=None, **configs):
"""
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 NVIDIA Corporation. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
from paddle.fluid.dygraph.jit import _run_save_pre_hooks, _clear_save_pre_hooks, _register_save_pre_hook
_counter = 0
class TestPreSaveHooks(unittest.TestCase):
def test_pre_save_hook_functions(self):
def fake_func(*args, **kwgs):
global _counter
_counter += 1
remove_handler = _register_save_pre_hook(fake_func)
self.assertEqual(len(paddle.fluid.dygraph.jit._save_pre_hooks), 1)
self.assertTrue(
paddle.fluid.dygraph.jit._save_pre_hooks[0] is fake_func)
# Test of avoiding redundancy hanging
remove_handler = _register_save_pre_hook(fake_func)
self.assertEqual(len(paddle.fluid.dygraph.jit._save_pre_hooks), 1)
self.assertTrue(
paddle.fluid.dygraph.jit._save_pre_hooks[0] is fake_func)
remove_handler.remove()
self.assertEqual(len(paddle.fluid.dygraph.jit._save_pre_hooks), 0)
remove_handler = _register_save_pre_hook(fake_func)
_clear_save_pre_hooks()
self.assertEqual(len(paddle.fluid.dygraph.jit._save_pre_hooks), 0)
global _counter
_counter = 0
remove_handler = _register_save_pre_hook(fake_func)
func_with_hook = _run_save_pre_hooks(fake_func)
func_with_hook(None, None)
self.assertEqual(_counter, 2)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 NVIDIA Corporation. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册