未验证 提交 621bc4f7 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2static]Fix paddle prefix in is_paddle_api (#30569)

* add paddle.

* add unittest
上级 9dd71c74
......@@ -31,6 +31,7 @@ from paddle.fluid.dygraph.dygraph_to_static.logging_utils import TranslatorLogge
from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction
from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static
from paddle.fluid.dygraph.dygraph_to_static.program_translator import unwrap_decorators
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_func
from paddle.fluid.dygraph.layers import Layer
__all__ = ["convert_call"]
......@@ -74,11 +75,6 @@ def is_builtin_len(func):
return False
def is_paddle_func(func):
m = inspect.getmodule(func)
return m is not None and m.__name__.startswith("paddle")
def is_unsupported(func):
"""
Checks whether the func is supported by dygraph to static graph.
......
......@@ -30,6 +30,12 @@ import numpy as np
from paddle.fluid import unique_name
from paddle.fluid.data_feeder import convert_dtype
# Note(Aurelius): Do not forget the dot `.` to distinguish other
# module such as paddlenlp.
PADDLE_MODULE_PREFIX = 'paddle.'
DYGRAPH_MODULE_PREFIX = 'paddle.fluid.dygraph'
DYGRAPH_TO_STATIC_MODULE_PREFIX = 'paddle.fluid.dygraph.dygraph_to_static'
class BaseNodeVisitor(gast.NodeVisitor):
"""
......@@ -191,16 +197,21 @@ def is_api_in_module(node, module_prefix):
def is_dygraph_api(node):
# Note: A api in module dygraph_to_static is not a real dygraph api.
if is_api_in_module(node, "paddle.fluid.dygraph.dygraph_to_static"):
if is_api_in_module(node, DYGRAPH_TO_STATIC_MODULE_PREFIX):
return False
# TODO(liym27): A better way to determine whether it is a dygraph api.
# Consider the decorator @dygraph_only
return is_api_in_module(node, "paddle.fluid.dygraph")
return is_api_in_module(node, DYGRAPH_MODULE_PREFIX)
def is_paddle_api(node):
return is_api_in_module(node, "paddle")
return is_api_in_module(node, PADDLE_MODULE_PREFIX)
def is_paddle_func(func):
m = inspect.getmodule(func)
return m is not None and m.__name__.startswith(PADDLE_MODULE_PREFIX)
# Is numpy_api cannot reuse is_api_in_module because of numpy module problem
......@@ -1235,7 +1246,7 @@ def input_specs_compatible(src_input_specs, desired_input_specs):
len_specs = len(src_input_specs)
if len_specs != len(desired_input_specs):
# NOTE(chenweihang): if the input_spec of jit.save is a subset of
# input_spec of to_static, also compatible
# input_spec of to_static, also compatible
for spec in src_input_specs:
if spec not in desired_input_specs:
return False
......
......@@ -14,10 +14,12 @@
from __future__ import print_function
import types
import unittest
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_func
from test_program_translator import get_source_code
......@@ -61,5 +63,14 @@ class TestSplitAssignTransformer(unittest.TestCase):
self.assertEqual(answer, code)
class TestIsPaddle(unittest.TestCase):
def fake_module(self):
return types.ModuleType('paddlenlp')
def test_func(self):
m = self.fake_module()
self.assertFalse(is_paddle_func(m))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册