utils_helper.py 5.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import inspect

import astor

from paddle.utils import gast

from .ast_utils import ast_to_source_code
from .logging_utils import warn


def index_in_list(array_list, item):
    try:
        return array_list.index(item)
    except ValueError:
        # Item not in array_list
        return -1


# Note(Aurelius): Do not forget the dot `.` to distinguish other
# module such as paddlenlp.
PADDLE_MODULE_PREFIX = 'paddle.'
DYGRAPH_TO_STATIC_MODULE_PREFIX = 'paddle.jit.dy2static'
DYGRAPH_MODULE_PREFIX = 'paddle.fluid.dygraph'


def is_dygraph_api(node):

    # Note: A api in module dygraph_to_static is not a real dygraph api.
    if is_api_in_module(node, DYGRAPH_TO_STATIC_MODULE_PREFIX):
        return False

    # TODO(liym27): A better way to determine whether it is a dygraph api.
    #  Consider the decorator @dygraph_only
    return is_api_in_module(node, DYGRAPH_MODULE_PREFIX)


def is_api_in_module(node, module_prefix):
    assert isinstance(node, gast.Call), "Input non-Call node for is_dygraph_api"

    # Python can have gast.Call as function, for example: covert_call(func)(x)
    # We only check the most outside function
    func_node = node.func
    while isinstance(func_node, gast.Call):
        func_node = func_node.func

    func_str = astor.to_source(gast.gast_to_ast(func_node)).strip()
    try:
        import paddle  # noqa: F401
        import paddle.jit.dy2static as _jst  # noqa: F401
65
        from paddle import fluid  # noqa: F401
66
        from paddle import to_tensor  # noqa: F401
67 68
        from paddle.fluid import dygraph  # noqa: F401
        from paddle.fluid import layers  # noqa: F401
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
        from paddle.fluid.dygraph import to_variable  # noqa: F401

        return eval(
            "_is_api_in_module_helper({}, '{}')".format(func_str, module_prefix)
        )
    except Exception:
        return False


def _is_api_in_module_helper(obj, module_prefix):
    m = inspect.getmodule(obj)
    return m is not None and m.__name__.startswith(module_prefix)


# Is numpy_api cannot reuse is_api_in_module because of numpy module problem
def is_numpy_api(node):
    assert isinstance(node, gast.Call), "Input non-Call node for is_numpy_api"
    func_str = astor.to_source(gast.gast_to_ast(node.func))
    try:
        import numpy as np  # noqa: F401

        module_result = eval(
            "_is_api_in_module_helper({}, '{}')".format(func_str, "numpy")
        )
        # BUG: np.random.uniform doesn't have module and cannot be analyzed
        # TODO: find a better way
        return module_result or (
            func_str.startswith("numpy.") or func_str.startswith("np.")
        )
    except Exception:
        return False


def is_paddle_api(node):
    return is_api_in_module(node, PADDLE_MODULE_PREFIX)


class NodeVarType:
    """
    Enum class of python variable types. We have to know some variable types
    during compile time to transfer AST. For example, a string variable and a
    tensor variable in if clause may lead to different conversion from dygraph
    to static graph.
    """

    ERROR = -1  # Returns when static analysis gets error
    UNKNOWN = 0  # Reserve for AST nodes have not known the type
    STATEMENT = 1  # For nodes representing statement (non-variable type)
    CALLABLE = 2

    # python data types
    NONE = 100
    BOOLEAN = 101
    INT = 102
    FLOAT = 103
    STRING = 104
    TENSOR = 105
    NUMPY_NDARRAY = 106

    # python collections
    LIST = 200
    SET = 201
    DICT = 202

    PADDLE_DYGRAPH_API = 300
    PADDLE_CONTROL_IF = 301
    PADDLE_CONTROL_WHILE = 302
    PADDLE_CONTROL_FOR = 303
    # Paddle API may not be visible to get source code.
    # We use this enum value to denote the type return by a Paddle API
    PADDLE_RETURN_TYPES = 304

    # If node.node_var_type in TENSOR_TYPES, it can be considered as tensor-dependent.
    TENSOR_TYPES = {TENSOR, PADDLE_RETURN_TYPES}

    Annotation_map = {
        "Tensor": TENSOR,
        "paddle.Tensor": TENSOR,
        "int": INT,
        "float": FLOAT,
        "bool": BOOLEAN,
        "str": STRING,
    }

    @staticmethod
    def binary_op_output_type(in_type1, in_type2):
        if in_type1 == in_type2:
            return in_type1

        if in_type1 == NodeVarType.UNKNOWN:
            return in_type2
        if in_type2 == NodeVarType.UNKNOWN:
            return in_type1

        supported_types = [
            NodeVarType.BOOLEAN,
            NodeVarType.INT,
            NodeVarType.FLOAT,
            NodeVarType.NUMPY_NDARRAY,
            NodeVarType.TENSOR,
            NodeVarType.PADDLE_RETURN_TYPES,
        ]

        if in_type1 not in supported_types:
            return NodeVarType.UNKNOWN
        if in_type2 not in supported_types:
            return NodeVarType.UNKNOWN

        forbidden_types = [NodeVarType.NUMPY_NDARRAY, NodeVarType.TENSOR]
        if in_type1 in forbidden_types and in_type2 in forbidden_types:
            return NodeVarType.UNKNOWN
        return max(in_type1, in_type2)

    @staticmethod
    def type_from_annotation(annotation):
        annotation_str = ast_to_source_code(annotation).strip()
        if annotation_str in NodeVarType.Annotation_map:
            return NodeVarType.Annotation_map[annotation_str]

        # raise warning if not found
        warn("Currently we don't support annotation: %s" % annotation_str)
        return NodeVarType.UNKNOWN