# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import inspect import collections import sys import pydoc import hashlib import functools import platform from paddle import _C_ops, _legacy_C_ops __all__ = [ 'get_apis_with_and_without_core_ops', ] # APIs that should not be printed into API.spec omitted_list = [ "paddle.fluid.LoDTensor.set", # Do not know why it should be omitted "paddle.fluid.io.ComposeNotAligned", "paddle.fluid.io.ComposeNotAligned.__init__", ] def md5(doc): try: hashinst = hashlib.md5() hashinst.update(str(doc).encode('utf-8')) md5sum = hashinst.hexdigest() except UnicodeDecodeError as e: md5sum = None print("Error({}) occurred when `md5({})`, discard it.".format( str(e), doc), file=sys.stderr) return md5sum def split_with_and_without_core_ops(member, cur_name): if cur_name in omitted_list: return if member.__doc__.find(':api_attr: Static Graph') != -1: return if cur_name.find('ParamBase') != -1 or cur_name.find( 'Parameter') != -1 or cur_name.find( 'Variable') != -1 or cur_name.find( 'control_flow') != -1 or cur_name.find( 'contrib.mixed_precision') != -1: return if inspect.isclass(member): pass else: try: source = inspect.getsource(member) if source.find('append_op') != -1: if source.find('core.ops') != -1 or source.find('_C_ops') != -1: api_with_ops.append(cur_name) else: api_without_ops.append(cur_name) except: # If getsource failed (pybind API or function inherit from father class), just skip pass def get_md5_of_func(member, cur_name): if cur_name in omitted_list: return doc_md5 = md5(member.__doc__) if inspect.isclass(member): pass else: try: source = inspect.getsource(member) func_dict[cur_name] = md5(source) except: # If getsource failed (pybind API or function inherit from father class), just skip pass def visit_member(parent_name, member, func): cur_name = ".".join([parent_name, member.__name__]) if inspect.isclass(member): func(member, cur_name) for name, value in inspect.getmembers(member): if hasattr(value, '__name__') and (not name.startswith("_") or name == "__init__"): visit_member(cur_name, value, func) elif inspect.ismethoddescriptor(member): return elif callable(member): func(member, cur_name) elif inspect.isgetsetdescriptor(member): return else: raise RuntimeError( "Unsupported generate signature of member, type {0}".format( str(type(member)))) def is_primitive(instance): int_types = (int, ) pritimitive_types = int_types + (float, str) if isinstance(instance, pritimitive_types): return True elif isinstance(instance, (list, tuple, set)): for obj in instance: if not is_primitive(obj): return False return True else: return False ErrorSet = set() IdSet = set() skiplist = [] visited_modules = set() def visit_all_module(mod, func): mod_name = mod.__name__ if mod_name != 'paddle' and not mod_name.startswith('paddle.'): return if mod_name.startswith('paddle.fluid.core'): return if mod in visited_modules: return visited_modules.add(mod) member_names = dir(mod) if hasattr(mod, "__all__"): member_names += mod.__all__ for member_name in member_names: if member_name.startswith('_'): continue cur_name = mod_name + '.' + member_name if cur_name in skiplist: continue try: instance = getattr(mod, member_name) if inspect.ismodule(instance): visit_all_module(instance, func) else: instance_id = id(instance) if instance_id in IdSet: continue IdSet.add(instance_id) visit_member(mod.__name__, instance, func) except: if cur_name not in ErrorSet and cur_name not in skiplist: ErrorSet.add(cur_name) def get_apis_with_and_without_core_ops(modules): global api_with_ops, api_without_ops api_with_ops = [] api_without_ops = [] for m in modules: visit_all_module(importlib.import_module(m), split_with_and_without_core_ops) return api_with_ops, api_without_ops def get_api_source_desc(modules): global func_dict func_dict = collections.OrderedDict() for m in modules: visit_all_module(importlib.import_module(m), get_md5_of_func) return func_dict if __name__ == "__main__": if len(sys.argv) > 1: modules = sys.argv[2].split(",") if sys.argv[1] == '-c': api_with_ops, api_without_ops = get_apis_with_and_without_core_ops( modules) print('api_with_ops:', len(api_with_ops)) print('\n'.join(api_with_ops)) print('\n==============\n') print('api_without_ops:', len(api_without_ops)) print('\n'.join(api_without_ops)) if sys.argv[1] == '-p': func_dict = get_api_source_desc(modules) for name in func_dict: print(name, func_dict[name]) else: print("""Usage: 1. Count and list all operator-raleated APIs that contains append_op but not _legacy_C_ops.xx. python ./count_api_without_core_ops.py -c paddle 2. Print api and the md5 of source code of the api. python ./count_api_without_core_ops.py -p paddle """)