diff --git a/.gitignore b/.gitignore index 25ecd77e25de9576c7abb9d399d0c0cc5b8eec99..2c486ec96f1069850d7562dd2b5156dbd3b3913d 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ paddle/fluid/API_PR.spec paddle/fluid/eager/api/generated/* paddle/fluid/op_use_default_grad_maker_DEV.spec paddle/fluid/op_use_default_grad_maker_PR.spec +paddle/fluid/operators/ops_extra_info.h paddle/phi/api/backward/backward_api.h paddle/phi/api/backward/sparse_bw_api.h paddle/phi/api/include/api.h diff --git a/paddle/phi/api/lib/CMakeLists.txt b/paddle/phi/api/lib/CMakeLists.txt index 2a1a6b4e78bd5e858c4f7d1d43c191aa23a09c3a..f50323cef216c4a6d17b206f18ae1215dc5eecb4 100644 --- a/paddle/phi/api/lib/CMakeLists.txt +++ b/paddle/phi/api/lib/CMakeLists.txt @@ -94,6 +94,14 @@ set(wrapped_infermeta_header_file set(wrapped_infermeta_source_file ${CMAKE_SOURCE_DIR}/paddle/phi/infermeta/generated.cc) +# op extra info file +set(ops_extra_info_gen_file + ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/generator/ops_extra_info_gen.py) +set(api_compat_yaml_file + ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/api_compat.yaml) +set(ops_extra_info_file + ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/ops_extra_info.h) + if(NOT PYTHONINTERP_FOUND) find_package(PythonInterp REQUIRED) endif() @@ -211,6 +219,13 @@ else() message("remove ${generated_argument_mapping_path}") endif() +# generate ops extra info +execute_process( + COMMAND + ${PYTHON_EXECUTABLE} ${ops_extra_info_gen_file} --api_compat_yaml_path + ${api_compat_yaml_file} --ops_extra_info_path ${ops_extra_info_file}) +message("generate ${ops_extra_info_file}") + # generate forward api add_custom_command( OUTPUT ${api_header_file} ${api_source_file} diff --git a/paddle/phi/api/yaml/api_compat.yaml b/paddle/phi/api/yaml/api_compat.yaml index 17f1d545057c0864ea6e92696330cd9cb73c30aa..987876d7039284b12e05669eceb5fafd3ec26682 100644 --- a/paddle/phi/api/yaml/api_compat.yaml +++ b/paddle/phi/api/yaml/api_compat.yaml @@ -23,3 +23,12 @@ x : Input outputs : out : Out + +- api : conv2d + extra : + attrs : [bool use_cudnn = false, bool fuse_relu_before_depthwise_conv = false, bool use_mkldnn = false, + bool use_quantizer = false, str mkldnn_data_type = "float32", bool fuse_relu = false, + str fuse_activation = "", bool fuse_alpha = false, bool fuse_beta = false, bool use_addto = false, + bool fuse_residual_connection = false, float Scale_in = 1.0f, float Scale_out = 1.0f, + float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool force_fp32_output = false, + int workspace_size_MB = 512, bool exhaustive_search = false] diff --git a/paddle/phi/api/yaml/generator/generate_op.py b/paddle/phi/api/yaml/generator/generate_op.py index 627051365c3f7a07b3e0d28fc3f931dbc40b4168..e70042fb9d03315a096abc4242829e9bf071cdb7 100644 --- a/paddle/phi/api/yaml/generator/generate_op.py +++ b/paddle/phi/api/yaml/generator/generate_op.py @@ -76,6 +76,8 @@ def main(api_yaml_path, backward_yaml_path, api_compat_yaml_path, api_args_map = yaml.safe_load(f) # replace args name for OpMaker for api_args in api_args_map: + if api_args['api'] not in forward_api_dict: + continue forward_api_item = forward_api_dict[api_args['api']] has_backward = True if forward_api_item['backward'] else False if has_backward: diff --git a/paddle/phi/api/yaml/generator/ops_extra_info_gen.py b/paddle/phi/api/yaml/generator/ops_extra_info_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..ef5afbf595b961578f4fcc4c09b337e4437e49d1 --- /dev/null +++ b/paddle/phi/api/yaml/generator/ops_extra_info_gen.py @@ -0,0 +1,110 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import yaml +import re +import argparse + + +def map_code_template(attrs_str): + return f""" +#include "paddle/fluid/framework/attribute.h" + +namespace paddle {{ +const static std::unordered_map extra_attrs_map = {{ +{attrs_str} +}}; + +}} // namespace paddle + +""" + + +ATTR_TYPE_STRING_MAP = { + 'bool': 'bool', + 'int': 'int', + 'int64_t': 'int64_t', + 'float': 'float', + 'double': 'double', + 'str': 'std::string', + 'int[]': 'std::vector', + 'int64_t[]': 'std::vector', + 'float[]': 'std::vector', + 'double[]': 'std::vector', + 'str[]': 'std::vector' +} + + +def parse_attr(attr_str): + result = re.search( + r"(?P[a-z[\]]+)\s+(?P[a-zA-Z0-9_]+)\s*=\s*(?P\S+)", + attr_str) + return ATTR_TYPE_STRING_MAP[result.group('attr_type')], result.group( + 'name'), result.group('default_val') + + +def generate_extra_info(api_compat_yaml_path, ops_extra_info_path): + compat_apis = [] + with open(api_compat_yaml_path, 'rt') as f: + compat_apis = yaml.safe_load(f) + + extra_map_str_list = [] + + for api_compat_args in compat_apis: + if 'extra' in api_compat_args: + extra_args_map = api_compat_args['extra'] + # TODO(chenweihang): add inputs and outputs + if 'attrs' in extra_args_map: + attr_map_list = [] + for attr in extra_args_map['attrs']: + attr_type, attr_name, default_val = parse_attr(attr) + if attr_type.startswith("std::vector"): + attr_map_list.append( + f"{{\"{attr_name}\", {attr_type}{default_val}}}") + else: + attr_map_list.append( + f"{{\"{attr_name}\", {attr_type}{{{default_val}}}}}" + ) + api_extra_attr_map = ", ".join(attr_map_list) + extra_map_str_list.append( + f"{{\"{api_compat_args['api']}\", {{ {api_extra_attr_map} }}}}" + ) + + ops_extra_info_file = open(ops_extra_info_path, 'w') + ops_extra_info_file.write(map_code_template(",\n".join(extra_map_str_list))) + ops_extra_info_file.close() + + +def main(): + parser = argparse.ArgumentParser( + description='Generate PaddlePaddle Extra Param Info for Op') + parser.add_argument('--api_compat_yaml_path', + help='path to api compat yaml file', + default='paddle/phi/api/yaml/api_compat.yaml') + + parser.add_argument('--ops_extra_info_path', + help='output of generated extra_prama_info code file', + default='paddle/fluid/operators/ops_extra_info.h') + + options = parser.parse_args() + + api_compat_yaml_path = options.api_compat_yaml_path + ops_extra_info_path = options.ops_extra_info_path + + generate_extra_info(api_compat_yaml_path, ops_extra_info_path) + + +if __name__ == '__main__': + main()