From daea892c67e85da91906864de40ce9f6f1b893ae Mon Sep 17 00:00:00 2001 From: hjyp <53164956+Tomoko-hjf@users.noreply.github.com> Date: Tue, 10 Jan 2023 14:45:39 +0800 Subject: [PATCH] [Dy2St] Add ignore_module API (#49485) * Add ignore_module API * fix type of parameter * Add test case of ignore-module --- .../dygraph_to_static/test_ignore_module.py | 36 +++++++++++++++++++ python/paddle/jit/__init__.py | 2 ++ python/paddle/jit/api.py | 31 +++++++++++++++- .../paddle/jit/dy2static/convert_call_func.py | 11 ++++++ 4 files changed, 79 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_ignore_module.py diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ignore_module.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ignore_module.py new file mode 100644 index 0000000000..d1ef9e0be2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ignore_module.py @@ -0,0 +1,36 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import astor +import scipy + +from paddle.jit import ignore_module +from paddle.jit.dy2static.convert_call_func import BUILTIN_LIKELY_MODULES + + +class TestIgnoreModule(unittest.TestCase): + def test_ignore_module(self): + modules = [scipy, astor] + ignore_module(modules) + self.assertEquals( + [scipy, astor], + BUILTIN_LIKELY_MODULES[-2:], + 'Failed to add modules that ignore transcription', + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/jit/__init__.py b/python/paddle/jit/__init__.py index fd5ca115c2..00f5f60943 100644 --- a/python/paddle/jit/__init__.py +++ b/python/paddle/jit/__init__.py @@ -17,6 +17,7 @@ from .api import save from .api import load from .api import to_static from .api import not_to_static +from .api import ignore_module from .dy2static.logging_utils import set_code_level, set_verbosity from . import dy2static @@ -27,6 +28,7 @@ __all__ = [ # noqa 'save', 'load', 'to_static', + 'ignore_module', 'ProgramTranslator', 'TranslatedLayer', 'set_code_level', diff --git a/python/paddle/jit/api.py b/python/paddle/jit/api.py index f55aeb5c1b..8898801640 100644 --- a/python/paddle/jit/api.py +++ b/python/paddle/jit/api.py @@ -24,7 +24,7 @@ import warnings from collections import OrderedDict import inspect import threading -from typing import Any +from typing import Any, List import paddle from paddle.fluid import core, dygraph @@ -43,6 +43,7 @@ from .dy2static import logging_utils from .dy2static.convert_call_func import ( ConversionOptions, CONVERSION_OPTIONS, + add_ignore_module, ) from .dy2static.program_translator import ( ProgramTranslator, @@ -192,6 +193,34 @@ def copy_decorator_attrs(original_func, decorated_obj): return decorated_obj +def ignore_module(modules: List[Any]): + """ + Adds modules that ignore transcription. + Builtin modules that have been ignored are collections, pdb, copy, inspect, re, numpy, logging, six + + Args: + modules (List[Any]): Ignored modules that you want to add + + Examples: + .. code-block:: python + + import scipy + import astor + + import paddle + from paddle.jit import ignore_module + + modules = [ + scipy, + astor + ] + + ignore_module(modules) + + """ + add_ignore_module(modules) + + def to_static( function=None, input_spec=None, build_strategy=None, property=False ): diff --git a/python/paddle/jit/dy2static/convert_call_func.py b/python/paddle/jit/dy2static/convert_call_func.py index dbcbbd260f..d4bb0513a6 100644 --- a/python/paddle/jit/dy2static/convert_call_func.py +++ b/python/paddle/jit/dy2static/convert_call_func.py @@ -21,6 +21,7 @@ import logging import pdb import re import types +from typing import Any, List import numpy @@ -100,6 +101,16 @@ def builtin_modules(): BUILTIN_LIKELY_MODULES = builtin_modules() +def add_ignore_module(modules: List[Any]): + """ + Adds modules that ignore transcription + """ + global BUILTIN_LIKELY_MODULES + for module in modules: + if module not in BUILTIN_LIKELY_MODULES: + BUILTIN_LIKELY_MODULES.append(module) + + def is_unsupported(func): """ Checks whether the func is supported by dygraph to static graph. -- GitLab