未验证 提交 7b70b792 编写于 作者: Z zyfncg 提交者: GitHub

【Pten】Refactor C++ API code-gen (#39408)

* refactor C++ API code-gen

* fix windows problem of C++ API
上级 224bc511
......@@ -16,7 +16,7 @@ cc_library(op_meta_info SRCS op_meta_info.cc DEPS pten_tensor)
cc_library(op_kernel_info SRCS op_kernel_info.cc DEPS pten_tensor)
set(api_gen_utils ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/gen_utils.py)
set(api_gen_base ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/api_base.py)
# forward api file
set(api_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/api_gen.py)
......@@ -49,7 +49,7 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${api_header_file_tmp} ${api_header_file}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${api_source_file_tmp} ${api_source_file}
COMMENT "copy_if_different ${api_header_file} ${api_source_file}"
DEPENDS ${api_yaml_file} ${api_gen_file} ${api_gen_utils}
DEPENDS ${api_yaml_file} ${api_gen_file} ${api_gen_base}
VERBATIM)
# generate backward api
......@@ -62,7 +62,7 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${bw_api_header_file_tmp} ${bw_api_header_file}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${bw_api_source_file_tmp} ${bw_api_source_file}
COMMENT "copy_if_different ${bw_api_header_file} ${bw_api_source_file}"
DEPENDS ${bw_api_yaml_file} ${bw_api_gen_file} ${api_gen_utils}
DEPENDS ${bw_api_yaml_file} ${bw_api_gen_file} ${api_gen_base}
VERBATIM)
cc_library(pten_data_transform SRCS data_transform.cc DEPS pten_tensor transfer_layout_kernel cast_kernel data_device_transform)
......
此差异已折叠。
......@@ -16,64 +16,19 @@ import os
import yaml
import argparse
import gen_utils
from api_base import BaseAPI
class API:
class ForwardAPI(BaseAPI):
prefix_tensor_name = 'dense_'
def __init__(self, api_item_yaml):
self.api = api_item_yaml['api']
# args:
# inputs:
# names : [], list of input names
# input_info : {input_name : type}
# attrs:
# names : [], list of attribute names
# attr_info : { attr_name : (type, default_values)}
self.args = gen_utils.parse_args(self.api, api_item_yaml['args'])
self.out_type_list, _ = gen_utils.parse_output(self.api,
api_item_yaml['output'])
self.return_type = self.out_type_list[0] if len(
self.out_type_list) == 1 else "std::tuple<" + ",".join(
self.out_type_list) + ">"
self.is_base_api = True
if 'invoke' in api_item_yaml:
self.is_base_api = False
self.invoke = api_item_yaml['invoke']
else:
self.kernel = api_item_yaml['kernel']
if 'backend' not in self.kernel or len(self.kernel['backend']) == 0:
self.kernel['backend'] = None
if 'layout' not in self.kernel or len(self.kernel['layout']) == 0:
self.kernel['layout'] = None
if 'data_type' not in self.kernel or len(self.kernel[
'data_type']) == 0:
self.kernel['data_type'] = None
if 'param' not in self.kernel:
self.kernel['param'] = None
self.infer_meta = api_item_yaml['infer_meta']
if 'param' not in self.infer_meta:
self.infer_meta['param'] = None
self.data_transform = {
'skip_transform': [],
'support_trans_dtype': []
}
if 'data_transform' in api_item_yaml:
if 'skip_transform' in api_item_yaml['data_transform']:
self.data_transform['skip_transform'] = api_item_yaml[
'data_transform']['skip_transform']
if 'support_trans_dtype' in api_item_yaml['data_transform']:
self.data_transform['support_trans_dtype'] = api_item_yaml[
'data_transform']['support_trans_dtype']
def gene_api_declaration(self):
return f"""
PADDLE_API {self.return_type} {self.api}({self.args['args_declare']});
"""
super(ForwardAPI, self).__init__(api_item_yaml)
def get_return_type(self, out_type_list):
return out_type_list[0] if len(
out_type_list) == 1 else "std::tuple<" + ",".join(
out_type_list) + ">"
def gene_output(self, output_type_list):
kernel_output = ""
......@@ -84,12 +39,12 @@ PADDLE_API {self.return_type} {self.api}({self.args['args_declare']});
kernel_output = 'dense_out'
output_names.append('dense_out')
output_create = f"""
{self.return_type} out;
{self.outputs['return_type']} out;
auto dense_out = SetKernelOutput(kernel_backend, &out);"""
elif len(output_type_list) > 1:
output_create = f"""
{self.return_type} out;"""
{self.outputs['return_type']} out;"""
for i in range(len(output_type_list)):
kernel_output = kernel_output + f'dense_out_{i}, '
......@@ -105,36 +60,6 @@ PADDLE_API {self.return_type} {self.api}({self.args['args_declare']});
return kernel_output, output_names, output_create
def gene_api_code(self):
if self.is_base_api:
input_tensors, kernel_args, kernel_signature = gen_utils.get_kernel_args(
self.args['inputs'], self.args['attrs'], self.out_type_list,
self.kernel['param'], self.data_transform)
outputs_args, output_names, output_create = self.gene_output(
self.out_type_list)
return f"""
PADDLE_API {self.return_type} {self.api}({self.args["args_define"]}) {{
{gen_utils.gene_kernel_select(self.api, self.args['inputs']['names'], self.args['attrs'], self.kernel)}
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
{input_tensors}
{output_create}
{gen_utils.gene_infer_meta(self.args['inputs']['names'], self.args['attrs']['names'], output_names, self.infer_meta)}
using kernel_signature = {kernel_signature};
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)({kernel_args}, {outputs_args});
return out;
}}
"""
else:
return f"""
PADDLE_API {self.return_type} {self.api}({self.args["args_define"]}) {{
return {self.invoke};
}}
"""
def header_include():
return """
......@@ -203,7 +128,7 @@ def generate_api(api_yaml_path, header_file_path, source_file_path):
source_file.write(namespace[0])
for api in apis:
api_code = API(api)
api_code = ForwardAPI(api)
print(api_code.gene_api_declaration())
header_file.write(api_code.gene_api_declaration())
source_file.write(api_code.gene_api_code())
......
......@@ -17,52 +17,16 @@ import yaml
import argparse
import re
import gen_utils
from api_base import BaseAPI
class BackwardAPI:
class BackwardAPI(BaseAPI):
def __init__(self, backward_item_yaml):
self.backward_api = backward_item_yaml['backward_api']
self.args, self.output_type_list, self.return_comment = self.parse_and_check_args(
backward_item_yaml['forward'], backward_item_yaml['args'],
backward_item_yaml['output'])
self.return_type = self.output_type_list[0] if len(
self.output_type_list) == 1 else "std::vector<std::vector<Tensor>>"
self.is_base_api = True
if 'invoke' in backward_item_yaml:
self.is_base_api = False
self.invoke = backward_item_yaml['invoke']
else:
self.kernel = backward_item_yaml['kernel']
if 'backend' not in self.kernel or len(self.kernel['backend']) == 0:
self.kernel['backend'] = None
if 'layout' not in self.kernel or len(self.kernel['layout']) == 0:
self.kernel['layout'] = None
if 'data_type' not in self.kernel or len(self.kernel[
'data_type']) == 0:
self.kernel['data_type'] = None
if 'param' not in self.kernel or len(self.kernel['param']) == 0:
self.kernel['param'] = None
self.infer_meta = backward_item_yaml['infer_meta']
if 'param' not in self.infer_meta or len(self.infer_meta[
'param']) == 0:
self.infer_meta['param'] = None
self.data_transform = {
'skip_transform': [],
'support_trans_dtype': []
}
if 'data_transform' in backward_item_yaml:
if 'skip_transform' in backward_item_yaml['data_transform']:
self.data_transform['skip_transform'] = backward_item_yaml[
'data_transform']['skip_transform']
if 'support_trans_dtype' in backward_item_yaml[
'data_transform']:
self.data_transform[
'support_trans_dtype'] = backward_item_yaml[
'data_transform']['support_trans_dtype']
super(BackwardAPI, self).__init__(backward_item_yaml)
self.check_args(backward_item_yaml['forward'])
def get_api_name(self, api_item_yaml):
return api_item_yaml['backward_api']
def parse_forward_config(self, forward_config):
# api_name (const Tensor& input, ... , int attr, ...) -> Tensor(out)
......@@ -71,51 +35,39 @@ class BackwardAPI:
forward_config)
api = result.group('api')
outputs = [item.strip() for item in result.group('outputs').split(',')]
forward_args = gen_utils.parse_args(api, result.group('args'))
fw_inputs, fw_attrs, _, = self.parse_input_and_attr(
api, result.group('args'))
return api, forward_args['inputs'], forward_args['attrs'], outputs
return api, fw_inputs, fw_attrs, outputs
def parse_and_check_args(self, forward_config, args_config, output_config):
def check_args(self, forward_config):
# parse the forward and backward config
_, fw_inputs, fw_attrs, fw_outputs = self.parse_forward_config(
forward_config)
bw_args = gen_utils.parse_args(self.backward_api, args_config)
# check the inputs of backward
for input in bw_args['inputs']['names']:
for input in self.inputs['names']:
if input not in fw_inputs and input not in fw_outputs:
if input.endswith('_grad'):
original_name = input[:-5]
assert original_name in fw_outputs, \
f"{self.backward_api} : Input Tensor error: the input tensor({input}) of backward should be an input or output or grad of output in forward api. \
Please check the forward of {self.backward_api} in yaml."
f"{self.api} : Input Tensor error: the input tensor({input}) of backward should be an input or output or grad of output in forward api. \
Please check the forward of {self.api} in yaml."
# check the attributes of backward
for attr in bw_args['attrs']['names']:
assert attr in fw_attrs['names'] and bw_args['attrs']['attr_info'][attr][0] == fw_attrs['attr_info'][attr][0], \
f"{self.backward_api} : Attribute error: The attribute({attr}) of backward isn't consistent with forward api. \
Please check the args of {self.backward_api} in yaml."
for attr in self.attrs['names']:
assert attr in fw_attrs['names'] and self.attrs['attr_info'][attr][0] == fw_attrs['attr_info'][attr][0], \
f"{self.api} : Attribute error: The attribute({attr}) of backward isn't consistent with forward api. \
Please check the args of {self.api} in yaml."
# check the output of backward
out_type_list, return_comment = gen_utils.parse_output(
self.backward_api, output_config)
assert len(out_type_list) <= len(fw_inputs['names']), \
f"{self.backward_api} : Output error: The number of ouputs should be less then the number of inputs of forward api. \
Please check the output of {self.backward_api} in yaml."
return bw_args, out_type_list, return_comment
def gene_api_declaration(self):
if self.return_comment:
return f"""
// {self.return_comment}
{self.return_type} {self.backward_api}({self.args['args_declare']});
"""
assert len(self.outputs['types']) <= len(fw_inputs['names']), \
f"{self.api} : Output error: The number of ouputs should be less then the number of inputs of forward api. \
Please check the output of {self.api} in yaml."
else:
return f"""
{self.return_type} {self.backward_api}({self.args['args_declare']});
"""
def get_return_type(self, out_type_list):
return out_type_list[0] if len(
out_type_list) == 1 else "std::vector<std::vector<Tensor>>"
def gene_output(self, output_type_list):
kernel_output = ""
......@@ -126,12 +78,12 @@ class BackwardAPI:
kernel_output = 'dense_out'
output_names.append('dense_out')
output_create = f"""
{self.return_type} out;
{self.outputs['return_type']} out;
auto dense_out = SetKernelOutput(kernel_backend, &out);"""
elif len(output_type_list) > 1:
output_create = f"""
{self.return_type} out({len(output_type_list)});"""
{self.outputs['return_type']} out({len(output_type_list)});"""
for i, out_type_item in enumerate(output_type_list):
kernel_output = kernel_output + f'dense_out_{i}, '
......@@ -150,58 +102,10 @@ class BackwardAPI:
else:
raise ValueError(
"{} : Output error: the output should not be empty.".format(
self.backward_api))
self.api))
return kernel_output, output_names, output_create
def gene_api_code(self):
if self.is_base_api:
input_tensors, kernel_args, kernel_signature = gen_utils.get_kernel_args(
self.args['inputs'], self.args['attrs'], self.output_type_list,
self.kernel['param'], self.data_transform)
outputs_args, output_names, output_create = self.gene_output(
self.output_type_list)
return f"""
// {self.return_comment}
{self.return_type} {self.backward_api}({self.args["args_define"]}) {{
{gen_utils.gene_kernel_select(self.backward_api, self.args['inputs']['names'], self.args['attrs'], self.kernel)}
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
{input_tensors}
{output_create}
{gen_utils.gene_infer_meta(self.args['inputs']['names'], self.args['attrs']['names'], output_names, self.infer_meta)}
using kernel_signature = {kernel_signature};
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)({kernel_args}, {outputs_args});
return out;
}}
"""
else:
inveke_func_name = self.invoke.split('(')[0].strip()
if inveke_func_name in self.args['attrs']['names']:
# Adjust the param whose name is same with api invoked.
pattern = '\W' + inveke_func_name + '[^A-Za-z0-9_(]'
def adjust_name(matched):
matched_str = matched.group()
return matched_str[0:-1] + '_val' + matched_str[-1]
invoke_code = re.sub(pattern, adjust_name, self.invoke)
params_code = re.sub(pattern, adjust_name,
self.args["args_define"])
else:
invoke_code = self.invoke
params_code = self.args["args_define"]
return f"""
// {self.return_comment}
{self.return_type} {self.backward_api}({params_code}) {{
return {invoke_code};
}}
"""
def header_include():
return """
......@@ -263,7 +167,6 @@ def generate_backward_api(backward_yaml_path, header_file_path,
for bw_api in bw_apis:
bw_api = BackwardAPI(bw_api)
# print(api_code.gene_api_declaration())
header_file.write(bw_api.gene_api_declaration())
source_file.write(bw_api.gene_api_code())
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
PREFIX_TENSOR_NAME = 'dense_'
PREFIX_META_TENSOR_NAME = 'meta_'
def parse_args(api_name, args_str):
"""
Returns:
{ inputs : {
names : [] // list of input names
input_info : { input_name : type }
}
attrs: {
names : [] // list of attribute names
attr_info : { attr_name : (type, default_value)}
}
args_declare : "str" // str of funtion params with default value. Example: (..., bool flag=false)
args_define : "str" // str of funtion params without default value. Example: (..., bool flag)
}
"""
inputs = {'names': [], 'input_info': {}}
attrs = {'names': [], 'attr_info': {}}
args_str = args_str.strip()
assert args_str.startswith('(') and args_str.endswith(')'), \
f"Args declaration should start with '(' and end with ')', please check the args of {api_name} in yaml."
args_str = args_str[1:-1]
args_list = args_str.split(',')
input_types = [
'const Tensor&', 'const Tensor &', 'const std::vector<Tensor>&',
'const std::vector<Tensor> &'
]
attr_types = ['const Scalar&', 'const Scalar &', 'const ScalarArray&', 'const ScalarArray &', \
'int', 'int32_t', 'int64_t', 'size_t', 'float', 'double', 'bool', \
'const std::vector<int64_t>&', 'Backend', 'DataLayout', 'DataType']
args_declare_str = ""
args_define_str = ""
for item in args_list:
item = item.strip()
# match the input tensor
has_input = False
for in_type in input_types:
if item.startswith(in_type):
input_name = item[len(in_type):].strip()
assert len(input_name) > 0, \
f"The input tensor name should not be empty. Please check the args of {api_name} in yaml."
assert len(attrs['names']) == 0, \
f"The input Tensor should appear before attributes. please check the position of {api_name}:input({input_name}) in yaml"
inputs['names'].append(input_name)
inputs['input_info'][input_name] = in_type
args_declare_str = args_declare_str + in_type + ' ' + input_name + ', '
args_define_str = args_define_str + in_type + ' ' + input_name + ', '
has_input = True
break
if has_input:
continue
# match the attribute
for attr_type in attr_types:
if item.startswith(attr_type):
attr_name = item[len(attr_type):].strip()
assert len(attr_name) > 0, \
f"The attribute name should not be empty. Please check the args of {api_name} in yaml."
default_value = None
if '=' in attr_name:
attr_infos = attr_name.split('=')
attr_name = attr_infos[0].strip()
default_value = attr_infos[1].strip()
default_value_str = "" if default_value is None else '=' + default_value
args_declare_str = args_declare_str + attr_type + ' ' + attr_name + default_value_str + ', '
args_define_str = args_define_str + attr_type + ' ' + attr_name + ', '
attrs['names'].append(attr_name)
attrs['attr_info'][attr_name] = (attr_type, default_value)
break
args = {
'inputs': inputs,
'attrs': attrs,
'args_declare': args_declare_str[:-2],
'args_define': args_define_str[:-2]
}
return args
def parse_output(api_name, output_config):
def parse_output_item(output_item):
alllowd_output_types = ['Tensor', 'std::vector<Tensor>']
if re.search(r'\(\w*\)', output_item):
result = re.search(
r"(?P<out_type>[a-zA-Z0-9_<>]+)\s*\((?P<name>\w+)\)",
output_item)
out_type = result.group('out_type')
assert out_type in alllowd_output_types, \
f"{api_name} : Output type error: the output type only support Tensor and std::vector<Tensor>, \
but now is {out_type}."
return out_type, result.group('name')
else:
if output_item.strip() in alllowd_output_types:
return output_item.strip(), 'out'
else:
raise ValueError(
"{} : Output type error: the output type only support Tensor and std::vector<Tensor>, \
but now is {}.".format(api_name, out_type))
temp_list = output_config.split(',')
if len(temp_list) == 1:
out_type, out_name = parse_output_item(temp_list[0])
return [out_type], out_name
else:
out_type_list = []
out_name_list = []
for output_item in temp_list:
out_type, out_name = parse_output_item(output_item)
out_type_list.append(out_type)
out_name_list.append(out_name)
return out_type_list, ", ".join(out_name_list)
def gene_kernel_select(api, input_names, attrs, kernel) -> str:
kernel_key_item_init = """
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;
"""
# Check the tensor options
attr_backend_count = 0
attr_layout_count = 0
attr_data_type_count = 0
for attr_name in attrs['names']:
if attrs['attr_info'][attr_name][0] == 'Backend':
assert kernel['backend'] is not None, \
f"{api} api: When there is a parameter with 'Backend' type in attributes, you must set backend of kernel manually."
attr_backend_count = attr_backend_count + 1
if attrs['attr_info'][attr_name][0] == 'DataLayout':
assert kernel['layout'] is not None, \
f"{api} api: When there is a parameter with 'DataLayout' type in attributes, you must set layout of kernel manually."
attr_layout_count = attr_layout_count + 1
if attrs['attr_info'][attr_name][0] == 'DataType':
assert kernel['data_type'] is not None, \
f"{api} api: When there is a parameter with 'DataType' type in attributes, you must set data_type of kernel manually."
attr_data_type_count = attr_data_type_count + 1
# preprocess kernel configures
kernel_select_code = ""
if kernel['backend'] is not None:
if '>' in kernel['backend']:
vars_list = kernel['backend'].split('>')
assert len(
vars_list
) == 2, f"{api} api: The number of params to set backend with '>' only allows 2, but received {len(vars_list)}."
assert (vars_list[0].strip() in attrs['names']) and (attrs['attr_info'][vars_list[0].strip()][0] == 'Backend'), \
f"{api} api: When use '>' to set kernel backend, the first param should be a attribute with Backend type."
kernel_select_code = kernel_select_code + f"""
kernel_backend = ParseBackendWithInputOrder({vars_list[0].strip()}, {vars_list[1].strip()});
"""
else:
args_str = ""
for ele in kernel['backend'].split(','):
args_str = args_str + ele.strip() + ', '
kernel_select_code = kernel_select_code + f"""
kernel_backend = ParseBackend({args_str[:-2]});
"""
if kernel['layout'] is not None:
if '>' in kernel['layout']:
vars_list = kernel['layout'].split('>')
assert len(
vars_list
) == 2, f"{api} api: The number of params to set layout with '>' only allows 2, but received {len(vars_list)}."
assert vars_list[0].strip() in attrs['names'] and attrs['attr_info'][vars_list[0].strip()][0] == 'DataLayout', \
f"{api} api: When use '>' to set kernel layout, the first param should be a attribute with DataLayout type."
kernel_select_code = kernel_select_code + f"""
kernel_layout = ParseLayoutWithInputOrder({vars_list[0].strip()}, {vars_list[1].strip()});
"""
else:
vars_list = kernel['layout'].split(',')
assert len(
vars_list
) == 1, f"{api} api: The number of params to set layout must be 1, but received {len(vars_list)}."
kernel_select_code = kernel_select_code + f"""
kernel_layout = ParseLayout({vars_list[0].strip()});
"""
if kernel['data_type'] is not None:
if '>' in kernel['data_type']:
vars_list = kernel['data_type'].split('>')
assert len(
vars_list
) == 2, f"{api} api: The number of params to set data_type with '>' only allows 2, but received {len(vars_list)}."
assert vars_list[0].strip() in attrs['names'] and attrs['attr_info'][vars_list[0].strip()][0] == 'DataType', \
f"{api} api: When use '>' to set kernel data_type, the first param should be a attribute with DataType type."
kernel_select_code = kernel_select_code + f"""
kernel_data_type = ParseDataTypeWithInputOrder({vars_list[0].strip()}, {vars_list[1].strip()});
"""
else:
vars_list = kernel['data_type'].split(',')
assert len(
vars_list
) == 1, f"{api} api: The number of params to set data_type only allows 2, but received {len(vars_list)}."
kernel_select_code = kernel_select_code + f"""
kernel_data_type = ParseDataType({vars_list[0].strip()});
"""
if len(input_names) == 0:
assert attr_backend_count > 0 and attr_layout_count > 0 and attr_data_type_count > 0, \
f"{api} api: When there is no input tensor, the args must have 'Backend', 'DataLayout' and 'DataType'."
kernel_select_args = ""
for input_name in input_names:
kernel_select_args = kernel_select_args + input_name + ", "
if len(kernel_select_args) > 2:
kernel_select_args = kernel_select_args[:-2]
kernel_select_code = kernel_key_item_init + kernel_select_code
if len(input_names) > 0:
kernel_select_code = kernel_select_code + f"""
if (kernel_backend == Backend::UNDEFINED
|| kernel_layout == DataLayout::UNDEFINED
|| kernel_data_type == DataType::UNDEFINED ) {{
auto kernel_key_set = ParseKernelKeyByInputArgs({kernel_select_args});
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {{
kernel_backend = kernel_key.backend();
}}
if (kernel_layout == DataLayout::UNDEFINED) {{
kernel_layout = kernel_key.layout();
}}
if (kernel_data_type == DataType::UNDEFINED) {{
kernel_data_type = kernel_key.dtype();
}}
}}"""
kernel_select_code = kernel_select_code + f"""
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"{kernel['func']}", {{kernel_backend, kernel_layout, kernel_data_type}});
VLOG(6) << "{api} API kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
VLOG(6) << "{api} API kernel: " << kernel;"""
return kernel_select_code
def gene_infer_meta(input_names, attr_names, output_names, infer_meta) -> str:
infer_meta_params = infer_meta['param'] + output_names if infer_meta[
'param'] is not None else input_names + attr_names + output_names
# generate meta tensors
meta_tensor_code = ""
param_code = ""
for param in infer_meta_params:
if param in input_names:
param_code = param_code + "MakeMetaTensor(*" + PREFIX_TENSOR_NAME + param + "), "
elif param in output_names:
meta_tensor_code = meta_tensor_code + " pten::MetaTensor " + param.replace(
PREFIX_TENSOR_NAME,
PREFIX_META_TENSOR_NAME) + "(" + param + ");\n"
param_code = param_code + "&" + param.replace(
PREFIX_TENSOR_NAME, PREFIX_META_TENSOR_NAME) + ", "
elif param in attr_names:
param_code = param_code + param + ", "
elif isinstance(param, str):
param_code = param_code + "\"" + param + "\", "
elif isinstance(param, bool):
param_code = param_code + str(param).lower() + ", "
else:
param_code = param_code + str(param) + ", "
param_code = param_code[:-2]
return f"""{meta_tensor_code}
pten::{infer_meta['func']}({param_code});
"""
def get_kernel_args(inputs, attrs, out_type_list, kernel_param, data_transform):
input_trans_map = {
'const Tensor&': 'const pten::DenseTensor&',
'const Tensor &': 'const pten::DenseTensor&',
'const std::vector<Tensor>&': 'const std::vector<pten::DenseTensor>&',
'const std::vector<Tensor> &': 'const std::vector<pten::DenseTensor>&'
}
out_trans_map = {
'Tensor': 'pten::DenseTensor*',
'std::vector<Tensor>': 'std::vector<pten::DenseTensor*>&'
}
input_names = inputs['names']
input_infos = inputs['input_info']
kernel_args_type_list = ['const platform::DeviceContext&']
input_tensor_code = ""
for input_name in input_names:
# set input code
input_tensor_code = input_tensor_code + f"""
auto {PREFIX_TENSOR_NAME}{input_name} = TensorToDenseTensor({input_name});"""
attr_names = attrs['names']
if kernel_param is None:
kernel_param = input_names + attr_names
input_tensor_code = ""
for i, input_name in enumerate(input_names):
# set input code
if input_name in kernel_param:
trans_flag = "{}"
if input_name in data_transform['skip_transform']:
trans_flag = "{true}"
elif input_name in data_transform['support_trans_dtype']:
trans_flag = "{false, true}"
input_tensor_code = input_tensor_code + f"""
auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});"""
else:
input_tensor_code = input_tensor_code + f"""
auto {PREFIX_TENSOR_NAME}{input_name} = TensorToDenseTensor({input_name});"""
kernel_args = "*dev_ctx, "
for param in kernel_param:
if param in input_names:
kernel_args = kernel_args + "*" + PREFIX_TENSOR_NAME + param + ", "
kernel_args_type_list.append(input_trans_map[input_infos[param]])
elif param in attr_names:
# set attr for kernel_context
if 'ScalarArray' in attrs['attr_info'][param][0]:
kernel_args_type_list.append('const pten::ScalarArray&')
param = 'pten::ScalarArray(' + param + ')'
elif 'Scalar' in attrs['attr_info'][param][0]:
kernel_args_type_list.append('const pten::Scalar&')
param = 'pten::Scalar(' + param + ')'
else:
kernel_args_type_list.append(attrs['attr_info'][param][0])
kernel_args = kernel_args + param + ", "
elif isinstance(param, bool):
kernel_args = kernel_args + str(param).lower() + ", "
else:
kernel_args = kernel_args + str(param) + ", "
for out_type in out_type_list:
kernel_args_type_list.append(out_trans_map[out_type])
kernel_signature = "void(*)(" + ", ".join(kernel_args_type_list) + ")"
return input_tensor_code, kernel_args[:-2], kernel_signature
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册