diff --git a/paddle/fluid/ir/dialect/op_generator/api_gen.py b/paddle/fluid/ir/dialect/op_generator/api_gen.py index 4dbe7d33540c90d4e7b3e92c19a0a41c619627aa..7680ddfb1228ce1a923207588d38f3ff579c4d85 100644 --- a/paddle/fluid/ir/dialect/op_generator/api_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/api_gen.py @@ -14,6 +14,7 @@ import argparse import os +import re import yaml from op_gen import OpCompatParser, OpInfoParser, to_pascal_case @@ -27,6 +28,7 @@ H_FILE_TEMPLATE = """ #include "paddle/ir/core/value.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/place.h" +#include "paddle/phi/common/scalar.h" {body} @@ -62,18 +64,38 @@ API_IMPL_TEMPLATE = """ {in_combine} {compute_op} {out_slice} - {out_combine} {return_result} }} """ -COMBINE_OP_TEMPLATE = """auto {op_name} = APIBuilder::Instance().GetBuilder()->Build({in_name});""" - -COMPUTE_OP_TEMPLATE = """paddle::dialect::{op_class_name} {op_inst_name} = APIBuilder::Instance().GetBuilder()->Build({args});""" - - -API_LIST = ['add_n', 'mean', 'sum', 'divide', 'full', 'tanh_grad', 'mean_grad'] +COMBINE_OP_TEMPLATE = """ + auto {op_name} = APIBuilder::Instance().GetBuilder()->Build({in_name});""" + +SLICE_OP_TEMPLATE = """ + auto {op_name} = APIBuilder::Instance().GetBuilder()->Build({in_name});""" + +COMPUTE_OP_TEMPLATE = """ + paddle::dialect::{op_class_name} {op_inst_name} = APIBuilder::Instance().GetBuilder()->Build({args});""" + +API_LIST = [ + 'add_n', + 'mean', + 'sum', + 'divide', + 'full', + 'tanh_grad', + 'mean_grad', + 'concat', + 'add', + 'multiply', + 'elementwise_pow', + 'scale', + 'reshape', + 'expand', + 'tile', + 'add_grad', +] OP_RESULT = 'ir::OpResult' VECTOR_TYPE = 'ir::VectorType' @@ -86,6 +108,7 @@ class CodeGen: def __init__(self) -> None: self._type_map = { 'paddle::dialect::DenseTensorType': 'ir::OpResult', + 'paddle::dialect::SelectedRowsType': 'ir::OpResult', 'ir::VectorType': 'std::vector', } @@ -126,6 +149,8 @@ class CodeGen: name_list, type_list, default_value_list ): if with_default and default_value is not None: + if type in ['float', 'double']: + default_value = default_value.strip('"') ret.append( '{type} {name} = {default_value}'.format( type=type, name=name, default_value=default_value @@ -140,9 +165,19 @@ class CodeGen: attrs = self._gen_api_attrs(op_info, with_default_attr) return (inputs + ', ' + attrs).strip(', ') + def _gen_ret_type(self, op_info): + type_list = op_info.output_type_list + assert len(type_list) >= 1 + if len(type_list) > 1: + return 'std::tuple<{}>'.format( + ', '.join([self._type_map[type] for type in type_list]) + ) + elif len(type_list) == 1: + return self._type_map[type_list[0]] + def _gen_one_declare(self, op_info, op_name): return API_DECLARE_TEMPLATE.format( - ret_type=OP_RESULT, + ret_type=self._gen_ret_type(op_info), api_name=op_name, args=self._gen_api_args(op_info, True), ) @@ -205,33 +240,51 @@ class CodeGen: op_inst_name, ) - def _gen_out_slice(self): - return '' + def _gen_out_slice_and_ret_list(self, op_info, op_inst_name): + name_list = op_info.output_name_list + type_list = op_info.output_type_list - def _gen_out_combine(self): - return '' + slice_op_str = '' + ret_list = [] + for i, (name, type) in enumerate(zip(name_list, type_list)): + if VECTOR_TYPE in type: + slice_op_name = f'{name}_slice_op' + slice_op_str += SLICE_OP_TEMPLATE.format( + op_name=slice_op_name, in_name=f'{op_inst_name}.result({i})' + ) + ret_list.append(f'{slice_op_name}.outputs()') + else: + ret_list.append(f'{op_inst_name}.result({i})') + return slice_op_str, ret_list - def _gen_return_result(self, op_info, op_inst_name): - output_name_list = op_info.output_name_list - assert len(output_name_list) == 1 - return f'return {op_inst_name}.result(0);' + def _gen_return_result(self, ret_list): + assert len(ret_list) >= 1 + if len(ret_list) > 1: + return 'return std::make_tuple({});'.format(', '.join(ret_list)) + else: + return f'return {ret_list[0]};' def _gen_one_impl(self, op_info, op_name): in_combine, in_combine_op_list = self._gen_in_combine(op_info) compute_op, op_inst_name = self._gen_compute_op( op_info, op_name, in_combine_op_list ) + out_slice, ret_list = self._gen_out_slice_and_ret_list( + op_info, op_inst_name + ) - return API_IMPL_TEMPLATE.format( - ret_type=OP_RESULT, + ret = API_IMPL_TEMPLATE.format( + ret_type=self._gen_ret_type(op_info), api_name=op_name, args=self._gen_api_args(op_info, False), in_combine=in_combine, compute_op=compute_op, - out_slice=self._gen_out_slice(), - out_combine=self._gen_out_combine(), - return_result=self._gen_return_result(op_info, op_inst_name), - ).replace(' \n', '') + out_slice=out_slice, + return_result=self._gen_return_result(ret_list), + ) + + ret = re.sub(r' +\n', '', ret) + return ret def _gen_cpp_file(self, op_info_items, namespaces, cpp_file_path): impl_str = '' diff --git a/paddle/fluid/ir/dialect/op_generator/op_gen.py b/paddle/fluid/ir/dialect/op_generator/op_gen.py index a204d64b00f48cc5133ba2732e45ce5612c40c31..5bbb5c80c0693d04da01988382be076719fde17c 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_gen.py @@ -259,7 +259,7 @@ class OpInfoParser: 'bool': ['ir::BoolAttribute', 'bool'], 'bool[]': [ 'ir::ArrayAttribute', - 'const std::vecot&', + 'const std::vector&', ], 'str': ['ir::StrAttribute', 'const std::string&'], 'str[]': [