count_api_without_core_ops.py 6.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import importlib
import inspect
import collections
import sys
import pydoc
import hashlib
import functools
24
import platform
25
from paddle import _C_ops, _legacy_C_ops
26

27 28 29
__all__ = [
    'get_apis_with_and_without_core_ops',
]
30

31
# APIs that should not be printed into API.spec
32 33 34 35 36 37 38
omitted_list = [
    "paddle.fluid.LoDTensor.set",  # Do not know why it should be omitted
    "paddle.fluid.io.ComposeNotAligned",
    "paddle.fluid.io.ComposeNotAligned.__init__",
]


39
def md5(doc):
40 41
    try:
        hashinst = hashlib.md5()
T
tianshuo78520a 已提交
42
        hashinst.update(str(doc).encode('utf-8'))
43 44 45
        md5sum = hashinst.hexdigest()
    except UnicodeDecodeError as e:
        md5sum = None
46 47 48
        print("Error({}) occurred when `md5({})`, discard it.".format(
            str(e), doc),
              file=sys.stderr)
49
    return md5sum
50

51 52

def split_with_and_without_core_ops(member, cur_name):
53 54 55
    if cur_name in omitted_list:
        return

56 57 58 59 60 61 62 63 64 65
    if member.__doc__.find(':api_attr: Static Graph') != -1:
        return

    if cur_name.find('ParamBase') != -1 or cur_name.find(
            'Parameter') != -1 or cur_name.find(
                'Variable') != -1 or cur_name.find(
                    'control_flow') != -1 or cur_name.find(
                        'contrib.mixed_precision') != -1:
        return

66 67 68 69 70 71
    if inspect.isclass(member):
        pass
    else:
        try:
            source = inspect.getsource(member)
            if source.find('append_op') != -1:
72
                if source.find('core.ops') != -1 or source.find('_C_ops') != -1:
73 74 75
                    api_with_ops.append(cur_name)
                else:
                    api_without_ops.append(cur_name)
76 77 78 79
        except:
            # If getsource failed (pybind API or function inherit from father class), just skip
            pass

80

81 82 83 84 85 86 87 88 89 90 91 92 93 94
def get_md5_of_func(member, cur_name):
    if cur_name in omitted_list:
        return

    doc_md5 = md5(member.__doc__)

    if inspect.isclass(member):
        pass
    else:
        try:
            source = inspect.getsource(member)
            func_dict[cur_name] = md5(source)
        except:
            # If getsource failed (pybind API or function inherit from father class), just skip
95 96 97
            pass


98
def visit_member(parent_name, member, func):
99 100
    cur_name = ".".join([parent_name, member.__name__])
    if inspect.isclass(member):
101
        func(member, cur_name)
102
        for name, value in inspect.getmembers(member):
103 104
            if hasattr(value, '__name__') and (not name.startswith("_")
                                               or name == "__init__"):
105
                visit_member(cur_name, value, func)
106 107 108
    elif inspect.ismethoddescriptor(member):
        return
    elif callable(member):
109
        func(member, cur_name)
110 111 112
    elif inspect.isgetsetdescriptor(member):
        return
    else:
113 114 115
        raise RuntimeError(
            "Unsupported generate signature of member, type {0}".format(
                str(type(member))))
116 117 118


def is_primitive(instance):
T
tianshuo78520a 已提交
119
    int_types = (int, )
120 121 122 123 124 125 126 127 128 129 130 131 132
    pritimitive_types = int_types + (float, str)
    if isinstance(instance, pritimitive_types):
        return True
    elif isinstance(instance, (list, tuple, set)):
        for obj in instance:
            if not is_primitive(obj):
                return False

        return True
    else:
        return False


133 134 135 136 137 138 139
ErrorSet = set()
IdSet = set()
skiplist = []
visited_modules = set()


def visit_all_module(mod, func):
140 141 142 143 144 145 146
    mod_name = mod.__name__
    if mod_name != 'paddle' and not mod_name.startswith('paddle.'):
        return

    if mod_name.startswith('paddle.fluid.core'):
        return

147
    if mod in visited_modules:
148
        return
149
    visited_modules.add(mod)
150

151 152 153 154 155
    member_names = dir(mod)
    if hasattr(mod, "__all__"):
        member_names += mod.__all__
    for member_name in member_names:
        if member_name.startswith('_'):
156
            continue
157 158
        cur_name = mod_name + '.' + member_name
        if cur_name in skiplist:
159
            continue
160 161 162 163 164 165 166 167 168 169 170 171 172
        try:
            instance = getattr(mod, member_name)
            if inspect.ismodule(instance):
                visit_all_module(instance, func)
            else:
                instance_id = id(instance)
                if instance_id in IdSet:
                    continue
                IdSet.add(instance_id)
                visit_member(mod.__name__, instance, func)
        except:
            if not cur_name in ErrorSet and not cur_name in skiplist:
                ErrorSet.add(cur_name)
173

174

175 176 177 178 179
def get_apis_with_and_without_core_ops(modules):
    global api_with_ops, api_without_ops
    api_with_ops = []
    api_without_ops = []
    for m in modules:
180 181
        visit_all_module(importlib.import_module(m),
                         split_with_and_without_core_ops)
182
    return api_with_ops, api_without_ops
183 184


185 186 187 188
def get_api_source_desc(modules):
    global func_dict
    func_dict = collections.OrderedDict()
    for m in modules:
189
        visit_all_module(importlib.import_module(m), get_md5_of_func)
190
    return func_dict
191 192


193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
if __name__ == "__main__":
    if len(sys.argv) > 1:
        modules = sys.argv[2].split(",")
        if sys.argv[1] == '-c':
            api_with_ops, api_without_ops = get_apis_with_and_without_core_ops(
                modules)

            print('api_with_ops:', len(api_with_ops))
            print('\n'.join(api_with_ops))
            print('\n==============\n')
            print('api_without_ops:', len(api_without_ops))
            print('\n'.join(api_without_ops))

        if sys.argv[1] == '-p':
            func_dict = get_api_source_desc(modules)
            for name in func_dict:
                print(name, func_dict[name])

    else:
212 213
        print("""Usage:
            1. Count and list all operator-raleated APIs that contains append_op but not _legacy_C_ops.xx.
214 215 216 217
                python ./count_api_without_core_ops.py -c paddle
            2. Print api and the md5 of source code of the api.
                python ./count_api_without_core_ops.py -p paddle
            """)