diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index 2db9fb5d76a587cd44b061ce000b686d0499b445..4bfdc3c27fad628bba3fd16237c12d3ca43244d7 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -1,4 +1,5 @@ # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,6 +21,7 @@ import warnings import functools from collections import OrderedDict import inspect +import threading import six import paddle @@ -525,6 +527,105 @@ def _build_load_path_and_config(path, config): return model_path, config +_save_pre_hooks_lock = threading.Lock() +_save_pre_hooks = [] + + +class HookRemoveHelper(object): + """ A HookRemoveHelper that can be used to remove hook. """ + + def __init__(self, hook): + self._hook = hook + + def remove(self): + _remove_save_pre_hook(self._hook) + + +def _register_save_pre_hook(hook): + """ + Register a save pre-hook for `paddle.jit.save`. + This hook will be executed before `save` function has been invoked. + + hook(layer, input_spec, configs) -> None + - layer (Layer|function): This argument is corresponding to `layer` in `paddle.jit.save`. + - input_spec (list or tuple[InputSpec|Tensor|Python built-in variable]): This argument is corresponding to `input_spec` in `paddle.jit.save`. + - configs (dict): This argument is corresponding to `configs` in `paddle.jit.save`. + + Args: + hook(function): a function registered as a save pre-hook + + Returns: + HookRemoveHelper: a HookRemoveHelper object that can be used to remove the added hook by calling `hook_remove_helper.remove()`. + + Examples: + .. code-block:: python + + import numpy as np + import paddle + + IMAGE_SIZE = 256 + CLASS_NUM = 10 + + class LinearNet(paddle.nn.Layer): + def __init__(self): + super(LinearNet, self).__init__() + self._linear = paddle.nn.Linear(IMAGE_SIZE, CLASS_NUM) + + def forward(self, x): + return self._linear(x) + + saving_count = 0 + def save_pre_hook(layer, input_spec, configs): + global saving_count + saving_count += 1 + + remove_handler = paddle.jit.register_save_pre_hook(save_pre_hook) + + layer = LinearNet() + paddle.jit.save(layer, "/tmp", [paddle.static.InputSpec(shape=[-1, IMAGE_SIZE])]) + # saving_count == 1 + + remove_handler.remove() + paddle.jit.save(layer, "/tmp", [paddle.static.InputSpec(shape=[-1, IMAGE_SIZE])]) + # saving_count == 1 + """ + global _save_pre_hooks_lock + global _save_pre_hooks + _save_pre_hooks_lock.acquire() + if hook not in _save_pre_hooks: + _save_pre_hooks.append(hook) + _save_pre_hooks_lock.release() + return HookRemoveHelper(hook) + + +def _clear_save_pre_hooks(): + global _save_pre_hooks_lock + global _save_pre_hooks + _save_pre_hooks_lock.acquire() + _save_pre_hooks.clear() + _save_pre_hooks_lock.release() + + +def _remove_save_pre_hook(hook): + global _save_pre_hooks_lock + global _save_pre_hooks + _save_pre_hooks_lock.acquire() + if hook in _save_pre_hooks: + _save_pre_hooks.remove(hook) + _save_pre_hooks_lock.release() + + +def _run_save_pre_hooks(func): + def wrapper(layer, path, input_spec=None, **configs): + global _save_pre_hooks + for hook in _save_pre_hooks: + hook(layer, input_spec, configs) + func(layer, path, input_spec, **configs) + + return wrapper + + +@_run_save_pre_hooks @switch_to_static_graph def save(layer, path, input_spec=None, **configs): """ diff --git a/python/paddle/fluid/tests/unittests/test_jit_pre_save_hooks.py b/python/paddle/fluid/tests/unittests/test_jit_pre_save_hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..a938024e3c9b4e5f84e69afb99eb3e6bbdd18c01 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_jit_pre_save_hooks.py @@ -0,0 +1,59 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import paddle +from paddle.fluid.dygraph.jit import _run_save_pre_hooks, _clear_save_pre_hooks, _register_save_pre_hook + +_counter = 0 + + +class TestPreSaveHooks(unittest.TestCase): + def test_pre_save_hook_functions(self): + def fake_func(*args, **kwgs): + global _counter + _counter += 1 + + remove_handler = _register_save_pre_hook(fake_func) + self.assertEqual(len(paddle.fluid.dygraph.jit._save_pre_hooks), 1) + self.assertTrue( + paddle.fluid.dygraph.jit._save_pre_hooks[0] is fake_func) + + # Test of avoiding redundancy hanging + remove_handler = _register_save_pre_hook(fake_func) + self.assertEqual(len(paddle.fluid.dygraph.jit._save_pre_hooks), 1) + self.assertTrue( + paddle.fluid.dygraph.jit._save_pre_hooks[0] is fake_func) + + remove_handler.remove() + self.assertEqual(len(paddle.fluid.dygraph.jit._save_pre_hooks), 0) + + remove_handler = _register_save_pre_hook(fake_func) + _clear_save_pre_hooks() + self.assertEqual(len(paddle.fluid.dygraph.jit._save_pre_hooks), 0) + + global _counter + _counter = 0 + remove_handler = _register_save_pre_hook(fake_func) + func_with_hook = _run_save_pre_hooks(fake_func) + func_with_hook(None, None) + self.assertEqual(_counter, 2) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/jit/__init__.py b/python/paddle/jit/__init__.py index 576989e8e0d2aa019dc9ec7c7d69afa941f1dcb7..a2af493faca111b5bfb5ea5e78b44e34145a9568 100644 --- a/python/paddle/jit/__init__.py +++ b/python/paddle/jit/__init__.py @@ -1,4 +1,5 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 NVIDIA Corporation. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License.