未验证 提交 bcc5ce44 编写于 作者: W WangZhen 提交者: GitHub

Add more gen api (#56291)

上级 04b6035d
......@@ -14,6 +14,7 @@
import argparse
import os
import re
import yaml
from op_gen import OpCompatParser, OpInfoParser, to_pascal_case
......@@ -27,6 +28,7 @@ H_FILE_TEMPLATE = """
#include "paddle/ir/core/value.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/scalar.h"
{body}
......@@ -62,18 +64,38 @@ API_IMPL_TEMPLATE = """
{in_combine}
{compute_op}
{out_slice}
{out_combine}
{return_result}
}}
"""
COMBINE_OP_TEMPLATE = """auto {op_name} = APIBuilder::Instance().GetBuilder()->Build<ir::CombineOp>({in_name});"""
COMPUTE_OP_TEMPLATE = """paddle::dialect::{op_class_name} {op_inst_name} = APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::{op_class_name}>({args});"""
API_LIST = ['add_n', 'mean', 'sum', 'divide', 'full', 'tanh_grad', 'mean_grad']
COMBINE_OP_TEMPLATE = """
auto {op_name} = APIBuilder::Instance().GetBuilder()->Build<ir::CombineOp>({in_name});"""
SLICE_OP_TEMPLATE = """
auto {op_name} = APIBuilder::Instance().GetBuilder()->Build<ir::SliceOp>({in_name});"""
COMPUTE_OP_TEMPLATE = """
paddle::dialect::{op_class_name} {op_inst_name} = APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::{op_class_name}>({args});"""
API_LIST = [
'add_n',
'mean',
'sum',
'divide',
'full',
'tanh_grad',
'mean_grad',
'concat',
'add',
'multiply',
'elementwise_pow',
'scale',
'reshape',
'expand',
'tile',
'add_grad',
]
OP_RESULT = 'ir::OpResult'
VECTOR_TYPE = 'ir::VectorType'
......@@ -86,6 +108,7 @@ class CodeGen:
def __init__(self) -> None:
self._type_map = {
'paddle::dialect::DenseTensorType': 'ir::OpResult',
'paddle::dialect::SelectedRowsType': 'ir::OpResult',
'ir::VectorType<paddle::dialect::DenseTensorType>': 'std::vector<ir::OpResult>',
}
......@@ -126,6 +149,8 @@ class CodeGen:
name_list, type_list, default_value_list
):
if with_default and default_value is not None:
if type in ['float', 'double']:
default_value = default_value.strip('"')
ret.append(
'{type} {name} = {default_value}'.format(
type=type, name=name, default_value=default_value
......@@ -140,9 +165,19 @@ class CodeGen:
attrs = self._gen_api_attrs(op_info, with_default_attr)
return (inputs + ', ' + attrs).strip(', ')
def _gen_ret_type(self, op_info):
type_list = op_info.output_type_list
assert len(type_list) >= 1
if len(type_list) > 1:
return 'std::tuple<{}>'.format(
', '.join([self._type_map[type] for type in type_list])
)
elif len(type_list) == 1:
return self._type_map[type_list[0]]
def _gen_one_declare(self, op_info, op_name):
return API_DECLARE_TEMPLATE.format(
ret_type=OP_RESULT,
ret_type=self._gen_ret_type(op_info),
api_name=op_name,
args=self._gen_api_args(op_info, True),
)
......@@ -205,33 +240,51 @@ class CodeGen:
op_inst_name,
)
def _gen_out_slice(self):
return ''
def _gen_out_slice_and_ret_list(self, op_info, op_inst_name):
name_list = op_info.output_name_list
type_list = op_info.output_type_list
def _gen_out_combine(self):
return ''
slice_op_str = ''
ret_list = []
for i, (name, type) in enumerate(zip(name_list, type_list)):
if VECTOR_TYPE in type:
slice_op_name = f'{name}_slice_op'
slice_op_str += SLICE_OP_TEMPLATE.format(
op_name=slice_op_name, in_name=f'{op_inst_name}.result({i})'
)
ret_list.append(f'{slice_op_name}.outputs()')
else:
ret_list.append(f'{op_inst_name}.result({i})')
return slice_op_str, ret_list
def _gen_return_result(self, op_info, op_inst_name):
output_name_list = op_info.output_name_list
assert len(output_name_list) == 1
return f'return {op_inst_name}.result(0);'
def _gen_return_result(self, ret_list):
assert len(ret_list) >= 1
if len(ret_list) > 1:
return 'return std::make_tuple({});'.format(', '.join(ret_list))
else:
return f'return {ret_list[0]};'
def _gen_one_impl(self, op_info, op_name):
in_combine, in_combine_op_list = self._gen_in_combine(op_info)
compute_op, op_inst_name = self._gen_compute_op(
op_info, op_name, in_combine_op_list
)
out_slice, ret_list = self._gen_out_slice_and_ret_list(
op_info, op_inst_name
)
return API_IMPL_TEMPLATE.format(
ret_type=OP_RESULT,
ret = API_IMPL_TEMPLATE.format(
ret_type=self._gen_ret_type(op_info),
api_name=op_name,
args=self._gen_api_args(op_info, False),
in_combine=in_combine,
compute_op=compute_op,
out_slice=self._gen_out_slice(),
out_combine=self._gen_out_combine(),
return_result=self._gen_return_result(op_info, op_inst_name),
).replace(' \n', '')
out_slice=out_slice,
return_result=self._gen_return_result(ret_list),
)
ret = re.sub(r' +\n', '', ret)
return ret
def _gen_cpp_file(self, op_info_items, namespaces, cpp_file_path):
impl_str = ''
......
......@@ -259,7 +259,7 @@ class OpInfoParser:
'bool': ['ir::BoolAttribute', 'bool'],
'bool[]': [
'ir::ArrayAttribute<ir::BoolAttribute>',
'const std::vecot<bool>&',
'const std::vector<bool>&',
],
'str': ['ir::StrAttribute', 'const std::string&'],
'str[]': [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册