diff --git a/paddle/fluid/ir/dialect/op_generator/api_gen.py b/paddle/fluid/ir/dialect/op_generator/api_gen.py index 8432ebf2fade9706b72e3fe78d9ee9fa19d06e29..b9ece187beea5c9b6d7bc7b3148bfd72bfaf8ba7 100644 --- a/paddle/fluid/ir/dialect/op_generator/api_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/api_gen.py @@ -64,7 +64,7 @@ API_IMPL_TEMPLATE = """ {ret_type} {api_name}({args}){{ {in_combine} {compute_op} - {out_slice} + {out_split} {return_result} }} @@ -73,34 +73,15 @@ API_IMPL_TEMPLATE = """ COMBINE_OP_TEMPLATE = """ auto {op_name} = APIBuilder::Instance().GetBuilder()->Build({in_name});""" -SLICE_OP_TEMPLATE = """ - auto {op_name} = APIBuilder::Instance().GetBuilder()->Build({in_name});""" +SPLIT_OP_TEMPLATE = """ + auto {op_name} = APIBuilder::Instance().GetBuilder()->Build({in_name});""" COMPUTE_OP_TEMPLATE = """ paddle::dialect::{op_class_name} {op_inst_name} = APIBuilder::Instance().GetBuilder()->Build({args});""" -API_LIST = [ - 'add_n', - 'mean', - 'sum', - 'divide', - 'full', - 'tanh_grad', - 'mean_grad', - 'concat', - 'add', - 'multiply', - 'elementwise_pow', - 'scale', - 'reshape', - 'expand', - 'tile', - 'add_grad', - 'divide_grad', - 'sum_grad', -] OP_RESULT = 'ir::OpResult' VECTOR_TYPE = 'ir::VectorType' +PD_MANUAL_OP_LIST = ['add_n'] def get_op_class_name(op_name): @@ -142,56 +123,70 @@ class CodeGen: ret.append(f'{self._type_map[type]} {name}') return ', '.join(ret) - def _gen_api_attrs(self, op_info, with_default): + def _gen_api_attrs(self, op_info, with_default, is_mutable_attr): name_list = op_info.attribute_name_list type_list = op_info.attribute_build_arg_type_list default_value_list = op_info.attribute_default_value_list + mutable_name_list = op_info.mutable_attribute_name_list assert len(name_list) == len(type_list) == len(default_value_list) - ret = [] + no_mutable_attr = [] + mutable_attr = [] for name, type, default_value in zip( name_list, type_list, default_value_list ): + if is_mutable_attr and name in mutable_name_list: + mutable_attr.append(f'{OP_RESULT} {name}') + continue if with_default and default_value is not None: if type in ['float', 'double']: default_value = default_value.strip('"') - ret.append( + no_mutable_attr.append( '{type} {name} = {default_value}'.format( type=type, name=name, default_value=default_value ) ) else: - ret.append(f'{type} {name}') - return ', '.join(ret) + no_mutable_attr.append(f'{type} {name}') + return ', '.join(mutable_attr + no_mutable_attr) - def _gen_api_args(self, op_info, with_default_attr): + def _gen_api_args(self, op_info, with_default_attr, is_mutable_attr): inputs = self._gen_api_inputs(op_info) - attrs = self._gen_api_attrs(op_info, with_default_attr) + attrs = self._gen_api_attrs(op_info, with_default_attr, is_mutable_attr) return (inputs + ', ' + attrs).strip(', ') def _gen_ret_type(self, op_info): type_list = op_info.output_type_list - assert len(type_list) >= 1 if len(type_list) > 1: return 'std::tuple<{}>'.format( ', '.join([self._type_map[type] for type in type_list]) ) elif len(type_list) == 1: return self._type_map[type_list[0]] + elif len(type_list) == 0: + return 'void' - def _gen_one_declare(self, op_info, op_name): + def _gen_one_declare(self, op_info, op_name, is_mutable_attr): return API_DECLARE_TEMPLATE.format( ret_type=self._gen_ret_type(op_info), api_name=op_name, - args=self._gen_api_args(op_info, True), + args=self._gen_api_args(op_info, True, is_mutable_attr), ) def _gen_h_file(self, op_info_items, namespaces, h_file_path): declare_str = '' for op_info in op_info_items: for op_name in op_info.op_phi_name: - if op_name not in API_LIST: + # NOTE:When infer_meta_func is None, the Build() function generated in pd_op + # is wrong, so temporarily skip the automatic generation of these APIs + if ( + op_info.infer_meta_func is None + and op_name not in PD_MANUAL_OP_LIST + ): continue - declare_str += self._gen_one_declare(op_info, op_name) + declare_str += self._gen_one_declare(op_info, op_name, False) + if len(op_info.mutable_attribute_name_list) > 0: + declare_str += self._gen_one_declare(op_info, op_name, True) + body = declare_str for namespace in reversed(namespaces): body = NAMESPACE_TEMPLATE.format(namespace=namespace, body=body) @@ -218,9 +213,13 @@ class CodeGen: combine_op_list.append(None) return combine_op, combine_op_list - def _gen_compute_op_args(self, op_info, in_combine_op_list): + def _gen_compute_op_args( + self, op_info, in_combine_op_list, is_mutable_attr + ): input_name_list = op_info.input_name_list - attribute_name_list = op_info.attribute_name_list + all_attr_list = op_info.attribute_name_list + no_mutable_attr_list = op_info.non_mutable_attribute_name_list + mutable_attr_list = op_info.mutable_attribute_name_list assert len(input_name_list) == len(in_combine_op_list) ret = [] for input_name, combine_op in zip(input_name_list, in_combine_op_list): @@ -228,61 +227,69 @@ class CodeGen: ret.append(input_name) else: ret.append(f'{combine_op}.out()') - ret += list(attribute_name_list) + if is_mutable_attr: + ret += list(mutable_attr_list + no_mutable_attr_list) + else: + ret += list(all_attr_list) return ', '.join(ret) - def _gen_compute_op(self, op_info, op_name, in_combine_op_list): + def _gen_compute_op( + self, op_info, op_name, in_combine_op_list, is_mutable_attr + ): op_class_name = to_pascal_case(op_name) + 'Op' op_inst_name = op_name + '_op' return ( COMPUTE_OP_TEMPLATE.format( op_class_name=op_class_name, op_inst_name=op_inst_name, - args=self._gen_compute_op_args(op_info, in_combine_op_list), + args=self._gen_compute_op_args( + op_info, in_combine_op_list, is_mutable_attr + ), ), op_inst_name, ) - def _gen_out_slice_and_ret_list(self, op_info, op_inst_name): + def _gen_out_split_and_ret_list(self, op_info, op_inst_name): name_list = op_info.output_name_list type_list = op_info.output_type_list - slice_op_str = '' + split_op_str = '' ret_list = [] for i, (name, type) in enumerate(zip(name_list, type_list)): if VECTOR_TYPE in type: - slice_op_name = f'{name}_slice_op' - slice_op_str += SLICE_OP_TEMPLATE.format( - op_name=slice_op_name, in_name=f'{op_inst_name}.result({i})' + split_op_name = f'{name}_split_op' + split_op_str += SPLIT_OP_TEMPLATE.format( + op_name=split_op_name, in_name=f'{op_inst_name}.result({i})' ) - ret_list.append(f'{slice_op_name}.outputs()') + ret_list.append(f'{split_op_name}.outputs()') else: ret_list.append(f'{op_inst_name}.result({i})') - return slice_op_str, ret_list + return split_op_str, ret_list def _gen_return_result(self, ret_list): - assert len(ret_list) >= 1 if len(ret_list) > 1: return 'return std::make_tuple({});'.format(', '.join(ret_list)) - else: + elif len(ret_list) == 1: return f'return {ret_list[0]};' + elif len(ret_list) == 0: + return 'return;' - def _gen_one_impl(self, op_info, op_name): + def _gen_one_impl(self, op_info, op_name, is_mutable_attr): in_combine, in_combine_op_list = self._gen_in_combine(op_info) compute_op, op_inst_name = self._gen_compute_op( - op_info, op_name, in_combine_op_list + op_info, op_name, in_combine_op_list, is_mutable_attr ) - out_slice, ret_list = self._gen_out_slice_and_ret_list( + out_split, ret_list = self._gen_out_split_and_ret_list( op_info, op_inst_name ) ret = API_IMPL_TEMPLATE.format( ret_type=self._gen_ret_type(op_info), api_name=op_name, - args=self._gen_api_args(op_info, False), + args=self._gen_api_args(op_info, False, is_mutable_attr), in_combine=in_combine, compute_op=compute_op, - out_slice=out_slice, + out_split=out_split, return_result=self._gen_return_result(ret_list), ) @@ -293,9 +300,16 @@ class CodeGen: impl_str = '' for op_info in op_info_items: for op_name in op_info.op_phi_name: - if op_name not in API_LIST: + # NOTE:When infer_meta_func is None, the Build() function generated in pd_op + # is wrong, so temporarily skip the automatic generation of these APIs + if ( + op_info.infer_meta_func is None + and op_name not in PD_MANUAL_OP_LIST + ): continue - impl_str += self._gen_one_impl(op_info, op_name) + impl_str += self._gen_one_impl(op_info, op_name, False) + if len(op_info.mutable_attribute_name_list) > 0: + impl_str += self._gen_one_impl(op_info, op_name, True) body = impl_str for namespace in reversed(namespaces): body = NAMESPACE_TEMPLATE.format(namespace=namespace, body=body) diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.cc b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.cc index 985e896d6e073e6c78093946b72b548b693adaba..8866922e4aa34ee74d3f49389a82f31a41a4d9de 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.cc +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.cc @@ -18,18 +18,5 @@ #include "paddle/ir/core/builtin_op.h" namespace paddle { -namespace dialect { -std::vector concat_grad(std::vector x, - ir::OpResult out_grad, - ir::OpResult axis) { - auto combine_op = - APIBuilder::Instance().GetBuilder()->Build(x); - paddle::dialect::ConcatGradOp concat_grad_op = - APIBuilder::Instance().GetBuilder()->Build( - combine_op.out(), out_grad, axis); - auto split_op = APIBuilder::Instance().GetBuilder()->Build( - concat_grad_op.result(0)); - return split_op.outputs(); -} -} // namespace dialect +namespace dialect {} // namespace dialect } // namespace paddle diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.h b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.h index dff38ef565cb2d8c479970fb61264baaa67510fc..de86758dddba8efe7b353f1679ab4b902c981d90 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.h +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_api.h @@ -21,10 +21,5 @@ #include "paddle/phi/common/place.h" namespace paddle { -namespace dialect { - -std::vector concat_grad(std::vector x, - ir::OpResult out_grad, - ir::OpResult axis); -} // namespace dialect +namespace dialect {} // namespace dialect } // namespace paddle