未验证 提交 daea892c 编写于 作者: H hjyp 提交者: GitHub

[Dy2St] Add ignore_module API (#49485)

* Add ignore_module API

* fix type of parameter

* Add test case of ignore-module
上级 451756fb
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import astor
import scipy
from paddle.jit import ignore_module
from paddle.jit.dy2static.convert_call_func import BUILTIN_LIKELY_MODULES
class TestIgnoreModule(unittest.TestCase):
def test_ignore_module(self):
modules = [scipy, astor]
ignore_module(modules)
self.assertEquals(
[scipy, astor],
BUILTIN_LIKELY_MODULES[-2:],
'Failed to add modules that ignore transcription',
)
if __name__ == '__main__':
unittest.main()
...@@ -17,6 +17,7 @@ from .api import save ...@@ -17,6 +17,7 @@ from .api import save
from .api import load from .api import load
from .api import to_static from .api import to_static
from .api import not_to_static from .api import not_to_static
from .api import ignore_module
from .dy2static.logging_utils import set_code_level, set_verbosity from .dy2static.logging_utils import set_code_level, set_verbosity
from . import dy2static from . import dy2static
...@@ -27,6 +28,7 @@ __all__ = [ # noqa ...@@ -27,6 +28,7 @@ __all__ = [ # noqa
'save', 'save',
'load', 'load',
'to_static', 'to_static',
'ignore_module',
'ProgramTranslator', 'ProgramTranslator',
'TranslatedLayer', 'TranslatedLayer',
'set_code_level', 'set_code_level',
......
...@@ -24,7 +24,7 @@ import warnings ...@@ -24,7 +24,7 @@ import warnings
from collections import OrderedDict from collections import OrderedDict
import inspect import inspect
import threading import threading
from typing import Any from typing import Any, List
import paddle import paddle
from paddle.fluid import core, dygraph from paddle.fluid import core, dygraph
...@@ -43,6 +43,7 @@ from .dy2static import logging_utils ...@@ -43,6 +43,7 @@ from .dy2static import logging_utils
from .dy2static.convert_call_func import ( from .dy2static.convert_call_func import (
ConversionOptions, ConversionOptions,
CONVERSION_OPTIONS, CONVERSION_OPTIONS,
add_ignore_module,
) )
from .dy2static.program_translator import ( from .dy2static.program_translator import (
ProgramTranslator, ProgramTranslator,
...@@ -192,6 +193,34 @@ def copy_decorator_attrs(original_func, decorated_obj): ...@@ -192,6 +193,34 @@ def copy_decorator_attrs(original_func, decorated_obj):
return decorated_obj return decorated_obj
def ignore_module(modules: List[Any]):
"""
Adds modules that ignore transcription.
Builtin modules that have been ignored are collections, pdb, copy, inspect, re, numpy, logging, six
Args:
modules (List[Any]): Ignored modules that you want to add
Examples:
.. code-block:: python
import scipy
import astor
import paddle
from paddle.jit import ignore_module
modules = [
scipy,
astor
]
ignore_module(modules)
"""
add_ignore_module(modules)
def to_static( def to_static(
function=None, input_spec=None, build_strategy=None, property=False function=None, input_spec=None, build_strategy=None, property=False
): ):
......
...@@ -21,6 +21,7 @@ import logging ...@@ -21,6 +21,7 @@ import logging
import pdb import pdb
import re import re
import types import types
from typing import Any, List
import numpy import numpy
...@@ -100,6 +101,16 @@ def builtin_modules(): ...@@ -100,6 +101,16 @@ def builtin_modules():
BUILTIN_LIKELY_MODULES = builtin_modules() BUILTIN_LIKELY_MODULES = builtin_modules()
def add_ignore_module(modules: List[Any]):
"""
Adds modules that ignore transcription
"""
global BUILTIN_LIKELY_MODULES
for module in modules:
if module not in BUILTIN_LIKELY_MODULES:
BUILTIN_LIKELY_MODULES.append(module)
def is_unsupported(func): def is_unsupported(func):
""" """
Checks whether the func is supported by dygraph to static graph. Checks whether the func is supported by dygraph to static graph.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册