diff --git a/paddle/fluid/ir/dialect/op_generator/op_build_gen.py b/paddle/fluid/ir/dialect/op_generator/op_build_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..76c48bdde5e1fe8eb490e33df375ab9a664b2618 --- /dev/null +++ b/paddle/fluid/ir/dialect/op_generator/op_build_gen.py @@ -0,0 +1,613 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# generator build function +OP_BUILD_TEMPLATE = """ +void {op_name}::Build({build_args}) {{ +{get_attributes} +{build_mutable_attributes} +{build_inputs} +{build_attributes} +{build_outputs} +}} +""" + + +def GenBuildInputArgsStr( + op_input_name_list, + op_attribute_name_list, + op_attribute_build_arg_type_list, + op_attribute_default_value_list, + op_mutable_attribute_name_list, + op_non_mutable_attribute_name_list, + op_non_mutable_attribute_build_arg_type_list, + op_non_mutable_attribute_default_value_list, + for_func_define=True, + mutable_attr_is_input=False, + attr_args_is_map=False, +): + ''' + Example: ir::Builder &builder, ir::OperationArgument &argument, ir::OpResult x_, phi::DataType dtype=phi::DataType::UNDEFINED, phi::Place place={} + ''' + # add inputs + build_args_str = "ir::Builder &builder, ir::OperationArgument &argument" + if len(op_input_name_list) > 0: + for input_name in op_input_name_list: + build_args_str += ", ir::OpResult " + input_name + "_" + + if attr_args_is_map: + build_args_str += ", ir::AttributeMap attributes" + else: + if not mutable_attr_is_input: + # add attributes + for attr_idx in range(len(op_attribute_name_list)): + build_args_str += ( + ", " + + op_attribute_build_arg_type_list[attr_idx] + + " " + + op_attribute_name_list[attr_idx] + ) + if for_func_define: + if op_attribute_default_value_list[attr_idx] is not None: + default_value = op_attribute_default_value_list[ + attr_idx + ] + if ( + op_attribute_build_arg_type_list[attr_idx] + != "std::string" + ): + if ( + default_value[0] == "'" + or default_value[0] == '"' + ): + default_value = default_value[1:] + if ( + default_value[-1] == "'" + or default_value[-1] == '"' + ): + default_value = default_value[0:-1] + build_args_str += "=" + default_value + else: + # add mutable attributes as inputs + if len(op_mutable_attribute_name_list) > 0: + for mutable_attr in op_mutable_attribute_name_list: + build_args_str += ", ir::OpResult " + mutable_attr + "_" + + # add non-mutable attributes + for attr_idx in range(len(op_non_mutable_attribute_name_list)): + build_args_str += ( + ", " + + op_non_mutable_attribute_build_arg_type_list[attr_idx] + + " " + + op_non_mutable_attribute_name_list[attr_idx] + ) + if for_func_define: + if ( + op_non_mutable_attribute_default_value_list[attr_idx] + is not None + ): + default_value = ( + op_non_mutable_attribute_default_value_list[ + attr_idx + ] + ) + if ( + op_non_mutable_attribute_build_arg_type_list[ + attr_idx + ] + != "std::string" + ): + if ( + default_value[0] == "'" + or default_value[0] == '"' + ): + default_value = default_value[1:] + if ( + default_value[-1] == "'" + or default_value[-1] == '"' + ): + default_value = default_value[0:-1] + build_args_str += "=" + default_value + + return build_args_str + + +mutable_attribute_phi_type_maps = { + 'int': 'phi::DataType::INT32', + 'int64_t': 'phi::DataType::INT64', + 'float': 'phi::DataType::FLOAT32', + 'std::vector': 'phi::DataType::INT64', + 'const std::vector&': 'phi::DataType::INT64', +} + + +def GenBuildInserFullForMutableAttribute( + op_attribute_name_list, + op_attribute_build_arg_type_list, + op_mutable_attribute_name_list, + op_mutable_attribute_type_list, +): + build_mutable_attribute = "" + BUILD_INTARRAY_ATTRIBUTE_TEMPLATE = """ // Generate int_array mutable attribute: {attr_name} + paddle::dialect::FullIntArrayOp full_{attr_name}_op = builder.Build({attr_name}, {phi_dtype}, phi::CPUPlace()); + ir::OpResult {attr_name}_ = full_{attr_name}_op->result(0); + """ + BUILD_SCALAR_ATTRIBUTE_TEMPLATE = """ // Generate scalar mutable attribute: {attr_name} + paddle::dialect::FullOp full_{attr_name}_op = builder.Build(std::vector{{1}}, {attr_name}, {phi_dtype}, phi::CPUPlace()); + ir::OpResult {attr_name}_ = full_{attr_name}_op->result(0); + """ + for idx in range(len(op_mutable_attribute_name_list)): + attr_name = op_mutable_attribute_name_list[idx] + attr_type = op_mutable_attribute_type_list[idx][0] + if attr_name in op_attribute_name_list: + phi_dtype = mutable_attribute_phi_type_maps[ + op_attribute_build_arg_type_list[ + op_attribute_name_list.index(attr_name) + ] + ] + else: + phi_dtype = mutable_attribute_phi_type_maps[ + op_mutable_attribute_type_list[idx][1] + ] + if attr_type == "paddle::dialect::IntArrayAttribute": + build_mutable_attribute += BUILD_INTARRAY_ATTRIBUTE_TEMPLATE.format( + attr_name=attr_name, phi_dtype=phi_dtype + ) + else: + build_mutable_attribute += BUILD_SCALAR_ATTRIBUTE_TEMPLATE.format( + attr_name=attr_name, phi_dtype=phi_dtype + ) + return build_mutable_attribute + + +def GenBuildInputs(op_input_name_list, op_mutable_attribute_name_list): + BUILD_INPUT_TEMPLATE = """ std::vector argument_inputs = {{{inputs_args}}}; + argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); +""" + build_input_str = ' VLOG(4) << "Builder construction inputs";\n' + input_name_list = op_input_name_list + op_mutable_attribute_name_list + if len(input_name_list) > 0: + inputs_args_str = "" + inputs_args_str += "_, ".join(input_name_list) + "_" + build_input_str += BUILD_INPUT_TEMPLATE.format( + inputs_args=inputs_args_str + ) + return build_input_str + + +def GenBuildAttributes( + op_non_mutable_attribute_name_list, op_non_mutable_attribute_type_list +): + INTARRAY_STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = {op_attribute_type}::get(ir::IrContext::Instance(), phi::IntArray({attr})); +""" + SCALAR_STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = TransToIrAttribute({attr}, ir::IrContext::Instance()); +""" + STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = {op_attribute_type}::get(ir::IrContext::Instance(), {attr}); +""" + ARRAY_ATTRIBUTE_TEMPLATE = """ std::vector vec_{attr_name}; + for (size_t i = 0; i < static_cast({attr_size}); i++) {{ + {create_attribute} + vec_{attr_name}.push_back(attr_{attr_name}); + }} + ir::Attribute attr_{attr_name} = ir::ArrayAttribute::get(ir::IrContext::Instance(), vec_{attr_name}); +""" + attr_str = ' VLOG(4) << "Builder construction attributes";\n' + for idx in range(len(op_non_mutable_attribute_name_list)): + if "ir::ArrayAttribute<" in op_non_mutable_attribute_type_list[idx]: + inner_attribute_type = op_non_mutable_attribute_type_list[idx][ + 19:-1 + ] + if inner_attribute_type == "paddle::dialect::IntArrayAttribute": + attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format( + attr_name=op_non_mutable_attribute_name_list[idx], + attr_size=op_non_mutable_attribute_name_list[idx] + + ".size()", + create_attribute=INTARRAY_STR_TEMPLATE.format( + attr_name=op_non_mutable_attribute_name_list[idx], + op_attribute_type=inner_attribute_type, + attr=op_non_mutable_attribute_name_list[idx] + "[i]", + ), + ) + elif inner_attribute_type == "paddle::dialect::ScalarAttribute": + attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format( + attr_name=op_non_mutable_attribute_name_list[idx], + attr_size=op_non_mutable_attribute_name_list[idx] + + ".size()", + create_attribute=SCALAR_STR_TEMPLATE.format( + attr_name=op_non_mutable_attribute_name_list[idx], + attr=op_non_mutable_attribute_name_list[idx] + "[i]", + ), + ) + else: + attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format( + attr_name=op_non_mutable_attribute_name_list[idx], + attr_size=op_non_mutable_attribute_name_list[idx] + + ".size()", + create_attribute=STR_TEMPLATE.format( + attr_name=op_non_mutable_attribute_name_list[idx], + op_attribute_type=inner_attribute_type, + attr=op_non_mutable_attribute_name_list[idx] + "[i]", + ), + ) + elif ( + op_non_mutable_attribute_type_list[idx] + == "paddle::dialect::IntArrayAttribute" + ): + attr_str += INTARRAY_STR_TEMPLATE.format( + attr_name=op_non_mutable_attribute_name_list[idx], + op_attribute_type=op_non_mutable_attribute_type_list[idx], + attr=op_non_mutable_attribute_name_list[idx], + ) + + elif ( + op_non_mutable_attribute_type_list[idx] + == "paddle::dialect::ScalarAttribute" + ): + attr_str += SCALAR_STR_TEMPLATE.format( + attr_name=op_non_mutable_attribute_name_list[idx], + attr=op_non_mutable_attribute_name_list[idx], + ) + else: + attr_str += STR_TEMPLATE.format( + attr_name=op_non_mutable_attribute_name_list[idx], + op_attribute_type=op_non_mutable_attribute_type_list[idx], + attr=op_non_mutable_attribute_name_list[idx], + ) + attr_str += """ argument.AddAttribute("{attr_name}", attr_{attr_name});\n""".format( + attr_name=op_non_mutable_attribute_name_list[idx] + ) + + return attr_str + + +def GenBuildOutputs( + op_input_name_list, + op_input_type_list, + op_mutable_attribute_name_list, + op_mutable_attribute_type_list, + op_output_name_list, + op_output_type_list, + op_output_size_list, + op_infer_meta_map, + mutable_attr_is_input=False, +): + build_output_str = ' VLOG(4) << "Builder construction outputs";\n' + CREATE_INPUT_METATENSOR_TEMPLATE = """ + VLOG(4) << "Builder construction dense_{name}"; + phi::DenseTensor dense_{name}(std::make_unique(paddle::platform::CPUPlace()).get(), + phi::DenseTensorMeta(TransToPhiDataType({name}.dtype()), + {name}.dims(), + {name}.data_layout(), + {name}.lod(), + {name}.offset())); + VLOG(4) << "Builder construction meta_{name}"; + phi::MetaTensor meta_{name}(&dense_{name}); +""" + CREATE_INPUT_VEC_METATENSOR_TEMPLATE = """ std::vector vec_dense_{name}; + for (size_t i=0; i < static_cast({name}.size()); i++) {{ + vec_dense_{name}.push_back(phi::DenseTensor(std::make_unique(paddle::platform::CPUPlace()).get(), + phi::DenseTensorMeta(TransToPhiDataType({name}[i].dyn_cast().dtype()), + {name}[i].dyn_cast().dims(), + {name}[i].dyn_cast().data_layout(), + {name}[i].dyn_cast().lod(), + {name}[i].dyn_cast().offset()))); + }} + std::vector vec_meta_{name}; + for (size_t i=0; i < vec_dense_{name}.size(); i++) {{ + vec_meta_{name}.push_back(phi::MetaTensor(&vec_dense_{name}[i])); + }} + + std::vector meta_{name}; + for (size_t i=0; i < static_cast(vec_meta_{name}.size()); i++) {{ + meta_{name}.push_back(&vec_meta_{name}[i]); + }} + """ + + CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE = """ std::vector {name} = {name}_.owner()->dyn_cast().attributes().at("value").dyn_cast().data().GetData(); (void){name};\n""" + CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE = """ {dtype} {name} = {name}_.owner()->dyn_cast().attributes().at("value").dyn_cast().data().to<{dtype}>(); (void){name};\n""" + + CREATE_OUTPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name}; + phi::MetaTensor meta_{name}(&dense_{name}); +""" + CREATE_OUTPUT_VEC_METATENSOR_TEMPLATE = """ std::vector vec_dense_{name}(({output_size}), phi::DenseTensor()); + std::vector vec_meta_{name}; + for (size_t i=0; i < static_cast({output_size}); i++) {{ + vec_meta_{name}.push_back(phi::MetaTensor(&vec_dense_{name}[i])); + }} + std::vector meta_{name}; + for (size_t i=0; i < static_cast(vec_meta_{name}.size()); i++) {{ + meta_{name}.push_back(&vec_meta_{name}[i]); + }} +""" + # Prepar input type + for idx in range(len(op_input_name_list)): + # is a vector + if 'ir::VectorType' in op_input_type_list[idx]: + build_output_str += " ir::VectorType {name} = {name}_.type().dyn_cast(); (void){name};\n".format( + name=op_input_name_list[idx] + ) + # is a Tensor + else: + build_output_str += " paddle::dialect::DenseTensorType {name} = {name}_.type().dyn_cast(); (void){name};\n".format( + name=op_input_name_list[idx] + ) + + # Prepare mutable attributes + if mutable_attr_is_input: + for idx in range(len(op_mutable_attribute_name_list)): + attr_dtype = op_mutable_attribute_type_list[idx] + # int_array + if attr_dtype[0] == "paddle::dialect::IntArrayAttribute": + build_output_str += ( + CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE.format( + name=op_mutable_attribute_name_list[idx] + ) + ) + # scalar + elif attr_dtype[0] == "paddle::dialect::ScalarAttribute": + build_output_str += ( + CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE.format( + name=op_mutable_attribute_name_list[idx], + dtype=attr_dtype[1], + ) + ) + # string + elif attr_dtype[0] == "ir::StrAttribute": + build_output_str += "" + else: + assert "mutable attribtue type is not right." + build_output_str += "\n" + + # Prepare inputs_meta_tensor & attributes for infer meta + infer_meta_args = [] + for idx in range(len(op_infer_meta_map['param'])): + # is input + if op_infer_meta_map['param'][idx] in op_input_name_list: + if ( + "meta_" + op_infer_meta_map['param'][idx] + ) not in infer_meta_args: + # is a vector + if ( + 'ir::VectorType' + in op_input_type_list[ + op_input_name_list.index( + op_infer_meta_map['param'][idx] + ) + ] + ): + build_output_str += ( + CREATE_INPUT_VEC_METATENSOR_TEMPLATE.format( + name=op_infer_meta_map['param'][idx] + ) + ) + # is a Tensor + else: + build_output_str += CREATE_INPUT_METATENSOR_TEMPLATE.format( + name=op_infer_meta_map['param'][idx] + ) + + infer_meta_args.append("meta_" + op_infer_meta_map['param'][idx]) + # is attribute + else: + infer_meta_args.append(op_infer_meta_map['param'][idx]) + + # Prepare outputs_meta_tensor for infer meta + for idx in range(len(op_output_name_list)): + # is a vector + if 'ir::VectorType' in op_output_type_list[idx]: + build_output_str += CREATE_OUTPUT_VEC_METATENSOR_TEMPLATE.format( + name=op_output_name_list[idx], + output_size=op_output_size_list[idx], + ) + infer_meta_args.append(f"meta_{op_output_name_list[idx]}") + # is a Tensor + else: + build_output_str += CREATE_OUTPUT_METATENSOR_TEMPLATE.format( + name=op_output_name_list[idx] + ) + infer_meta_args.append(f"&meta_{op_output_name_list[idx]}") + + # Execute infer meta function + CREATE_INFER_META_FUNC_TEMPLATE = """ + phi::{func}({args}); +""" + build_output_str += CREATE_INFER_META_FUNC_TEMPLATE.format( + func=op_infer_meta_map['func'], args=", ".join(infer_meta_args) + ) + + # use dense_{name} or vec_dense_{name} to create Outputs type + build_output_str += "\n std::vector argument_outputs;" + + CREATE_OUTPUT_DENSE_TENSOR_TEMPLATE = """ + ir::Type {name}_dense_tensor_type = paddle::dialect::DenseTensorType::get(ir::IrContext::Instance(), TransToIrDataType(dense_{name}.dtype()), dense_{name}.dims(), dense_{name}.layout(), dense_{name}.lod(), dense_{name}.offset()); + argument_outputs.push_back({name}_dense_tensor_type); +""" + CREATE_OUTPUT_VEC_DENSE_TENSOR_TEMPLATE = """ + std::vector {name}_types; + for (size_t i=0; i < static_cast({output_size}); i++) {{ + {name}_types.push_back(paddle::dialect::DenseTensorType::get(ir::IrContext::Instance(), TransToIrDataType(vec_dense_{name}[i].dtype()), vec_dense_{name}[i].dims(), vec_dense_{name}[i].layout(), vec_dense_{name}[i].lod(), vec_dense_{name}[i].offset())); + }} + ir::Type {name}_vector_type = ir::VectorType::get(ir::IrContext::Instance(), {name}_types); + argument_outputs.push_back({name}_vector_type); +""" + for idx in range(len(op_output_name_list)): + # is a vector + if 'ir::VectorType' in op_output_type_list[idx]: + build_output_str += CREATE_OUTPUT_VEC_DENSE_TENSOR_TEMPLATE.format( + name=op_output_name_list[idx], + output_size=op_output_size_list[idx], + ) + # is a Tensor + else: + build_output_str += CREATE_OUTPUT_DENSE_TENSOR_TEMPLATE.format( + name=op_output_name_list[idx] + ) + + build_output_str += " argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());\n" + + return build_output_str + + +def gen_build_func_str( + op_class_name, + op_input_name_list, + op_input_type_list, + op_attribute_name_list, + op_attribute_type_list, + op_attribute_build_arg_type_list, + op_attribute_default_value_list, + op_mutable_attribute_name_list, + op_mutable_attribute_type_list, + op_non_mutable_attribute_name_list, + op_non_mutable_attribute_type_list, + op_non_mutable_attribute_build_arg_type_list, + op_non_mutable_attribute_default_value_list, + op_output_name_list, + op_output_type_list, + op_output_size_list, + op_infer_meta_map, + muta_attr_is_input=False, + attr_args_is_map=False, +): + build_args_for_declare = "" + build_func = "" + + build_args_for_declare = GenBuildInputArgsStr( + op_input_name_list, + op_attribute_name_list, + op_attribute_build_arg_type_list, + op_attribute_default_value_list, + op_mutable_attribute_name_list, + op_non_mutable_attribute_name_list, + op_non_mutable_attribute_build_arg_type_list, + op_non_mutable_attribute_default_value_list, + True, + muta_attr_is_input, + attr_args_is_map, + ) + + build_args_for_define = GenBuildInputArgsStr( + op_input_name_list, + op_attribute_name_list, + op_attribute_build_arg_type_list, + op_attribute_default_value_list, + op_mutable_attribute_name_list, + op_non_mutable_attribute_name_list, + op_non_mutable_attribute_build_arg_type_list, + op_non_mutable_attribute_default_value_list, + False, + muta_attr_is_input, + attr_args_is_map, + ) + inset_full_for_mutable_attributes_str = "" + if not muta_attr_is_input: + inset_full_for_mutable_attributes_str = ( + GenBuildInserFullForMutableAttribute( + op_attribute_name_list, + op_attribute_build_arg_type_list, + op_mutable_attribute_name_list, + op_mutable_attribute_type_list, + ) + ) + + build_inputs_str = GenBuildInputs( + op_input_name_list, op_mutable_attribute_name_list + ) + build_attributes_str = GenBuildAttributes( + op_non_mutable_attribute_name_list, + op_non_mutable_attribute_type_list, + ) + build_outputs_str = GenBuildOutputs( + op_input_name_list, + op_input_type_list, + op_mutable_attribute_name_list, + op_mutable_attribute_type_list, + op_output_name_list, + op_output_type_list, + op_output_size_list, + op_infer_meta_map, + muta_attr_is_input, + ) + + GET_ATTRIBUTES_FROM_MAP_TEMPLATE = """ + {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast<{attr_ir_type}>().data(); +""" + GET_ARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE = """ + {attr_type} {attribute_name}; + for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast().size(); i++) {{ + {attribute_name}.push_back(attributes.at("{attribute_name}").dyn_cast()[i].dyn_cast<{inner_type}>().data()); + }} +""" + GET_INTARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE = """ + {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast().data().GetData(); +""" + GET_SCALAR_ATTRIBUTE_FROM_MAP_TEMPLATE = """ + {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast().data().to<{attr_type}>(); +""" + + get_attributes_str = "" + if attr_args_is_map: + for idx in range(len(op_attribute_name_list)): + attr_type = op_attribute_build_arg_type_list[idx] + attr_type = attr_type.replace("const ", "") + attr_type = attr_type.replace("&", "") + # if op_attribute_build_arg_type_list[idx] == "const std::vector&": + # attr_type = "std::vector" + if "ir::ArrayAttribute" in op_attribute_type_list[idx]: + inner_type = op_attribute_type_list[idx][19:-1] + get_attributes_str += ( + GET_ARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE.format( + attr_type=attr_type, + attribute_name=op_attribute_name_list[idx], + inner_type=inner_type, + ) + ) + elif ( + "paddle::dialect::IntArrayAttribute" + in op_attribute_type_list[idx] + ): + get_attributes_str += ( + GET_INTARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE.format( + attr_type=attr_type, + attribute_name=op_attribute_name_list[idx], + ) + ) + elif ( + "paddle::dialect::ScalarAttribute" + in op_attribute_type_list[idx] + ): + get_attributes_str += ( + GET_SCALAR_ATTRIBUTE_FROM_MAP_TEMPLATE.format( + attr_type=attr_type, + attribute_name=op_attribute_name_list[idx], + ) + ) + else: + get_attributes_str += GET_ATTRIBUTES_FROM_MAP_TEMPLATE.format( + attr_type=attr_type, + attribute_name=op_attribute_name_list[idx], + attr_ir_type=op_attribute_type_list[idx], + ) + + build_func = OP_BUILD_TEMPLATE.format( + op_name=op_class_name, + build_args=build_args_for_define, + build_mutable_attributes=inset_full_for_mutable_attributes_str, + get_attributes=get_attributes_str, + build_inputs=build_inputs_str, + build_attributes=build_attributes_str, + build_outputs=build_outputs_str, + ) + + return (build_args_for_declare, build_func) diff --git a/paddle/fluid/ir/dialect/op_generator/op_gen.py b/paddle/fluid/ir/dialect/op_generator/op_gen.py index 29a96b9c386b0def8e17a536e08e51345f3bed1f..d07331e76f3af3dfb919b3924f86a143da1ce194 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_gen.py @@ -16,6 +16,7 @@ import argparse import os import yaml +from op_build_gen import gen_build_func_str from op_interface_gen import gen_exclusive_interface_str, gen_op_infer_meta_str from op_member_func_gen import gen_op_get_inputs_outputs_str from op_verify_gen import gen_verify_func_str @@ -68,6 +69,7 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{ static OpInfoTuple GetOpInfo(); static void Build({build_args}); {build_mutable_attr_is_input} + {build_attr_num_over_1} void Verify(); {get_inputs_and_outputs} {exclusive_interface} @@ -130,15 +132,6 @@ CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE = ( """OpAttributeInfo("{name}", "{typename}", "{data_type}")""" ) -# build -OP_BUILD_TEMPLATE = """ -void {op_name}::Build({build_args}) {{ -{build_mutable_attributes} -{build_inputs} -{build_attributes} -{build_outputs} -}} -""" DEFINE_OP_TYPE_ID = """ IR_DEFINE_EXPLICIT_TYPE_ID({op_name}) @@ -626,511 +619,6 @@ def to_pascal_case(s): return "".join([word.capitalize() for word in words]) + "" -# ===================================== -# Generate Op Definition Files -# ===================================== -def GenBuildInputArgsStr( - op_input_name_list, - op_attribute_name_list, - op_attribute_build_arg_type_list, - op_attribute_default_value_list, - op_mutable_attribute_name_list, - op_non_mutable_attribute_name_list, - op_non_mutable_attribute_build_arg_type_list, - op_non_mutable_attribute_default_value_list, - for_func_define=True, - mutable_attr_is_input=False, -): - ''' - Example: ir::Builder &builder, ir::OperationArgument &argument, ir::OpResult x_, phi::DataType dtype=phi::DataType::UNDEFINED, phi::Place place={} - ''' - # add inputs - build_args_str = "ir::Builder &builder, ir::OperationArgument &argument" - if len(op_input_name_list) > 0: - for input_name in op_input_name_list: - build_args_str += ", ir::OpResult " + input_name + "_" - - if not mutable_attr_is_input: - # add attributes - for attr_idx in range(len(op_attribute_name_list)): - build_args_str += ( - ", " - + op_attribute_build_arg_type_list[attr_idx] - + " " - + op_attribute_name_list[attr_idx] - ) - if for_func_define: - if op_attribute_default_value_list[attr_idx] is not None: - default_value = op_attribute_default_value_list[attr_idx] - if ( - op_attribute_build_arg_type_list[attr_idx] - != "std::string" - ): - if default_value[0] == "'" or default_value[0] == '"': - default_value = default_value[1:] - if default_value[-1] == "'" or default_value[-1] == '"': - default_value = default_value[0:-1] - build_args_str += "=" + default_value - else: - # add mutable attributes as inputs - if len(op_mutable_attribute_name_list) > 0: - for mutable_attr in op_mutable_attribute_name_list: - build_args_str += ", ir::OpResult " + mutable_attr + "_" - - # add non-mutable attributes - for attr_idx in range(len(op_non_mutable_attribute_name_list)): - build_args_str += ( - ", " - + op_non_mutable_attribute_build_arg_type_list[attr_idx] - + " " - + op_non_mutable_attribute_name_list[attr_idx] - ) - if for_func_define: - if ( - op_non_mutable_attribute_default_value_list[attr_idx] - is not None - ): - default_value = op_non_mutable_attribute_default_value_list[ - attr_idx - ] - if ( - op_non_mutable_attribute_build_arg_type_list[attr_idx] - != "std::string" - ): - if default_value[0] == "'" or default_value[0] == '"': - default_value = default_value[1:] - if default_value[-1] == "'" or default_value[-1] == '"': - default_value = default_value[0:-1] - build_args_str += "=" + default_value - - return build_args_str - - -mutable_attribute_phi_type_maps = { - 'int': 'phi::DataType::INT32', - 'int64_t': 'phi::DataType::INT64', - 'float': 'phi::DataType::FLOAT32', - 'std::vector': 'phi::DataType::INT64', - 'const std::vector&': 'phi::DataType::INT64', -} - - -def GenBuildInserFullForMutableAttribute( - op_attribute_name_list, - op_attribute_build_arg_type_list, - op_mutable_attribute_name_list, - op_mutable_attribute_type_list, -): - build_mutable_attribute = "" - BUILD_INTARRAY_ATTRIBUTE_TEMPLATE = """ // Generate int_array mutable attribute: {attr_name} - paddle::dialect::FullIntArrayOp full_{attr_name}_op = builder.Build({attr_name}, {phi_dtype}, phi::CPUPlace()); - ir::OpResult {attr_name}_ = full_{attr_name}_op->result(0); - """ - BUILD_SCALAR_ATTRIBUTE_TEMPLATE = """ // Generate scalar mutable attribute: {attr_name} - paddle::dialect::FullOp full_{attr_name}_op = builder.Build(std::vector{{1}}, {attr_name}, {phi_dtype}, phi::CPUPlace()); - ir::OpResult {attr_name}_ = full_{attr_name}_op->result(0); - """ - for idx in range(len(op_mutable_attribute_name_list)): - attr_name = op_mutable_attribute_name_list[idx] - attr_type = op_mutable_attribute_type_list[idx][0] - if attr_name in op_attribute_name_list: - phi_dtype = mutable_attribute_phi_type_maps[ - op_attribute_build_arg_type_list[ - op_attribute_name_list.index(attr_name) - ] - ] - else: - phi_dtype = mutable_attribute_phi_type_maps[ - op_mutable_attribute_type_list[idx][1] - ] - if attr_type == "paddle::dialect::IntArrayAttribute": - build_mutable_attribute += BUILD_INTARRAY_ATTRIBUTE_TEMPLATE.format( - attr_name=attr_name, phi_dtype=phi_dtype - ) - else: - build_mutable_attribute += BUILD_SCALAR_ATTRIBUTE_TEMPLATE.format( - attr_name=attr_name, phi_dtype=phi_dtype - ) - return build_mutable_attribute - - -def GenBuildInputs(op_input_name_list, op_mutable_attribute_name_list): - BUILD_INPUT_TEMPLATE = """ std::vector argument_inputs = {{{inputs_args}}}; - argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); -""" - build_input_str = ' VLOG(4) << "Builder construction inputs";\n' - input_name_list = op_input_name_list + op_mutable_attribute_name_list - if len(input_name_list) > 0: - inputs_args_str = "" - inputs_args_str += "_, ".join(input_name_list) + "_" - build_input_str += BUILD_INPUT_TEMPLATE.format( - inputs_args=inputs_args_str - ) - return build_input_str - - -def GenBuildAttributes( - op_non_mutable_attribute_name_list, op_non_mutable_attribute_type_list -): - INTARRAY_STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = {op_attribute_type}::get(ir::IrContext::Instance(), phi::IntArray({attr})); -""" - SCALAR_STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = TransToIrAttribute({attr}, ir::IrContext::Instance()); -""" - STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = {op_attribute_type}::get(ir::IrContext::Instance(), {attr}); -""" - ARRAY_ATTRIBUTE_TEMPLATE = """ std::vector vec_{attr_name}; - for (size_t i = 0; i < static_cast({attr_size}); i++) {{ - {create_attribute} - vec_{attr_name}.push_back(attr_{attr_name}); - }} - ir::Attribute attr_{attr_name} = ir::ArrayAttribute::get(ir::IrContext::Instance(), vec_{attr_name}); -""" - attr_str = ' VLOG(4) << "Builder construction attributes";\n' - for idx in range(len(op_non_mutable_attribute_name_list)): - if "ir::ArrayAttribute<" in op_non_mutable_attribute_type_list[idx]: - inner_attribute_type = op_non_mutable_attribute_type_list[idx][ - 19:-1 - ] - if inner_attribute_type == "paddle::dialect::IntArrayAttribute": - attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format( - attr_name=op_non_mutable_attribute_name_list[idx], - attr_size=op_non_mutable_attribute_name_list[idx] - + ".size()", - create_attribute=INTARRAY_STR_TEMPLATE.format( - attr_name=op_non_mutable_attribute_name_list[idx], - op_attribute_type=inner_attribute_type, - attr=op_non_mutable_attribute_name_list[idx] + "[i]", - ), - ) - elif inner_attribute_type == "paddle::dialect::ScalarAttribute": - attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format( - attr_name=op_non_mutable_attribute_name_list[idx], - attr_size=op_non_mutable_attribute_name_list[idx] - + ".size()", - create_attribute=SCALAR_STR_TEMPLATE.format( - attr_name=op_non_mutable_attribute_name_list[idx], - attr=op_non_mutable_attribute_name_list[idx] + "[i]", - ), - ) - else: - attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format( - attr_name=op_non_mutable_attribute_name_list[idx], - attr_size=op_non_mutable_attribute_name_list[idx] - + ".size()", - create_attribute=STR_TEMPLATE.format( - attr_name=op_non_mutable_attribute_name_list[idx], - op_attribute_type=inner_attribute_type, - attr=op_non_mutable_attribute_name_list[idx] + "[i]", - ), - ) - elif ( - op_non_mutable_attribute_type_list[idx] - == "paddle::dialect::IntArrayAttribute" - ): - attr_str += INTARRAY_STR_TEMPLATE.format( - attr_name=op_non_mutable_attribute_name_list[idx], - op_attribute_type=op_non_mutable_attribute_type_list[idx], - attr=op_non_mutable_attribute_name_list[idx], - ) - - elif ( - op_non_mutable_attribute_type_list[idx] - == "paddle::dialect::ScalarAttribute" - ): - attr_str += SCALAR_STR_TEMPLATE.format( - attr_name=op_non_mutable_attribute_name_list[idx], - attr=op_non_mutable_attribute_name_list[idx], - ) - else: - attr_str += STR_TEMPLATE.format( - attr_name=op_non_mutable_attribute_name_list[idx], - op_attribute_type=op_non_mutable_attribute_type_list[idx], - attr=op_non_mutable_attribute_name_list[idx], - ) - attr_str += """ argument.AddAttribute("{attr_name}", attr_{attr_name});\n""".format( - attr_name=op_non_mutable_attribute_name_list[idx] - ) - - return attr_str - - -def GenBuildOutputs( - op_input_name_list, - op_input_type_list, - op_mutable_attribute_name_list, - op_mutable_attribute_type_list, - op_output_name_list, - op_output_type_list, - op_output_size_list, - op_infer_meta_map, - mutable_attr_is_input=False, -): - build_output_str = ' VLOG(4) << "Builder construction outputs";\n' - CREATE_INPUT_METATENSOR_TEMPLATE = """ - VLOG(4) << "Builder construction dense_{name}"; - phi::DenseTensor dense_{name}(std::make_unique(paddle::platform::CPUPlace()).get(), - phi::DenseTensorMeta(TransToPhiDataType({name}.dtype()), - {name}.dims(), - {name}.data_layout(), - {name}.lod(), - {name}.offset())); - VLOG(4) << "Builder construction meta_{name}"; - phi::MetaTensor meta_{name}(&dense_{name}); -""" - CREATE_INPUT_VEC_METATENSOR_TEMPLATE = """ std::vector vec_dense_{name}; - for (size_t i=0; i < static_cast({name}.size()); i++) {{ - vec_dense_{name}.push_back(phi::DenseTensor(std::make_unique(paddle::platform::CPUPlace()).get(), - phi::DenseTensorMeta(TransToPhiDataType({name}[i].dyn_cast().dtype()), - {name}[i].dyn_cast().dims(), - {name}[i].dyn_cast().data_layout(), - {name}[i].dyn_cast().lod(), - {name}[i].dyn_cast().offset()))); - }} - std::vector vec_meta_{name}; - for (size_t i=0; i < vec_dense_{name}.size(); i++) {{ - vec_meta_{name}.push_back(phi::MetaTensor(&vec_dense_{name}[i])); - }} - - std::vector meta_{name}; - for (size_t i=0; i < static_cast(vec_meta_{name}.size()); i++) {{ - meta_{name}.push_back(&vec_meta_{name}[i]); - }} - """ - - CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE = """ std::vector {name} = {name}_.owner()->dyn_cast().attributes().at("value").dyn_cast().data().GetData(); (void){name};\n""" - CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE = """ {dtype} {name} = {name}_.owner()->dyn_cast().attributes().at("value").dyn_cast().data().to<{dtype}>(); (void){name};\n""" - - CREATE_OUTPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name}; - phi::MetaTensor meta_{name}(&dense_{name}); -""" - CREATE_OUTPUT_VEC_METATENSOR_TEMPLATE = """ std::vector vec_dense_{name}(({output_size}), phi::DenseTensor()); - std::vector vec_meta_{name}; - for (size_t i=0; i < static_cast({output_size}); i++) {{ - vec_meta_{name}.push_back(phi::MetaTensor(&vec_dense_{name}[i])); - }} - std::vector meta_{name}; - for (size_t i=0; i < static_cast(vec_meta_{name}.size()); i++) {{ - meta_{name}.push_back(&vec_meta_{name}[i]); - }} -""" - # Prepar input type - for idx in range(len(op_input_name_list)): - # is a vector - if 'ir::VectorType' in op_input_type_list[idx]: - build_output_str += " ir::VectorType {name} = {name}_.type().dyn_cast(); (void){name};\n".format( - name=op_input_name_list[idx] - ) - # is a Tensor - else: - build_output_str += " paddle::dialect::DenseTensorType {name} = {name}_.type().dyn_cast(); (void){name};\n".format( - name=op_input_name_list[idx] - ) - - # Prepare mutable attributes - if mutable_attr_is_input: - for idx in range(len(op_mutable_attribute_name_list)): - attr_dtype = op_mutable_attribute_type_list[idx] - # int_array - if attr_dtype[0] == "paddle::dialect::IntArrayAttribute": - build_output_str += ( - CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE.format( - name=op_mutable_attribute_name_list[idx] - ) - ) - # scalar - elif attr_dtype[0] == "paddle::dialect::ScalarAttribute": - build_output_str += ( - CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE.format( - name=op_mutable_attribute_name_list[idx], - dtype=attr_dtype[1], - ) - ) - # string - elif attr_dtype[0] == "ir::StrAttribute": - build_output_str += "" - else: - assert "mutable attribtue type is not right." - build_output_str += "\n" - - # Prepare inputs_meta_tensor & attributes for infer meta - infer_meta_args = [] - for idx in range(len(op_infer_meta_map['param'])): - # is input - if op_infer_meta_map['param'][idx] in op_input_name_list: - if ( - "meta_" + op_infer_meta_map['param'][idx] - ) not in infer_meta_args: - # is a vector - if ( - 'ir::VectorType' - in op_input_type_list[ - op_input_name_list.index( - op_infer_meta_map['param'][idx] - ) - ] - ): - build_output_str += ( - CREATE_INPUT_VEC_METATENSOR_TEMPLATE.format( - name=op_infer_meta_map['param'][idx] - ) - ) - # is a Tensor - else: - build_output_str += CREATE_INPUT_METATENSOR_TEMPLATE.format( - name=op_infer_meta_map['param'][idx] - ) - - infer_meta_args.append("meta_" + op_infer_meta_map['param'][idx]) - # is attribute - else: - infer_meta_args.append(op_infer_meta_map['param'][idx]) - - # Prepare outputs_meta_tensor for infer meta - for idx in range(len(op_output_name_list)): - # is a vector - if 'ir::VectorType' in op_output_type_list[idx]: - build_output_str += CREATE_OUTPUT_VEC_METATENSOR_TEMPLATE.format( - name=op_output_name_list[idx], - output_size=op_output_size_list[idx], - ) - infer_meta_args.append(f"meta_{op_output_name_list[idx]}") - # is a Tensor - else: - build_output_str += CREATE_OUTPUT_METATENSOR_TEMPLATE.format( - name=op_output_name_list[idx] - ) - infer_meta_args.append(f"&meta_{op_output_name_list[idx]}") - - # Execute infer meta function - CREATE_INFER_META_FUNC_TEMPLATE = """ - phi::{func}({args}); -""" - build_output_str += CREATE_INFER_META_FUNC_TEMPLATE.format( - func=op_infer_meta_map['func'], args=", ".join(infer_meta_args) - ) - - # use dense_{name} or vec_dense_{name} to create Outputs type - build_output_str += "\n std::vector argument_outputs;" - - CREATE_OUTPUT_DENSE_TENSOR_TEMPLATE = """ - ir::Type {name}_dense_tensor_type = paddle::dialect::DenseTensorType::get(ir::IrContext::Instance(), TransToIrDataType(dense_{name}.dtype()), dense_{name}.dims(), dense_{name}.layout(), dense_{name}.lod(), dense_{name}.offset()); - argument_outputs.push_back({name}_dense_tensor_type); -""" - CREATE_OUTPUT_VEC_DENSE_TENSOR_TEMPLATE = """ - std::vector {name}_types; - for (size_t i=0; i < static_cast({output_size}); i++) {{ - {name}_types.push_back(paddle::dialect::DenseTensorType::get(ir::IrContext::Instance(), TransToIrDataType(vec_dense_{name}[i].dtype()), vec_dense_{name}[i].dims(), vec_dense_{name}[i].layout(), vec_dense_{name}[i].lod(), vec_dense_{name}[i].offset())); - }} - ir::Type {name}_vector_type = ir::VectorType::get(ir::IrContext::Instance(), {name}_types); - argument_outputs.push_back({name}_vector_type); -""" - for idx in range(len(op_output_name_list)): - # is a vector - if 'ir::VectorType' in op_output_type_list[idx]: - build_output_str += CREATE_OUTPUT_VEC_DENSE_TENSOR_TEMPLATE.format( - name=op_output_name_list[idx], - output_size=op_output_size_list[idx], - ) - # is a Tensor - else: - build_output_str += CREATE_OUTPUT_DENSE_TENSOR_TEMPLATE.format( - name=op_output_name_list[idx] - ) - - build_output_str += " argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());\n" - - return build_output_str - - -def GenBuild( - op_class_name, - op_input_name_list, - op_input_type_list, - op_attribute_name_list, - op_attribute_build_arg_type_list, - op_attribute_default_value_list, - op_mutable_attribute_name_list, - op_mutable_attribute_type_list, - op_non_mutable_attribute_name_list, - op_non_mutable_attribute_type_list, - op_non_mutable_attribute_build_arg_type_list, - op_non_mutable_attribute_default_value_list, - op_output_name_list, - op_output_type_list, - op_output_size_list, - op_infer_meta_map, - muta_attr_is_input=False, -): - build_args_for_declare = "" - build_func = "" - - build_args_for_declare = GenBuildInputArgsStr( - op_input_name_list, - op_attribute_name_list, - op_attribute_build_arg_type_list, - op_attribute_default_value_list, - op_mutable_attribute_name_list, - op_non_mutable_attribute_name_list, - op_non_mutable_attribute_build_arg_type_list, - op_non_mutable_attribute_default_value_list, - True, - muta_attr_is_input, - ) - - build_args_for_define = GenBuildInputArgsStr( - op_input_name_list, - op_attribute_name_list, - op_attribute_build_arg_type_list, - op_attribute_default_value_list, - op_mutable_attribute_name_list, - op_non_mutable_attribute_name_list, - op_non_mutable_attribute_build_arg_type_list, - op_non_mutable_attribute_default_value_list, - False, - muta_attr_is_input, - ) - inset_full_for_mutable_attributes_str = "" - if not muta_attr_is_input: - inset_full_for_mutable_attributes_str = ( - GenBuildInserFullForMutableAttribute( - op_attribute_name_list, - op_attribute_build_arg_type_list, - op_mutable_attribute_name_list, - op_mutable_attribute_type_list, - ) - ) - - build_inputs_str = GenBuildInputs( - op_input_name_list, op_mutable_attribute_name_list - ) - build_attributes_str = GenBuildAttributes( - op_non_mutable_attribute_name_list, - op_non_mutable_attribute_type_list, - ) - build_outputs_str = GenBuildOutputs( - op_input_name_list, - op_input_type_list, - op_mutable_attribute_name_list, - op_mutable_attribute_type_list, - op_output_name_list, - op_output_type_list, - op_output_size_list, - op_infer_meta_map, - muta_attr_is_input, - ) - - build_func = OP_BUILD_TEMPLATE.format( - op_name=op_class_name, - build_args=build_args_for_define, - build_mutable_attributes=inset_full_for_mutable_attributes_str, - build_inputs=build_inputs_str, - build_attributes=build_attributes_str, - build_outputs=build_outputs_str, - ) - - return (build_args_for_declare, build_func) - - def OpGenerator( op_yaml_files, op_compat_yaml_file, @@ -1243,17 +731,20 @@ def OpGenerator( build_args_with_muta_attr_not_input_for_declare = "" build_func_with_muta_attr_not_input = "" build_mutable_attr_is_input = "" + build_attr_num_over_1 = "" + build_func_with_attr_is_map = "" build_func_with_muta_attr_is_input = "" if op_infer_meta_map is not None: ( build_args_with_muta_attr_not_input_for_declare, build_func_with_muta_attr_not_input, - ) = GenBuild( + ) = gen_build_func_str( op_class_name, op_input_name_list, op_input_type_list, op_attribute_name_list, + op_attribute_type_list, op_attribute_build_arg_type_list, op_attribute_default_value_list, op_mutable_attribute_name_list, @@ -1268,15 +759,47 @@ def OpGenerator( op_infer_meta_map, muta_attr_is_input=False, ) + if len(op_attribute_name_list) > 1: + ( + build_args_with_attr_is_map_for_declare, + build_func_with_attr_is_map, + ) = gen_build_func_str( + op_class_name, + op_input_name_list, + op_input_type_list, + op_attribute_name_list, + op_attribute_type_list, + op_attribute_build_arg_type_list, + op_attribute_default_value_list, + op_mutable_attribute_name_list, + op_mutable_attribute_type_list, + op_non_mutable_attribute_name_list, + op_non_mutable_attribute_type_list, + op_non_mutable_attribute_build_arg_type_list, + op_non_mutable_attribute_default_value_list, + op_output_name_list, + op_output_type_list, + op_output_size_list, + op_infer_meta_map, + muta_attr_is_input=False, + attr_args_is_map=True, + ) + build_attr_num_over_1 = ( + "static void Build({build_args});".format( + build_args=build_args_with_attr_is_map_for_declare + ) + ) + if len(op_mutable_attribute_name_list) > 0: ( build_args_with_muta_attr_is_input_for_declare, build_func_with_muta_attr_is_input, - ) = GenBuild( + ) = gen_build_func_str( op_class_name, op_input_name_list, op_input_type_list, op_attribute_name_list, + op_attribute_type_list, op_attribute_build_arg_type_list, op_attribute_default_value_list, op_mutable_attribute_name_list, @@ -1307,6 +830,7 @@ def OpGenerator( attribute_num=0, build_args=build_args_with_muta_attr_not_input_for_declare, build_mutable_attr_is_input=build_mutable_attr_is_input, + build_attr_num_over_1=build_attr_num_over_1, get_inputs_and_outputs=op_get_inputs_outputs_str, exclusive_interface=exclusive_interface_str, ) @@ -1323,6 +847,7 @@ def OpGenerator( attribute_num=len(op_non_mutable_attribute_name_list), build_args=build_args_with_muta_attr_not_input_for_declare, build_mutable_attr_is_input=build_mutable_attr_is_input, + build_attr_num_over_1=build_attr_num_over_1, get_inputs_and_outputs=op_get_inputs_outputs_str, exclusive_interface=exclusive_interface_str, ) @@ -1457,6 +982,7 @@ def OpGenerator( ops_defined_list.append(op_defined_str) ops_defined_list.append(op_info_func_str) ops_defined_list.append(build_func_with_muta_attr_not_input) + ops_defined_list.append(build_func_with_attr_is_map) if len(op_mutable_attribute_name_list) > 0: ops_defined_list.append(build_func_with_muta_attr_is_input) ops_defined_list.append(op_verify_str) diff --git a/test/cpp/ir/core/ir_exe_test.cc b/test/cpp/ir/core/ir_exe_test.cc index 0c23e49e805ec1e7d39fa5ca59853ba38cdd083d..0b3f956cd47c9f386f0a888b00e80d36bc182e82 100644 --- a/test/cpp/ir/core/ir_exe_test.cc +++ b/test/cpp/ir/core/ir_exe_test.cc @@ -64,13 +64,27 @@ TEST(program_test, program) { // Def: A = paddle::dialect::UniformOp(std::vector shape, // phi::DataType dtype, float min, float max, int seed, phi::Place place) + ir::AttributeMap uniform1_attributes; + uniform1_attributes.insert({"shape", + paddle::dialect::IntArrayAttribute::get( + ir::IrContext::Instance(), + phi::IntArray(std::vector{2, 2}))}); + uniform1_attributes.insert( + {"dtype", + paddle::dialect::DataTypeAttribute::get(ir::IrContext::Instance(), + phi::DataType::FLOAT32)}); + uniform1_attributes.insert( + {"min", ir::FloatAttribute::get(ir::IrContext::Instance(), 0.0)}); + uniform1_attributes.insert( + {"max", ir::FloatAttribute::get(ir::IrContext::Instance(), 1.0)}); + uniform1_attributes.insert( + {"seed", ir::Int32Attribute::get(ir::IrContext::Instance(), 2)}); + uniform1_attributes.insert({"place", + paddle::dialect::PlaceAttribute::get( + ir::IrContext::Instance(), phi::CPUPlace())}); paddle::dialect::UniformOp uniform1 = - builder.Build(std::vector{2, 2}, - phi::DataType::FLOAT32, - 0.0, - 1.0, - 2, - phi::CPUPlace()); + builder.Build(uniform1_attributes); + EXPECT_EQ(uniform1->result(0).type().isa(), true); EXPECT_EQ(block->size(), 4u);