未验证 提交 621bc4f7 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2static]Fix paddle prefix in is_paddle_api (#30569)

* add paddle.

* add unittest
上级 9dd71c74
...@@ -31,6 +31,7 @@ from paddle.fluid.dygraph.dygraph_to_static.logging_utils import TranslatorLogge ...@@ -31,6 +31,7 @@ from paddle.fluid.dygraph.dygraph_to_static.logging_utils import TranslatorLogge
from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction
from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static
from paddle.fluid.dygraph.dygraph_to_static.program_translator import unwrap_decorators from paddle.fluid.dygraph.dygraph_to_static.program_translator import unwrap_decorators
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_func
from paddle.fluid.dygraph.layers import Layer from paddle.fluid.dygraph.layers import Layer
__all__ = ["convert_call"] __all__ = ["convert_call"]
...@@ -74,11 +75,6 @@ def is_builtin_len(func): ...@@ -74,11 +75,6 @@ def is_builtin_len(func):
return False return False
def is_paddle_func(func):
m = inspect.getmodule(func)
return m is not None and m.__name__.startswith("paddle")
def is_unsupported(func): def is_unsupported(func):
""" """
Checks whether the func is supported by dygraph to static graph. Checks whether the func is supported by dygraph to static graph.
......
...@@ -30,6 +30,12 @@ import numpy as np ...@@ -30,6 +30,12 @@ import numpy as np
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.fluid.data_feeder import convert_dtype from paddle.fluid.data_feeder import convert_dtype
# Note(Aurelius): Do not forget the dot `.` to distinguish other
# module such as paddlenlp.
PADDLE_MODULE_PREFIX = 'paddle.'
DYGRAPH_MODULE_PREFIX = 'paddle.fluid.dygraph'
DYGRAPH_TO_STATIC_MODULE_PREFIX = 'paddle.fluid.dygraph.dygraph_to_static'
class BaseNodeVisitor(gast.NodeVisitor): class BaseNodeVisitor(gast.NodeVisitor):
""" """
...@@ -191,16 +197,21 @@ def is_api_in_module(node, module_prefix): ...@@ -191,16 +197,21 @@ def is_api_in_module(node, module_prefix):
def is_dygraph_api(node): def is_dygraph_api(node):
# Note: A api in module dygraph_to_static is not a real dygraph api. # Note: A api in module dygraph_to_static is not a real dygraph api.
if is_api_in_module(node, "paddle.fluid.dygraph.dygraph_to_static"): if is_api_in_module(node, DYGRAPH_TO_STATIC_MODULE_PREFIX):
return False return False
# TODO(liym27): A better way to determine whether it is a dygraph api. # TODO(liym27): A better way to determine whether it is a dygraph api.
# Consider the decorator @dygraph_only # Consider the decorator @dygraph_only
return is_api_in_module(node, "paddle.fluid.dygraph") return is_api_in_module(node, DYGRAPH_MODULE_PREFIX)
def is_paddle_api(node): def is_paddle_api(node):
return is_api_in_module(node, "paddle") return is_api_in_module(node, PADDLE_MODULE_PREFIX)
def is_paddle_func(func):
m = inspect.getmodule(func)
return m is not None and m.__name__.startswith(PADDLE_MODULE_PREFIX)
# Is numpy_api cannot reuse is_api_in_module because of numpy module problem # Is numpy_api cannot reuse is_api_in_module because of numpy module problem
...@@ -1235,7 +1246,7 @@ def input_specs_compatible(src_input_specs, desired_input_specs): ...@@ -1235,7 +1246,7 @@ def input_specs_compatible(src_input_specs, desired_input_specs):
len_specs = len(src_input_specs) len_specs = len(src_input_specs)
if len_specs != len(desired_input_specs): if len_specs != len(desired_input_specs):
# NOTE(chenweihang): if the input_spec of jit.save is a subset of # NOTE(chenweihang): if the input_spec of jit.save is a subset of
# input_spec of to_static, also compatible # input_spec of to_static, also compatible
for spec in src_input_specs: for spec in src_input_specs:
if spec not in desired_input_specs: if spec not in desired_input_specs:
return False return False
......
...@@ -14,10 +14,12 @@ ...@@ -14,10 +14,12 @@
from __future__ import print_function from __future__ import print_function
import types
import unittest import unittest
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_func
from test_program_translator import get_source_code from test_program_translator import get_source_code
...@@ -61,5 +63,14 @@ class TestSplitAssignTransformer(unittest.TestCase): ...@@ -61,5 +63,14 @@ class TestSplitAssignTransformer(unittest.TestCase):
self.assertEqual(answer, code) self.assertEqual(answer, code)
class TestIsPaddle(unittest.TestCase):
def fake_module(self):
return types.ModuleType('paddlenlp')
def test_func(self):
m = self.fake_module()
self.assertFalse(is_paddle_func(m))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册