未验证 提交 fc5fa0de 编写于 作者: Z zyfncg 提交者: GitHub

Auto-geneate kernel signature in C++ API (#39281)

上级 543f3dea
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/platform/device_context.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/dense_tensor.h"
// This header is used to cast kernel function from void* to original form of
// function Currnetly.
// It may be generated automatically in the future.
namespace pten {
using DeviceContext = paddle::platform::DeviceContext;
using add_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
const DenseTensor&,
DenseTensor*);
using cast_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
DataType,
DenseTensor*);
using concat_kernel = void (*)(const DeviceContext&,
const std::vector<DenseTensor>&,
const Scalar&,
DenseTensor*);
using divide_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
const DenseTensor&,
DenseTensor*);
using dot_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
const DenseTensor&,
DenseTensor*);
using flatten_kernel =
void (*)(const DeviceContext&, const DenseTensor&, int, int, DenseTensor*);
using empty_kernel = void (*)(const DeviceContext&,
const ScalarArray&,
DenseTensor*);
using empty_like_kernel = void (*)(const DeviceContext&, DenseTensor*);
using full_kernel = void (*)(const DeviceContext&,
const ScalarArray&,
const Scalar&,
DenseTensor*);
using full_like_kernel = void (*)(const DeviceContext&,
const Scalar&,
DenseTensor*);
using matmul_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
const DenseTensor&,
bool,
bool,
DenseTensor*);
using mean_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
const std::vector<int64_t>&,
bool,
DenseTensor*);
using multiply_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
const DenseTensor&,
DenseTensor*);
using reshape_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
const ScalarArray&,
DenseTensor*);
using scale_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
const Scalar&,
float,
bool,
DenseTensor*);
using sum_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
const std::vector<int64_t>&,
DataType,
bool,
DenseTensor*);
using subtract_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
const DenseTensor&,
DenseTensor*);
using conj_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
DenseTensor*);
/* -------------- Grad Kernel ----------------- */
using matmul_grad_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
const DenseTensor&,
const DenseTensor&,
bool,
bool,
DenseTensor*,
DenseTensor*);
} // namespace pten
......@@ -27,6 +27,7 @@ class API:
# args:
# inputs:
# names : [], list of input names
# input_info : {input_name : type}
# attrs:
# names : [], list of attribute names
# attr_info : { attr_name : (type, default_values)}
......@@ -91,8 +92,8 @@ PADDLE_API {self.return_type} {self.api}({self.args['args_declare']});
def gene_api_code(self):
if self.is_base_api:
input_tensors, kernel_args = gen_utils.get_kernel_args(
self.args['inputs']['names'], self.args['attrs'],
input_tensors, kernel_args, kernel_signature = gen_utils.get_kernel_args(
self.args['inputs'], self.args['attrs'], self.out_type_list,
self.kernel['param'])
outputs_args, output_create = self.gene_output(self.out_type_list)
return f"""
......@@ -103,8 +104,8 @@ PADDLE_API {self.return_type} {self.api}({self.args["args_define"]}) {{
{input_tensors}
{gen_utils.gene_infer_meta(self.args['inputs']['names'], self.args['attrs']['names'], self.infer_meta)}
{output_create}
auto* kernel_fn = kernel.GetVariadicKernelFn<pten::{self.api}_kernel>();
using kernel_signature = {kernel_signature};
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)({kernel_args}, {outputs_args});
return out;
......@@ -136,7 +137,6 @@ def source_include(header_file_path):
#include "glog/logging.h"
#include "paddle/pten/api/include/kernel_signature.h"
#include "paddle/pten/api/lib/api_registry.h"
#include "paddle/pten/api/lib/api_utils.h"
#include "paddle/pten/api/lib/kernel_dispatch.h"
......
......@@ -108,7 +108,6 @@ class BackwardAPI:
output_create = ""
if len(output_type_list) == 1:
return_type = output_type_list[0]
kernel_output = 'dense_out'
output_create = f"""
{self.return_type} out;
......@@ -116,11 +115,17 @@ class BackwardAPI:
elif len(output_type_list) > 1:
output_create = f"""
{self.return_type} out;"""
{self.return_type} out({len(output_type_list)});"""
for i, out_type_item in enumerate(output_type_list):
kernel_output = kernel_output + f'dense_out_{i}, '
get_out_code = f'&out[{i}][0]' if out_type_item == 'Tensor' else f'&out[{i}]'
if out_type_item == 'Tensor':
get_out_code = f'&out[{i}][0]'
output_create = output_create + f"""
out[{i}].emplace_back();"""
else:
get_out_code = f'&out[{i}]'
output_create = output_create + f"""
auto dense_out_{i} = SetKernelOutput(std::get<{i}>(out_meta), kernel_backend, {get_out_code});"""
......@@ -134,8 +139,8 @@ class BackwardAPI:
def gene_api_code(self):
if self.is_base_api:
input_tensors, kernel_args = gen_utils.get_kernel_args(
self.args['inputs']['names'], self.args['attrs'],
input_tensors, kernel_args, kernel_signature = gen_utils.get_kernel_args(
self.args['inputs'], self.args['attrs'], self.output_type_list,
self.kernel['param'])
outputs_args, output_create = self.gene_output(
self.output_type_list)
......@@ -149,7 +154,8 @@ class BackwardAPI:
{gen_utils.gene_infer_meta(self.args['inputs']['names'], self.args['attrs']['names'], self.infer_meta)}
{output_create}
auto* kernel_fn = kernel.GetVariadicKernelFn<pten::{self.backward_api}_kernel>();
using kernel_signature = {kernel_signature};
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)({kernel_args}, {outputs_args});
return out;
......@@ -197,7 +203,6 @@ def source_include(header_file_path):
#include "glog/logging.h"
#include "paddle/pten/api/include/kernel_signature.h"
#include "paddle/pten/api/lib/api_registry.h"
#include "paddle/pten/api/lib/api_utils.h"
#include "paddle/pten/api/lib/kernel_dispatch.h"
......
......@@ -287,7 +287,21 @@ def gene_infer_meta(input_names, attr_names, infer_meta) -> str:
"""
def get_kernel_args(input_names, attrs, kernel_param):
def get_kernel_args(inputs, attrs, out_type_list, kernel_param):
input_trans_map = {
'const Tensor&': 'const pten::DenseTensor&',
'const Tensor &': 'const pten::DenseTensor&',
'const std::vector<Tensor>&': 'const std::vector<pten::DenseTensor>&',
'const std::vector<Tensor> &': 'const std::vector<pten::DenseTensor>&'
}
out_trans_map = {
'Tensor': 'pten::DenseTensor*',
'std::vector<Tensor>': 'std::vector<pten::DenseTensor*>&'
}
input_names = inputs['names']
input_infos = inputs['input_info']
kernel_args_type_list = ['const platform::DeviceContext&']
input_tensor_code = ""
for input_name in input_names:
# set input code
......@@ -302,15 +316,26 @@ def get_kernel_args(input_names, attrs, kernel_param):
for param in kernel_param:
if param in input_names:
kernel_args = kernel_args + "*" + PREFIX_TENSOR_NAME + param + ", "
kernel_args_type_list.append(input_trans_map[input_infos[param]])
elif param in attr_names:
# set attr for kernel_context
if 'ScalarArray' in attrs['attr_info'][param][0]:
kernel_args_type_list.append('const pten::ScalarArray&')
param = 'pten::ScalarArray(' + param + ')'
elif 'Scalar' in attrs['attr_info'][param][0]:
kernel_args_type_list.append('const pten::Scalar&')
param = 'pten::Scalar(' + param + ')'
else:
kernel_args_type_list.append(attrs['attr_info'][param][0])
kernel_args = kernel_args + param + ", "
elif isinstance(param, bool):
kernel_args = kernel_args + str(param).lower() + ", "
else:
kernel_args = kernel_args + str(param) + ", "
return input_tensor_code, kernel_args[:-2]
for out_type in out_type_list:
kernel_args_type_list.append(out_trans_map[out_type])
kernel_signature = "void(*)(" + ", ".join(kernel_args_type_list) + ")"
return input_tensor_code, kernel_args[:-2], kernel_signature
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册