From 74a150fed63d08908e97e2708a0225e011e3f08c Mon Sep 17 00:00:00 2001 From: zyfncg Date: Sun, 13 Feb 2022 10:20:44 +0800 Subject: [PATCH] [Pten] Generate Wrapped InferMeta by Yaml (#39482) * generate wrapped_infer_meta * add test for wrapped_infer_meta * Update test_meta_fn_utils.cc * change the dir of generated file Co-authored-by: Chen Weihang Co-authored-by: Chen Weihang --- .gitignore | 1 + paddle/pten/api/lib/CMakeLists.txt | 17 ++ paddle/pten/core/infermeta_utils.h | 3 + paddle/pten/tests/core/CMakeLists.txt | 2 +- paddle/pten/tests/core/test_meta_fn_utils.cc | 15 ++ python/paddle/utils/code_gen/api.yaml | 6 - python/paddle/utils/code_gen/api_gen.py | 14 -- .../utils/code_gen/wrapped_infermeta_gen.py | 185 ++++++++++++++++++ 8 files changed, 222 insertions(+), 21 deletions(-) create mode 100644 python/paddle/utils/code_gen/wrapped_infermeta_gen.py diff --git a/.gitignore b/.gitignore index 77fe7a9b4cd..5018bf56c16 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ paddle/pten/api/lib/api.cc paddle/pten/api/backward/backward_api.h paddle/pten/api/lib/backward_api.cc paddle/pten/include/* +paddle/pten/infermeta/generated.* paddle/pten/extension.h paddle/fluid/eager/api/generated/* diff --git a/paddle/pten/api/lib/CMakeLists.txt b/paddle/pten/api/lib/CMakeLists.txt index 9be55572e20..2cf737eb8b1 100644 --- a/paddle/pten/api/lib/CMakeLists.txt +++ b/paddle/pten/api/lib/CMakeLists.txt @@ -34,6 +34,12 @@ set(bw_api_source_file ${CMAKE_SOURCE_DIR}/paddle/pten/api/lib/backward_api.cc) set(bw_api_header_file_tmp ${bw_api_header_file}.tmp) set(bw_api_source_file_tmp ${bw_api_source_file}.tmp) +# wrapped infermeta file +set(wrapped_infermeta_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/wrapped_infermeta_gen.py) +set(api_yaml_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/api.yaml) +set(wrapped_infermeta_header_file ${CMAKE_SOURCE_DIR}/paddle/pten/infermeta/generated.h) +set(wrapped_infermeta_source_file ${CMAKE_SOURCE_DIR}/paddle/pten/infermeta/generated.cc) + if (NOT PYTHON_EXECUTABLE) find_package(PythonInterp REQUIRED) endif() @@ -65,8 +71,19 @@ add_custom_command( DEPENDS ${bw_api_yaml_file} ${bw_api_gen_file} ${api_gen_base} VERBATIM) +# generate wrapped infermeta +add_custom_command( + OUTPUT ${wrapped_infermeta_header_file} ${wrapped_infermeta_source_file} + COMMAND ${PYTHON_EXECUTABLE} ${wrapped_infermeta_gen_file} + --api_yaml_path ${api_yaml_file} + --wrapped_infermeta_header_path ${wrapped_infermeta_header_file} + --wrapped_infermeta_source_path ${wrapped_infermeta_source_file} + DEPENDS ${api_yaml_file} ${wrapped_infermeta_gen_file} ${api_gen_base} + VERBATIM) + cc_library(pten_data_transform SRCS data_transform.cc DEPS pten_tensor transfer_layout_kernel cast_kernel data_device_transform) cc_library(manual_api SRCS manual_api.cc DEPS pten_tensor pten kernel_dispatch) cc_library(sparse_api SRCS sparse_api.cc DEPS pten_tensor pten kernel_dispatch pten_data_transform) cc_library(pten_function_api SRCS ${api_source_file} DEPS pten_tensor pten kernel_dispatch pten_data_transform) cc_library(pten_bw_function_api SRCS ${bw_api_source_file} DEPS pten_tensor pten kernel_dispatch backward_infermeta pten_data_transform pten_function_api) +cc_library(wrapped_infermeta SRCS ${wrapped_infermeta_source_file} DEPS pten) diff --git a/paddle/pten/core/infermeta_utils.h b/paddle/pten/core/infermeta_utils.h index 0c95b204abf..c95ae6b69f7 100644 --- a/paddle/pten/core/infermeta_utils.h +++ b/paddle/pten/core/infermeta_utils.h @@ -149,10 +149,13 @@ struct InferMetaFnImpl { PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(bool); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t); + PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(float); + PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(double); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector&); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE( const std::vector&); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataType); + PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(Backend); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataLayout); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const Scalar&); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const ScalarArray&); diff --git a/paddle/pten/tests/core/CMakeLists.txt b/paddle/pten/tests/core/CMakeLists.txt index d9c8c86a240..32e6e0784da 100644 --- a/paddle/pten/tests/core/CMakeLists.txt +++ b/paddle/pten/tests/core/CMakeLists.txt @@ -7,7 +7,7 @@ cc_test(test_sparse_coo_tensor SRCS test_sparse_coo_tensor.cc DEPS dense_tensor cc_test(test_sparse_csr_tensor SRCS test_sparse_csr_tensor.cc DEPS dense_tensor sparse_csr_tensor) cc_test(test_op_utils SRCS test_op_utils.cc DEPS op_compat_infos) cc_test(test_pten_device_context SRCS test_device_context.cc DEPS pten_context cpu_context) -cc_test(test_meta_fn_utils SRCS test_meta_fn_utils.cc DEPS dense_tensor infermeta infermeta_utils) +cc_test(test_meta_fn_utils SRCS test_meta_fn_utils.cc DEPS dense_tensor wrapped_infermeta infermeta infermeta_utils) cc_test(test_ddim SRCS test_ddim.cc DEPS ddim) if(WITH_GPU) diff --git a/paddle/pten/tests/core/test_meta_fn_utils.cc b/paddle/pten/tests/core/test_meta_fn_utils.cc index 3cde1cfb5dc..f4edc3555bc 100644 --- a/paddle/pten/tests/core/test_meta_fn_utils.cc +++ b/paddle/pten/tests/core/test_meta_fn_utils.cc @@ -17,11 +17,26 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/infermeta_utils.h" +#include "paddle/pten/infermeta/generated.h" #include "paddle/pten/infermeta/unary.h" namespace pten { namespace tests { +TEST(WrappedInferMeta, Scale) { + pten::DenseTensor dense_x; + dense_x.Resize(pten::framework::make_ddim({3, 4})); + + pten::MetaTensor meta_x(&dense_x); + pten::DenseTensor dense_out1; + pten::MetaTensor meta_out(&dense_out1); + pten::ScaleInferMeta(meta_x, 0, 0, false, &meta_out); + + EXPECT_EQ(dense_out1.dims().size(), dense_x.dims().size()); + EXPECT_EQ(dense_out1.dims()[0], dense_x.dims()[0]); + EXPECT_EQ(dense_out1.dims()[1], dense_x.dims()[1]); +} + TEST(MetaFnFactory, InferMetaFnExists) { pten::DenseTensor dense_x; dense_x.Resize(pten::framework::make_ddim({3, 4})); diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 22f67270452..6f64eaadc89 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -3,7 +3,6 @@ output : Tensor infer_meta : func : ElementwiseInferMeta - param : [x, y] kernel : func : add @@ -40,7 +39,6 @@ output : Tensor infer_meta : func : ElementwiseInferMeta - param : [x, y] kernel : func : divide @@ -126,7 +124,6 @@ output : Tensor infer_meta : func : ReduceInferMeta - param: [x, axis, keep_dim] kernel : func : mean @@ -135,7 +132,6 @@ output : Tensor infer_meta : func : ElementwiseInferMeta - param : [x, y] kernel : func : multiply @@ -174,7 +170,6 @@ output : Tensor infer_meta : func : ElementwiseInferMeta - param : [x, y] kernel : func : subtract @@ -183,7 +178,6 @@ output : Tensor infer_meta : func : SumInferMeta - param: [x, axis, dtype, keep_dim] kernel : func : sum param : [x, axis, dtype, keep_dim] diff --git a/python/paddle/utils/code_gen/api_gen.py b/python/paddle/utils/code_gen/api_gen.py index 7039129a796..629d68230a1 100644 --- a/python/paddle/utils/code_gen/api_gen.py +++ b/python/paddle/utils/code_gen/api_gen.py @@ -60,14 +60,6 @@ class ForwardAPI(BaseAPI): return kernel_output, output_names, output_create - def gene_infer_meta_register(self): - if self.is_base_api: - return f""" -PT_REGISTER_INFER_META_FN({self.kernel['func']}, pten::{self.infer_meta['func']});""" - - else: - return '' - def header_include(): return """ @@ -91,7 +83,6 @@ def source_include(header_file_path): #include "paddle/pten/api/lib/data_transform.h" #include "paddle/pten/api/lib/kernel_dispatch.h" #include "paddle/pten/api/lib/utils/storage.h" -#include "paddle/pten/core/infermeta_utils.h" #include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/infermeta/binary.h" #include "paddle/pten/infermeta/multiary.h" @@ -136,21 +127,16 @@ def generate_api(api_yaml_path, header_file_path, source_file_path): source_file.write(source_include(include_header_file)) source_file.write(namespace[0]) - infer_meta_register_code = '' - for api in apis: api_code = ForwardAPI(api) print(api_code.gene_api_declaration()) header_file.write(api_code.gene_api_declaration()) source_file.write(api_code.gene_api_code()) - infer_meta_register_code = infer_meta_register_code + api_code.gene_infer_meta_register( - ) header_file.write(namespace[1]) source_file.write(namespace[1]) source_file.write(api_register()) - source_file.write(infer_meta_register_code) header_file.close() source_file.close() diff --git a/python/paddle/utils/code_gen/wrapped_infermeta_gen.py b/python/paddle/utils/code_gen/wrapped_infermeta_gen.py new file mode 100644 index 00000000000..ad26062e6ba --- /dev/null +++ b/python/paddle/utils/code_gen/wrapped_infermeta_gen.py @@ -0,0 +1,185 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import yaml +import argparse + +from api_base import BaseAPI + + +def get_wrapped_infermeta_name(api_name): + return api_name.capitalize() + 'InferMeta' + + +def gene_wrapped_infermeta_and_register(api): + if api.is_base_api: + register_code = f""" +PT_REGISTER_INFER_META_FN({api.kernel['func']}, pten::{api.infer_meta['func']});""" + + if api.infer_meta['param'] is not None: + tensor_type_map = { + 'const Tensor&': 'const MetaTensor&', + 'const std::vector&': 'const std::vector&', + 'Tensor': 'MetaTensor*', + 'std::vector': 'std::vector*', + } + wrapped_infermeta_name = get_wrapped_infermeta_name(api.api) + args = [] + check_args = [] + for input_name in api.inputs['names']: + args.append(tensor_type_map[api.inputs['input_info'][ + input_name]] + ' ' + input_name) + check_args.append(input_name) + for attr_name in api.attrs['names']: + args.append(api.attrs['attr_info'][attr_name][0] + ' ' + + attr_name) + check_args.append(attr_name) + for i, out_type in enumerate(api.outputs['types']): + args.append(tensor_type_map[out_type] + ' ' + api.outputs[ + 'names'][i]) + + if check_args == api.infer_meta['param']: + return '', '', register_code + + invoke_param = api.infer_meta['param'] + invoke_param.extend(api.outputs['names']) + + declare_code = f""" +void {wrapped_infermeta_name}({", ".join(args)}); +""" + + defind_code = f""" +void {wrapped_infermeta_name}({", ".join(args)}) {{ + {api.infer_meta['func']}({", ".join(invoke_param)}); +}} +""" + + register_code = f""" +PT_REGISTER_INFER_META_FN({api.kernel['func']}, pten::{get_wrapped_infermeta_name(api.kernel['func'])});""" + + return declare_code, defind_code, register_code + else: + return '', '', register_code + else: + return '', '', '' + + +def gene_infermeta_register(api): + if api.is_base_api: + if api.infer_meta['param'] is None: + return f""" +PT_REGISTER_INFER_META_FN({api.kernel['func']}, pten::{api.infer_meta['func']});""" + + else: + return f""" +PT_REGISTER_INFER_META_FN({api.kernel['func']}, pten::{get_wrapped_infermeta_name(api.kernel['func'])});""" + + else: + return '' + + +def header_include(): + return """ +#include "paddle/pten/core/meta_tensor.h" +#include "paddle/pten/common/scalar.h" +#include "paddle/pten/common/scalar_array.h" +""" + + +def source_include(header_file_path): + return f""" +#include "{header_file_path}" +#include "paddle/pten/core/infermeta_utils.h" +#include "paddle/pten/infermeta/binary.h" +#include "paddle/pten/infermeta/multiary.h" +#include "paddle/pten/infermeta/nullary.h" +#include "paddle/pten/infermeta/unary.h" +""" + + +def api_namespace(): + return (""" +namespace pten { +""", """ +} // namespace pten +""") + + +def generate_wrapped_infermeta_and_register(api_yaml_path, header_file_path, + source_file_path): + + with open(api_yaml_path, 'r') as f: + apis = yaml.load(f, Loader=yaml.FullLoader) + header_file = open(header_file_path, 'w') + source_file = open(source_file_path, 'w') + + namespace = api_namespace() + + header_file.write("#pragma once\n") + header_file.write(header_include()) + header_file.write(namespace[0]) + + include_header_file = "paddle/pten/infermeta/generated.h" + source_file.write(source_include(include_header_file)) + source_file.write(namespace[0]) + + infermeta_register_code = '' + + for api in apis: + api_item = BaseAPI(api) + declare_code, defind_code, register_code = gene_wrapped_infermeta_and_register( + api_item) + header_file.write(declare_code) + source_file.write(defind_code) + infermeta_register_code = infermeta_register_code + register_code + + header_file.write(namespace[1]) + source_file.write(namespace[1]) + + source_file.write(infermeta_register_code) + + header_file.close() + source_file.close() + + +def main(): + parser = argparse.ArgumentParser( + description='Generate PaddlePaddle C++ API files') + parser.add_argument( + '--api_yaml_path', + help='path to api yaml file', + default='python/paddle/utils/code_gen/api.yaml') + parser.add_argument( + '--wrapped_infermeta_header_path', + help='output of generated wrapped_infermeta header code file', + default='paddle/pten/infermeta/generated.h') + + parser.add_argument( + '--wrapped_infermeta_source_path', + help='output of generated wrapped_infermeta source code file', + default='paddle/pten/infermeta/generated.cc') + + options = parser.parse_args() + + api_yaml_path = options.api_yaml_path + header_file_path = options.wrapped_infermeta_header_path + source_file_path = options.wrapped_infermeta_source_path + + generate_wrapped_infermeta_and_register(api_yaml_path, header_file_path, + source_file_path) + + +if __name__ == '__main__': + main() -- GitLab