diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py index b7d25e2a14b49166e3ea8ad5e6d63f75ef2517a7..7604be2d838eb669c0f1af1f3a4c53716ce2562f 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py @@ -31,6 +31,7 @@ from paddle.fluid.dygraph.dygraph_to_static.logging_utils import TranslatorLogge from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static from paddle.fluid.dygraph.dygraph_to_static.program_translator import unwrap_decorators +from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_func from paddle.fluid.dygraph.layers import Layer __all__ = ["convert_call"] @@ -74,11 +75,6 @@ def is_builtin_len(func): return False -def is_paddle_func(func): - m = inspect.getmodule(func) - return m is not None and m.__name__.startswith("paddle") - - def is_unsupported(func): """ Checks whether the func is supported by dygraph to static graph. diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 3676958f15df501f7e1ba7d1cf6c17cd3d78abbf..9e61b8aa1ee4267712f18ed3a15fe8a5e80fdb77 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -30,6 +30,12 @@ import numpy as np from paddle.fluid import unique_name from paddle.fluid.data_feeder import convert_dtype +# Note(Aurelius): Do not forget the dot `.` to distinguish other +# module such as paddlenlp. +PADDLE_MODULE_PREFIX = 'paddle.' +DYGRAPH_MODULE_PREFIX = 'paddle.fluid.dygraph' +DYGRAPH_TO_STATIC_MODULE_PREFIX = 'paddle.fluid.dygraph.dygraph_to_static' + class BaseNodeVisitor(gast.NodeVisitor): """ @@ -191,16 +197,21 @@ def is_api_in_module(node, module_prefix): def is_dygraph_api(node): # Note: A api in module dygraph_to_static is not a real dygraph api. - if is_api_in_module(node, "paddle.fluid.dygraph.dygraph_to_static"): + if is_api_in_module(node, DYGRAPH_TO_STATIC_MODULE_PREFIX): return False # TODO(liym27): A better way to determine whether it is a dygraph api. # Consider the decorator @dygraph_only - return is_api_in_module(node, "paddle.fluid.dygraph") + return is_api_in_module(node, DYGRAPH_MODULE_PREFIX) def is_paddle_api(node): - return is_api_in_module(node, "paddle") + return is_api_in_module(node, PADDLE_MODULE_PREFIX) + + +def is_paddle_func(func): + m = inspect.getmodule(func) + return m is not None and m.__name__.startswith(PADDLE_MODULE_PREFIX) # Is numpy_api cannot reuse is_api_in_module because of numpy module problem @@ -1235,7 +1246,7 @@ def input_specs_compatible(src_input_specs, desired_input_specs): len_specs = len(src_input_specs) if len_specs != len(desired_input_specs): # NOTE(chenweihang): if the input_spec of jit.save is a subset of - # input_spec of to_static, also compatible + # input_spec of to_static, also compatible for spec in src_input_specs: if spec not in desired_input_specs: return False diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_utils.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_utils.py index 24b8833fec1925f4a52e7361aeb24818a580cca2..747e9f1c0dbd93be3300b81f34183095d54a6571 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_utils.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_utils.py @@ -14,10 +14,12 @@ from __future__ import print_function +import types import unittest from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list +from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_func from test_program_translator import get_source_code @@ -61,5 +63,14 @@ class TestSplitAssignTransformer(unittest.TestCase): self.assertEqual(answer, code) +class TestIsPaddle(unittest.TestCase): + def fake_module(self): + return types.ModuleType('paddlenlp') + + def test_func(self): + m = self.fake_module() + self.assertFalse(is_paddle_func(m)) + + if __name__ == '__main__': unittest.main()