未验证 提交 74a150fe 编写于 作者: Z zyfncg 提交者: GitHub

[Pten] Generate Wrapped InferMeta by Yaml (#39482)

* generate wrapped_infer_meta

* add test for wrapped_infer_meta

* Update test_meta_fn_utils.cc

* change the dir of generated file
Co-authored-by: NChen Weihang <chenweihang@baidu.com>
Co-authored-by: NChen Weihang <chenwhpro@163.com>
上级 bdeb479c
...@@ -9,6 +9,7 @@ paddle/pten/api/lib/api.cc ...@@ -9,6 +9,7 @@ paddle/pten/api/lib/api.cc
paddle/pten/api/backward/backward_api.h paddle/pten/api/backward/backward_api.h
paddle/pten/api/lib/backward_api.cc paddle/pten/api/lib/backward_api.cc
paddle/pten/include/* paddle/pten/include/*
paddle/pten/infermeta/generated.*
paddle/pten/extension.h paddle/pten/extension.h
paddle/fluid/eager/api/generated/* paddle/fluid/eager/api/generated/*
......
...@@ -34,6 +34,12 @@ set(bw_api_source_file ${CMAKE_SOURCE_DIR}/paddle/pten/api/lib/backward_api.cc) ...@@ -34,6 +34,12 @@ set(bw_api_source_file ${CMAKE_SOURCE_DIR}/paddle/pten/api/lib/backward_api.cc)
set(bw_api_header_file_tmp ${bw_api_header_file}.tmp) set(bw_api_header_file_tmp ${bw_api_header_file}.tmp)
set(bw_api_source_file_tmp ${bw_api_source_file}.tmp) set(bw_api_source_file_tmp ${bw_api_source_file}.tmp)
# wrapped infermeta file
set(wrapped_infermeta_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/wrapped_infermeta_gen.py)
set(api_yaml_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/api.yaml)
set(wrapped_infermeta_header_file ${CMAKE_SOURCE_DIR}/paddle/pten/infermeta/generated.h)
set(wrapped_infermeta_source_file ${CMAKE_SOURCE_DIR}/paddle/pten/infermeta/generated.cc)
if (NOT PYTHON_EXECUTABLE) if (NOT PYTHON_EXECUTABLE)
find_package(PythonInterp REQUIRED) find_package(PythonInterp REQUIRED)
endif() endif()
...@@ -65,8 +71,19 @@ add_custom_command( ...@@ -65,8 +71,19 @@ add_custom_command(
DEPENDS ${bw_api_yaml_file} ${bw_api_gen_file} ${api_gen_base} DEPENDS ${bw_api_yaml_file} ${bw_api_gen_file} ${api_gen_base}
VERBATIM) VERBATIM)
# generate wrapped infermeta
add_custom_command(
OUTPUT ${wrapped_infermeta_header_file} ${wrapped_infermeta_source_file}
COMMAND ${PYTHON_EXECUTABLE} ${wrapped_infermeta_gen_file}
--api_yaml_path ${api_yaml_file}
--wrapped_infermeta_header_path ${wrapped_infermeta_header_file}
--wrapped_infermeta_source_path ${wrapped_infermeta_source_file}
DEPENDS ${api_yaml_file} ${wrapped_infermeta_gen_file} ${api_gen_base}
VERBATIM)
cc_library(pten_data_transform SRCS data_transform.cc DEPS pten_tensor transfer_layout_kernel cast_kernel data_device_transform) cc_library(pten_data_transform SRCS data_transform.cc DEPS pten_tensor transfer_layout_kernel cast_kernel data_device_transform)
cc_library(manual_api SRCS manual_api.cc DEPS pten_tensor pten kernel_dispatch) cc_library(manual_api SRCS manual_api.cc DEPS pten_tensor pten kernel_dispatch)
cc_library(sparse_api SRCS sparse_api.cc DEPS pten_tensor pten kernel_dispatch pten_data_transform) cc_library(sparse_api SRCS sparse_api.cc DEPS pten_tensor pten kernel_dispatch pten_data_transform)
cc_library(pten_function_api SRCS ${api_source_file} DEPS pten_tensor pten kernel_dispatch pten_data_transform) cc_library(pten_function_api SRCS ${api_source_file} DEPS pten_tensor pten kernel_dispatch pten_data_transform)
cc_library(pten_bw_function_api SRCS ${bw_api_source_file} DEPS pten_tensor pten kernel_dispatch backward_infermeta pten_data_transform pten_function_api) cc_library(pten_bw_function_api SRCS ${bw_api_source_file} DEPS pten_tensor pten kernel_dispatch backward_infermeta pten_data_transform pten_function_api)
cc_library(wrapped_infermeta SRCS ${wrapped_infermeta_source_file} DEPS pten)
...@@ -149,10 +149,13 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> { ...@@ -149,10 +149,13 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(bool); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(bool);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(float);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(double);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector<int>&); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector<int>&);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE( PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(
const std::vector<int64_t>&); const std::vector<int64_t>&);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataType); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataType);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(Backend);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataLayout); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataLayout);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const Scalar&); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const Scalar&);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const ScalarArray&); PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const ScalarArray&);
......
...@@ -7,7 +7,7 @@ cc_test(test_sparse_coo_tensor SRCS test_sparse_coo_tensor.cc DEPS dense_tensor ...@@ -7,7 +7,7 @@ cc_test(test_sparse_coo_tensor SRCS test_sparse_coo_tensor.cc DEPS dense_tensor
cc_test(test_sparse_csr_tensor SRCS test_sparse_csr_tensor.cc DEPS dense_tensor sparse_csr_tensor) cc_test(test_sparse_csr_tensor SRCS test_sparse_csr_tensor.cc DEPS dense_tensor sparse_csr_tensor)
cc_test(test_op_utils SRCS test_op_utils.cc DEPS op_compat_infos) cc_test(test_op_utils SRCS test_op_utils.cc DEPS op_compat_infos)
cc_test(test_pten_device_context SRCS test_device_context.cc DEPS pten_context cpu_context) cc_test(test_pten_device_context SRCS test_device_context.cc DEPS pten_context cpu_context)
cc_test(test_meta_fn_utils SRCS test_meta_fn_utils.cc DEPS dense_tensor infermeta infermeta_utils) cc_test(test_meta_fn_utils SRCS test_meta_fn_utils.cc DEPS dense_tensor wrapped_infermeta infermeta infermeta_utils)
cc_test(test_ddim SRCS test_ddim.cc DEPS ddim) cc_test(test_ddim SRCS test_ddim.cc DEPS ddim)
if(WITH_GPU) if(WITH_GPU)
......
...@@ -17,11 +17,26 @@ limitations under the License. */ ...@@ -17,11 +17,26 @@ limitations under the License. */
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/infermeta_utils.h" #include "paddle/pten/core/infermeta_utils.h"
#include "paddle/pten/infermeta/generated.h"
#include "paddle/pten/infermeta/unary.h" #include "paddle/pten/infermeta/unary.h"
namespace pten { namespace pten {
namespace tests { namespace tests {
TEST(WrappedInferMeta, Scale) {
pten::DenseTensor dense_x;
dense_x.Resize(pten::framework::make_ddim({3, 4}));
pten::MetaTensor meta_x(&dense_x);
pten::DenseTensor dense_out1;
pten::MetaTensor meta_out(&dense_out1);
pten::ScaleInferMeta(meta_x, 0, 0, false, &meta_out);
EXPECT_EQ(dense_out1.dims().size(), dense_x.dims().size());
EXPECT_EQ(dense_out1.dims()[0], dense_x.dims()[0]);
EXPECT_EQ(dense_out1.dims()[1], dense_x.dims()[1]);
}
TEST(MetaFnFactory, InferMetaFnExists) { TEST(MetaFnFactory, InferMetaFnExists) {
pten::DenseTensor dense_x; pten::DenseTensor dense_x;
dense_x.Resize(pten::framework::make_ddim({3, 4})); dense_x.Resize(pten::framework::make_ddim({3, 4}));
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
output : Tensor output : Tensor
infer_meta : infer_meta :
func : ElementwiseInferMeta func : ElementwiseInferMeta
param : [x, y]
kernel : kernel :
func : add func : add
...@@ -40,7 +39,6 @@ ...@@ -40,7 +39,6 @@
output : Tensor output : Tensor
infer_meta : infer_meta :
func : ElementwiseInferMeta func : ElementwiseInferMeta
param : [x, y]
kernel : kernel :
func : divide func : divide
...@@ -126,7 +124,6 @@ ...@@ -126,7 +124,6 @@
output : Tensor output : Tensor
infer_meta : infer_meta :
func : ReduceInferMeta func : ReduceInferMeta
param: [x, axis, keep_dim]
kernel : kernel :
func : mean func : mean
...@@ -135,7 +132,6 @@ ...@@ -135,7 +132,6 @@
output : Tensor output : Tensor
infer_meta : infer_meta :
func : ElementwiseInferMeta func : ElementwiseInferMeta
param : [x, y]
kernel : kernel :
func : multiply func : multiply
...@@ -174,7 +170,6 @@ ...@@ -174,7 +170,6 @@
output : Tensor output : Tensor
infer_meta : infer_meta :
func : ElementwiseInferMeta func : ElementwiseInferMeta
param : [x, y]
kernel : kernel :
func : subtract func : subtract
...@@ -183,7 +178,6 @@ ...@@ -183,7 +178,6 @@
output : Tensor output : Tensor
infer_meta : infer_meta :
func : SumInferMeta func : SumInferMeta
param: [x, axis, dtype, keep_dim]
kernel : kernel :
func : sum func : sum
param : [x, axis, dtype, keep_dim] param : [x, axis, dtype, keep_dim]
......
...@@ -60,14 +60,6 @@ class ForwardAPI(BaseAPI): ...@@ -60,14 +60,6 @@ class ForwardAPI(BaseAPI):
return kernel_output, output_names, output_create return kernel_output, output_names, output_create
def gene_infer_meta_register(self):
if self.is_base_api:
return f"""
PT_REGISTER_INFER_META_FN({self.kernel['func']}, pten::{self.infer_meta['func']});"""
else:
return ''
def header_include(): def header_include():
return """ return """
...@@ -91,7 +83,6 @@ def source_include(header_file_path): ...@@ -91,7 +83,6 @@ def source_include(header_file_path):
#include "paddle/pten/api/lib/data_transform.h" #include "paddle/pten/api/lib/data_transform.h"
#include "paddle/pten/api/lib/kernel_dispatch.h" #include "paddle/pten/api/lib/kernel_dispatch.h"
#include "paddle/pten/api/lib/utils/storage.h" #include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/core/infermeta_utils.h"
#include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/infermeta/binary.h" #include "paddle/pten/infermeta/binary.h"
#include "paddle/pten/infermeta/multiary.h" #include "paddle/pten/infermeta/multiary.h"
...@@ -136,21 +127,16 @@ def generate_api(api_yaml_path, header_file_path, source_file_path): ...@@ -136,21 +127,16 @@ def generate_api(api_yaml_path, header_file_path, source_file_path):
source_file.write(source_include(include_header_file)) source_file.write(source_include(include_header_file))
source_file.write(namespace[0]) source_file.write(namespace[0])
infer_meta_register_code = ''
for api in apis: for api in apis:
api_code = ForwardAPI(api) api_code = ForwardAPI(api)
print(api_code.gene_api_declaration()) print(api_code.gene_api_declaration())
header_file.write(api_code.gene_api_declaration()) header_file.write(api_code.gene_api_declaration())
source_file.write(api_code.gene_api_code()) source_file.write(api_code.gene_api_code())
infer_meta_register_code = infer_meta_register_code + api_code.gene_infer_meta_register(
)
header_file.write(namespace[1]) header_file.write(namespace[1])
source_file.write(namespace[1]) source_file.write(namespace[1])
source_file.write(api_register()) source_file.write(api_register())
source_file.write(infer_meta_register_code)
header_file.close() header_file.close()
source_file.close() source_file.close()
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import yaml
import argparse
from api_base import BaseAPI
def get_wrapped_infermeta_name(api_name):
return api_name.capitalize() + 'InferMeta'
def gene_wrapped_infermeta_and_register(api):
if api.is_base_api:
register_code = f"""
PT_REGISTER_INFER_META_FN({api.kernel['func']}, pten::{api.infer_meta['func']});"""
if api.infer_meta['param'] is not None:
tensor_type_map = {
'const Tensor&': 'const MetaTensor&',
'const std::vector<Tensor>&': 'const std::vector<MetaTensor>&',
'Tensor': 'MetaTensor*',
'std::vector<Tensor>': 'std::vector<MetaTensor>*',
}
wrapped_infermeta_name = get_wrapped_infermeta_name(api.api)
args = []
check_args = []
for input_name in api.inputs['names']:
args.append(tensor_type_map[api.inputs['input_info'][
input_name]] + ' ' + input_name)
check_args.append(input_name)
for attr_name in api.attrs['names']:
args.append(api.attrs['attr_info'][attr_name][0] + ' ' +
attr_name)
check_args.append(attr_name)
for i, out_type in enumerate(api.outputs['types']):
args.append(tensor_type_map[out_type] + ' ' + api.outputs[
'names'][i])
if check_args == api.infer_meta['param']:
return '', '', register_code
invoke_param = api.infer_meta['param']
invoke_param.extend(api.outputs['names'])
declare_code = f"""
void {wrapped_infermeta_name}({", ".join(args)});
"""
defind_code = f"""
void {wrapped_infermeta_name}({", ".join(args)}) {{
{api.infer_meta['func']}({", ".join(invoke_param)});
}}
"""
register_code = f"""
PT_REGISTER_INFER_META_FN({api.kernel['func']}, pten::{get_wrapped_infermeta_name(api.kernel['func'])});"""
return declare_code, defind_code, register_code
else:
return '', '', register_code
else:
return '', '', ''
def gene_infermeta_register(api):
if api.is_base_api:
if api.infer_meta['param'] is None:
return f"""
PT_REGISTER_INFER_META_FN({api.kernel['func']}, pten::{api.infer_meta['func']});"""
else:
return f"""
PT_REGISTER_INFER_META_FN({api.kernel['func']}, pten::{get_wrapped_infermeta_name(api.kernel['func'])});"""
else:
return ''
def header_include():
return """
#include "paddle/pten/core/meta_tensor.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
"""
def source_include(header_file_path):
return f"""
#include "{header_file_path}"
#include "paddle/pten/core/infermeta_utils.h"
#include "paddle/pten/infermeta/binary.h"
#include "paddle/pten/infermeta/multiary.h"
#include "paddle/pten/infermeta/nullary.h"
#include "paddle/pten/infermeta/unary.h"
"""
def api_namespace():
return ("""
namespace pten {
""", """
} // namespace pten
""")
def generate_wrapped_infermeta_and_register(api_yaml_path, header_file_path,
source_file_path):
with open(api_yaml_path, 'r') as f:
apis = yaml.load(f, Loader=yaml.FullLoader)
header_file = open(header_file_path, 'w')
source_file = open(source_file_path, 'w')
namespace = api_namespace()
header_file.write("#pragma once\n")
header_file.write(header_include())
header_file.write(namespace[0])
include_header_file = "paddle/pten/infermeta/generated.h"
source_file.write(source_include(include_header_file))
source_file.write(namespace[0])
infermeta_register_code = ''
for api in apis:
api_item = BaseAPI(api)
declare_code, defind_code, register_code = gene_wrapped_infermeta_and_register(
api_item)
header_file.write(declare_code)
source_file.write(defind_code)
infermeta_register_code = infermeta_register_code + register_code
header_file.write(namespace[1])
source_file.write(namespace[1])
source_file.write(infermeta_register_code)
header_file.close()
source_file.close()
def main():
parser = argparse.ArgumentParser(
description='Generate PaddlePaddle C++ API files')
parser.add_argument(
'--api_yaml_path',
help='path to api yaml file',
default='python/paddle/utils/code_gen/api.yaml')
parser.add_argument(
'--wrapped_infermeta_header_path',
help='output of generated wrapped_infermeta header code file',
default='paddle/pten/infermeta/generated.h')
parser.add_argument(
'--wrapped_infermeta_source_path',
help='output of generated wrapped_infermeta source code file',
default='paddle/pten/infermeta/generated.cc')
options = parser.parse_args()
api_yaml_path = options.api_yaml_path
header_file_path = options.wrapped_infermeta_header_path
source_file_path = options.wrapped_infermeta_source_path
generate_wrapped_infermeta_and_register(api_yaml_path, header_file_path,
source_file_path)
if __name__ == '__main__':
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册