未验证 提交 74a150fe 编写于 作者: Z zyfncg 提交者: GitHub

[Pten] Generate Wrapped InferMeta by Yaml (#39482)

* generate wrapped_infer_meta

* add test for wrapped_infer_meta

* Update test_meta_fn_utils.cc

* change the dir of generated file
Co-authored-by: NChen Weihang <chenweihang@baidu.com>
Co-authored-by: NChen Weihang <chenwhpro@163.com>
上级 bdeb479c
......@@ -9,6 +9,7 @@ paddle/pten/api/lib/api.cc
paddle/pten/api/backward/backward_api.h
paddle/pten/api/lib/backward_api.cc
paddle/pten/include/*
paddle/pten/infermeta/generated.*
paddle/pten/extension.h
paddle/fluid/eager/api/generated/*
......
......@@ -34,6 +34,12 @@ set(bw_api_source_file ${CMAKE_SOURCE_DIR}/paddle/pten/api/lib/backward_api.cc)
set(bw_api_header_file_tmp ${bw_api_header_file}.tmp)
set(bw_api_source_file_tmp ${bw_api_source_file}.tmp)
# wrapped infermeta file
set(wrapped_infermeta_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/wrapped_infermeta_gen.py)
set(api_yaml_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/api.yaml)
set(wrapped_infermeta_header_file ${CMAKE_SOURCE_DIR}/paddle/pten/infermeta/generated.h)
set(wrapped_infermeta_source_file ${CMAKE_SOURCE_DIR}/paddle/pten/infermeta/generated.cc)
if (NOT PYTHON_EXECUTABLE)
find_package(PythonInterp REQUIRED)
endif()
......@@ -65,8 +71,19 @@ add_custom_command(
DEPENDS ${bw_api_yaml_file} ${bw_api_gen_file} ${api_gen_base}
VERBATIM)
# generate wrapped infermeta
add_custom_command(
OUTPUT ${wrapped_infermeta_header_file} ${wrapped_infermeta_source_file}
COMMAND ${PYTHON_EXECUTABLE} ${wrapped_infermeta_gen_file}
--api_yaml_path ${api_yaml_file}
--wrapped_infermeta_header_path ${wrapped_infermeta_header_file}
--wrapped_infermeta_source_path ${wrapped_infermeta_source_file}
DEPENDS ${api_yaml_file} ${wrapped_infermeta_gen_file} ${api_gen_base}
VERBATIM)
cc_library(pten_data_transform SRCS data_transform.cc DEPS pten_tensor transfer_layout_kernel cast_kernel data_device_transform)
cc_library(manual_api SRCS manual_api.cc DEPS pten_tensor pten kernel_dispatch)
cc_library(sparse_api SRCS sparse_api.cc DEPS pten_tensor pten kernel_dispatch pten_data_transform)
cc_library(pten_function_api SRCS ${api_source_file} DEPS pten_tensor pten kernel_dispatch pten_data_transform)
cc_library(pten_bw_function_api SRCS ${bw_api_source_file} DEPS pten_tensor pten kernel_dispatch backward_infermeta pten_data_transform pten_function_api)
cc_library(wrapped_infermeta SRCS ${wrapped_infermeta_source_file} DEPS pten)
......@@ -149,10 +149,13 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(bool);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(float);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(double);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector<int>&);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(
const std::vector<int64_t>&);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataType);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(Backend);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataLayout);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const Scalar&);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const ScalarArray&);
......
......@@ -7,7 +7,7 @@ cc_test(test_sparse_coo_tensor SRCS test_sparse_coo_tensor.cc DEPS dense_tensor
cc_test(test_sparse_csr_tensor SRCS test_sparse_csr_tensor.cc DEPS dense_tensor sparse_csr_tensor)
cc_test(test_op_utils SRCS test_op_utils.cc DEPS op_compat_infos)
cc_test(test_pten_device_context SRCS test_device_context.cc DEPS pten_context cpu_context)
cc_test(test_meta_fn_utils SRCS test_meta_fn_utils.cc DEPS dense_tensor infermeta infermeta_utils)
cc_test(test_meta_fn_utils SRCS test_meta_fn_utils.cc DEPS dense_tensor wrapped_infermeta infermeta infermeta_utils)
cc_test(test_ddim SRCS test_ddim.cc DEPS ddim)
if(WITH_GPU)
......
......@@ -17,11 +17,26 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/infermeta_utils.h"
#include "paddle/pten/infermeta/generated.h"
#include "paddle/pten/infermeta/unary.h"
namespace pten {
namespace tests {
TEST(WrappedInferMeta, Scale) {
pten::DenseTensor dense_x;
dense_x.Resize(pten::framework::make_ddim({3, 4}));
pten::MetaTensor meta_x(&dense_x);
pten::DenseTensor dense_out1;
pten::MetaTensor meta_out(&dense_out1);
pten::ScaleInferMeta(meta_x, 0, 0, false, &meta_out);
EXPECT_EQ(dense_out1.dims().size(), dense_x.dims().size());
EXPECT_EQ(dense_out1.dims()[0], dense_x.dims()[0]);
EXPECT_EQ(dense_out1.dims()[1], dense_x.dims()[1]);
}
TEST(MetaFnFactory, InferMetaFnExists) {
pten::DenseTensor dense_x;
dense_x.Resize(pten::framework::make_ddim({3, 4}));
......
......@@ -3,7 +3,6 @@
output : Tensor
infer_meta :
func : ElementwiseInferMeta
param : [x, y]
kernel :
func : add
......@@ -40,7 +39,6 @@
output : Tensor
infer_meta :
func : ElementwiseInferMeta
param : [x, y]
kernel :
func : divide
......@@ -126,7 +124,6 @@
output : Tensor
infer_meta :
func : ReduceInferMeta
param: [x, axis, keep_dim]
kernel :
func : mean
......@@ -135,7 +132,6 @@
output : Tensor
infer_meta :
func : ElementwiseInferMeta
param : [x, y]
kernel :
func : multiply
......@@ -174,7 +170,6 @@
output : Tensor
infer_meta :
func : ElementwiseInferMeta
param : [x, y]
kernel :
func : subtract
......@@ -183,7 +178,6 @@
output : Tensor
infer_meta :
func : SumInferMeta
param: [x, axis, dtype, keep_dim]
kernel :
func : sum
param : [x, axis, dtype, keep_dim]
......
......@@ -60,14 +60,6 @@ class ForwardAPI(BaseAPI):
return kernel_output, output_names, output_create
def gene_infer_meta_register(self):
if self.is_base_api:
return f"""
PT_REGISTER_INFER_META_FN({self.kernel['func']}, pten::{self.infer_meta['func']});"""
else:
return ''
def header_include():
return """
......@@ -91,7 +83,6 @@ def source_include(header_file_path):
#include "paddle/pten/api/lib/data_transform.h"
#include "paddle/pten/api/lib/kernel_dispatch.h"
#include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/core/infermeta_utils.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/infermeta/binary.h"
#include "paddle/pten/infermeta/multiary.h"
......@@ -136,21 +127,16 @@ def generate_api(api_yaml_path, header_file_path, source_file_path):
source_file.write(source_include(include_header_file))
source_file.write(namespace[0])
infer_meta_register_code = ''
for api in apis:
api_code = ForwardAPI(api)
print(api_code.gene_api_declaration())
header_file.write(api_code.gene_api_declaration())
source_file.write(api_code.gene_api_code())
infer_meta_register_code = infer_meta_register_code + api_code.gene_infer_meta_register(
)
header_file.write(namespace[1])
source_file.write(namespace[1])
source_file.write(api_register())
source_file.write(infer_meta_register_code)
header_file.close()
source_file.close()
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import yaml
import argparse
from api_base import BaseAPI
def get_wrapped_infermeta_name(api_name):
return api_name.capitalize() + 'InferMeta'
def gene_wrapped_infermeta_and_register(api):
if api.is_base_api:
register_code = f"""
PT_REGISTER_INFER_META_FN({api.kernel['func']}, pten::{api.infer_meta['func']});"""
if api.infer_meta['param'] is not None:
tensor_type_map = {
'const Tensor&': 'const MetaTensor&',
'const std::vector<Tensor>&': 'const std::vector<MetaTensor>&',
'Tensor': 'MetaTensor*',
'std::vector<Tensor>': 'std::vector<MetaTensor>*',
}
wrapped_infermeta_name = get_wrapped_infermeta_name(api.api)
args = []
check_args = []
for input_name in api.inputs['names']:
args.append(tensor_type_map[api.inputs['input_info'][
input_name]] + ' ' + input_name)
check_args.append(input_name)
for attr_name in api.attrs['names']:
args.append(api.attrs['attr_info'][attr_name][0] + ' ' +
attr_name)
check_args.append(attr_name)
for i, out_type in enumerate(api.outputs['types']):
args.append(tensor_type_map[out_type] + ' ' + api.outputs[
'names'][i])
if check_args == api.infer_meta['param']:
return '', '', register_code
invoke_param = api.infer_meta['param']
invoke_param.extend(api.outputs['names'])
declare_code = f"""
void {wrapped_infermeta_name}({", ".join(args)});
"""
defind_code = f"""
void {wrapped_infermeta_name}({", ".join(args)}) {{
{api.infer_meta['func']}({", ".join(invoke_param)});
}}
"""
register_code = f"""
PT_REGISTER_INFER_META_FN({api.kernel['func']}, pten::{get_wrapped_infermeta_name(api.kernel['func'])});"""
return declare_code, defind_code, register_code
else:
return '', '', register_code
else:
return '', '', ''
def gene_infermeta_register(api):
if api.is_base_api:
if api.infer_meta['param'] is None:
return f"""
PT_REGISTER_INFER_META_FN({api.kernel['func']}, pten::{api.infer_meta['func']});"""
else:
return f"""
PT_REGISTER_INFER_META_FN({api.kernel['func']}, pten::{get_wrapped_infermeta_name(api.kernel['func'])});"""
else:
return ''
def header_include():
return """
#include "paddle/pten/core/meta_tensor.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
"""
def source_include(header_file_path):
return f"""
#include "{header_file_path}"
#include "paddle/pten/core/infermeta_utils.h"
#include "paddle/pten/infermeta/binary.h"
#include "paddle/pten/infermeta/multiary.h"
#include "paddle/pten/infermeta/nullary.h"
#include "paddle/pten/infermeta/unary.h"
"""
def api_namespace():
return ("""
namespace pten {
""", """
} // namespace pten
""")
def generate_wrapped_infermeta_and_register(api_yaml_path, header_file_path,
source_file_path):
with open(api_yaml_path, 'r') as f:
apis = yaml.load(f, Loader=yaml.FullLoader)
header_file = open(header_file_path, 'w')
source_file = open(source_file_path, 'w')
namespace = api_namespace()
header_file.write("#pragma once\n")
header_file.write(header_include())
header_file.write(namespace[0])
include_header_file = "paddle/pten/infermeta/generated.h"
source_file.write(source_include(include_header_file))
source_file.write(namespace[0])
infermeta_register_code = ''
for api in apis:
api_item = BaseAPI(api)
declare_code, defind_code, register_code = gene_wrapped_infermeta_and_register(
api_item)
header_file.write(declare_code)
source_file.write(defind_code)
infermeta_register_code = infermeta_register_code + register_code
header_file.write(namespace[1])
source_file.write(namespace[1])
source_file.write(infermeta_register_code)
header_file.close()
source_file.close()
def main():
parser = argparse.ArgumentParser(
description='Generate PaddlePaddle C++ API files')
parser.add_argument(
'--api_yaml_path',
help='path to api yaml file',
default='python/paddle/utils/code_gen/api.yaml')
parser.add_argument(
'--wrapped_infermeta_header_path',
help='output of generated wrapped_infermeta header code file',
default='paddle/pten/infermeta/generated.h')
parser.add_argument(
'--wrapped_infermeta_source_path',
help='output of generated wrapped_infermeta source code file',
default='paddle/pten/infermeta/generated.cc')
options = parser.parse_args()
api_yaml_path = options.api_yaml_path
header_file_path = options.wrapped_infermeta_header_path
source_file_path = options.wrapped_infermeta_source_path
generate_wrapped_infermeta_and_register(api_yaml_path, header_file_path,
source_file_path)
if __name__ == '__main__':
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册