From 3a452e4e1cf3d64a9edbee99fced8aa867d8d2b0 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Fri, 9 Jun 2023 10:22:54 +0800 Subject: [PATCH] [IR] Refine IR builder and throw methods (#54396) * refine code * refine code * refine code * refine code * refine code * refine code * refine code * fix bug * refine code * refine code * refine code * refine code * refine code * delete unused code * delete unused code * refine code --- paddle/fluid/ir/dialect/op_gen.py | 605 ++++++++++++++++-------- paddle/fluid/ir/dialect/pd_attribute.cc | 18 + paddle/fluid/ir/dialect/pd_attribute.h | 4 + paddle/fluid/ir/dialect/utils.h | 7 +- paddle/ir/core/builder.h | 2 +- paddle/ir/core/builtin_op.cc | 159 +++---- paddle/ir/core/builtin_op.h | 4 +- paddle/ir/core/operation.cc | 9 +- paddle/ir/core/storage_manager.cc | 8 +- paddle/phi/api/yaml/legacy_ops.yaml | 12 + paddle/phi/api/yaml/op_compat.yaml | 10 +- paddle/phi/infermeta/nullary.cc | 9 + paddle/phi/infermeta/nullary.h | 4 + paddle/phi/kernels/cpu/full_kernel.cc | 15 + paddle/phi/kernels/full_kernel.h | 6 + test/cpp/ir/core/ir_exe_test.cc | 99 ++-- test/cpp/ir/core/ir_op_test.cc | 9 +- test/cpp/ir/core/phi_kernel_adaptor.h | 52 +- 18 files changed, 655 insertions(+), 377 deletions(-) diff --git a/paddle/fluid/ir/dialect/op_gen.py b/paddle/fluid/ir/dialect/op_gen.py index 106229e141e..ab8354f1fb8 100644 --- a/paddle/fluid/ir/dialect/op_gen.py +++ b/paddle/fluid/ir/dialect/op_gen.py @@ -57,6 +57,7 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{ static constexpr uint32_t attributes_num = {attribute_num}; static OpInfoTuple GetOpInfo(); static void Build({build_args}); + {build_mutable_attr_is_input} static void Verify(const std::vector &inputs, const std::vector &outputs, const ir::AttributeMap &attributes); {get_inputs_and_outputs} {exclusive_interface} @@ -94,6 +95,7 @@ CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_g #include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/ternary.h" #include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/api/lib/utils/allocator.h" {input} """ @@ -112,9 +114,7 @@ OpInfoTuple {op_name}::GetOpInfo() {{ return std::make_tuple(inputs, attributes, outputs, run_time_info); }} """ -CONSTRUCT_INPUT_INFO_TEMPLATE = ( - """OpInputInfo("{name}", "{typename}", {optional}, {no_need_buffer})""" -) +CONSTRUCT_INPUT_INFO_TEMPLATE = """OpInputInfo("{name}", "{typename}", {optional}, {no_need_buffer}, {is_mutable_attribute})""" CONSTRUCT_OUTPUT_INFO_TEMPLATE = ( """OpOutputInfo("{name}", "{typename}", {optional}, {intermediate})""" ) @@ -125,6 +125,7 @@ CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE = ( # build OP_BUILD_TEMPLATE = """ void {op_name}::Build({build_args}) {{ +{build_mutable_attributes} {build_inputs} {build_attributes} {build_outputs} @@ -303,6 +304,7 @@ class OpInfoParser: self.output_type_list, self.output_optional_list, ) + # parse attributes self.attr_types_map = { 'IntArray': ['paddle::dialect::IntArrayAttribute', 'IntArray'], @@ -313,7 +315,7 @@ class OpInfoParser: 'Scalar(dobule)': ['ir::DoubleAttribute', 'dobule'], 'Scalar[]': [ 'ir::ArrayAttribute', - 'std::vector', + 'const std::vector&', ], 'int': ['ir::Int32_tAttribute', 'int'], 'int32_t': ['ir::Int32_tAttribute', 'int32_t'], @@ -323,18 +325,18 @@ class OpInfoParser: 'float': ['ir::FloatAttribute', 'float'], 'float[]': [ 'ir::ArrayAttribute', - 'std::vector', + 'const std::vector&', ], 'double': ['ir::DoubleAttribute', 'double'], 'bool': ['ir::BoolAttribute', 'bool'], 'bool[]': [ 'ir::ArrayAttribute', - 'std::vecot', + 'const std::vecot&', ], 'str': ['ir::StrAttribute', 'std::string'], 'str[]': [ 'ir::ArrayAttribute', - 'std::vector', + 'const std::vector&', ], 'Place': ['paddle::dialect::PlaceAttribute', 'Place'], 'DataLayout': [ @@ -344,11 +346,11 @@ class OpInfoParser: 'DataType': ['paddle::dialect::DataTypeAttribute', 'DataType'], 'int64_t[]': [ 'ir::ArrayAttribute', - 'std::vector', + 'const std::vector&', ], 'int[]': [ 'ir::ArrayAttribute', - 'std::vector', + 'const std::vector&', ], } self.attribute_name_list = self.parse_attribute_name_list() @@ -368,6 +370,14 @@ class OpInfoParser: self.mutable_attribute_type_list, ) = self.parse_mutable_attribute() + ( + self.non_mutable_attribute_name_list, + self.non_mutable_attribute_type_list, + self.non_mutable_attribute_data_type_list, + self.non_mutable_attribute_build_arg_type_list, + self.non_mutable_attribute_default_value_list, + ) = self.parse_non_nutable_attribute() + # parse infermeta && kernel self.infer_meta_map = self.parse_infer_meta_map() self.kernel_map = self.parse_kernel_map() @@ -424,7 +434,13 @@ class OpInfoParser: ["ir::StrAttribute", "std::string"] ) else: - mutable_attribute_name_list.append(scalar_attr) + if ( + scalar_attr == "depth" + and self.op_phi_name[0] == "one_hot" + ): + mutable_attribute_name_list.append("num_classes") + else: + mutable_attribute_name_list.append(scalar_attr) mutable_attribute_type_list.append( [ "paddle::dialect::ScalarAttribute", @@ -459,7 +475,55 @@ class OpInfoParser: ], ] ) - return mutable_attribute_name_list, mutable_attribute_type_list + sorted_mutable_attribute_name_list = [] + sorted_mutable_attribute_type_list = [] + for attr_name in self.attribute_name_list: + if attr_name in mutable_attribute_name_list: + sorted_mutable_attribute_name_list.append(attr_name) + sorted_mutable_attribute_type_list.append( + mutable_attribute_type_list[ + mutable_attribute_name_list.index(attr_name) + ] + ) + + return ( + sorted_mutable_attribute_name_list, + sorted_mutable_attribute_type_list, + ) + + def parse_non_nutable_attribute(self): + op_non_mutable_attribute_name_list = [] + op_non_mutable_attribute_type_list = [] + op_non_mutable_attribute_data_type_list = [] + op_non_mutable_attribute_build_arg_type_list = [] + op_non_mutable_attribute_default_value_list = [] + for idx in range(len(self.attribute_name_list)): + if ( + self.attribute_name_list[idx] + not in self.mutable_attribute_name_list + ): + op_non_mutable_attribute_name_list.append( + self.attribute_name_list[idx] + ) + op_non_mutable_attribute_type_list.append( + self.attribute_type_list[idx] + ) + op_non_mutable_attribute_data_type_list.append( + self.attribute_data_type_list[idx] + ) + op_non_mutable_attribute_build_arg_type_list.append( + self.attribute_build_arg_type_list[idx] + ) + op_non_mutable_attribute_default_value_list.append( + self.attribute_default_value_list[idx] + ) + return ( + op_non_mutable_attribute_name_list, + op_non_mutable_attribute_type_list, + op_non_mutable_attribute_data_type_list, + op_non_mutable_attribute_build_arg_type_list, + op_non_mutable_attribute_default_value_list, + ) def parse_input_name_list(self): name_list = [] @@ -649,53 +713,129 @@ def to_pascal_case(s): # ===================================== def GenBuildInputArgsStr( op_input_name_list, + op_attribute_name_list, + op_attribute_build_arg_type_list, + op_attribute_default_value_list, op_mutable_attribute_name_list, op_non_mutable_attribute_name_list, op_non_mutable_attribute_build_arg_type_list, op_non_mutable_attribute_default_value_list, for_func_define=True, + mutable_attr_is_input=False, ): ''' - Example: ir::OperationArgument &argument, ir::OpResult x_, phi::DataType dtype=phi::DataType::UNDEFINED, phi::Place place={} + Example: ir::Builder &builder, ir::OperationArgument &argument, ir::OpResult x_, phi::DataType dtype=phi::DataType::UNDEFINED, phi::Place place={} ''' # add inputs - build_args_str = "ir::OperationArgument &argument" + build_args_str = "ir::Builder &builder, ir::OperationArgument &argument" if len(op_input_name_list) > 0: for input_name in op_input_name_list: build_args_str += ", ir::OpResult " + input_name + "_" - # add mutable attributes as inputs - if len(op_mutable_attribute_name_list) > 0: - for mutable_attr in op_mutable_attribute_name_list: - build_args_str += ", ir::OpResult " + mutable_attr + "_" - - # add non-mutable attributes - for attr_idx in range(len(op_non_mutable_attribute_name_list)): - build_args_str += ( - ", " - + op_non_mutable_attribute_build_arg_type_list[attr_idx] - + " " - + op_non_mutable_attribute_name_list[attr_idx] - ) - if for_func_define: - if ( - op_non_mutable_attribute_default_value_list[attr_idx] - is not None - ): - default_value = op_non_mutable_attribute_default_value_list[ - attr_idx - ] + + if not mutable_attr_is_input: + # add attributes + for attr_idx in range(len(op_attribute_name_list)): + build_args_str += ( + ", " + + op_attribute_build_arg_type_list[attr_idx] + + " " + + op_attribute_name_list[attr_idx] + ) + if for_func_define: + if op_attribute_default_value_list[attr_idx] is not None: + default_value = op_attribute_default_value_list[attr_idx] + if ( + op_attribute_build_arg_type_list[attr_idx] + != "std::string" + ): + if default_value[0] == "'" or default_value[0] == '"': + default_value = default_value[1:] + if default_value[-1] == "'" or default_value[-1] == '"': + default_value = default_value[0:-1] + build_args_str += "=" + default_value + else: + # add mutable attributes as inputs + if len(op_mutable_attribute_name_list) > 0: + for mutable_attr in op_mutable_attribute_name_list: + build_args_str += ", ir::OpResult " + mutable_attr + "_" + + # add non-mutable attributes + for attr_idx in range(len(op_non_mutable_attribute_name_list)): + build_args_str += ( + ", " + + op_non_mutable_attribute_build_arg_type_list[attr_idx] + + " " + + op_non_mutable_attribute_name_list[attr_idx] + ) + if for_func_define: if ( - op_non_mutable_attribute_build_arg_type_list[attr_idx] - != "std::string" + op_non_mutable_attribute_default_value_list[attr_idx] + is not None ): - if default_value[0] == "'" or default_value[0] == '"': - default_value = default_value[1:] - if default_value[-1] == "'" or default_value[-1] == '"': - default_value = default_value[0:-1] - build_args_str += "=" + default_value + default_value = op_non_mutable_attribute_default_value_list[ + attr_idx + ] + if ( + op_non_mutable_attribute_build_arg_type_list[attr_idx] + != "std::string" + ): + if default_value[0] == "'" or default_value[0] == '"': + default_value = default_value[1:] + if default_value[-1] == "'" or default_value[-1] == '"': + default_value = default_value[0:-1] + build_args_str += "=" + default_value + return build_args_str +mutable_attribute_phi_type_maps = { + 'int': 'phi::DataType::INT32', + 'int64_t': 'phi::DataType::INT64', + 'float': 'phi::DataType::FLOAT32', + 'std::vector': 'phi::DataType::INT64', + 'const std::vector&': 'phi::DataType::INT64', +} + + +def GenBuildInserFullForMutableAttribute( + op_attribute_name_list, + op_attribute_build_arg_type_list, + op_mutable_attribute_name_list, + op_mutable_attribute_type_list, +): + build_mutable_attribute = "" + BUILD_INTARRAY_ATTRIBUTE_TEMPLATE = """ // Generate int_array mutable attribute: {attr_name} + paddle::dialect::FullIntArrayOp full_{attr_name}_op = builder.Build({attr_name}, {phi_dtype}, phi::CPUPlace()); + ir::OpResult {attr_name}_ = full_{attr_name}_op->GetResultByIndex(0); + """ + BUILD_SCALAR_ATTRIBUTE_TEMPLATE = """ // Generate scalar mutable attribute: {attr_name} + paddle::dialect::FullOp full_{attr_name}_op = builder.Build(std::vector{{1}}, {attr_name}, {phi_dtype}, phi::CPUPlace()); + ir::OpResult {attr_name}_ = full_{attr_name}_op->GetResultByIndex(0); + """ + for idx in range(len(op_mutable_attribute_name_list)): + attr_name = op_mutable_attribute_name_list[idx] + attr_type = op_mutable_attribute_type_list[idx][0] + if attr_name in op_attribute_name_list: + phi_dtype = mutable_attribute_phi_type_maps[ + op_attribute_build_arg_type_list[ + op_attribute_name_list.index(attr_name) + ] + ] + else: + phi_dtype = mutable_attribute_phi_type_maps[ + op_mutable_attribute_type_list[idx][1] + ] + if attr_type == "paddle::dialect::IntArrayAttribute": + build_mutable_attribute += BUILD_INTARRAY_ATTRIBUTE_TEMPLATE.format( + attr_name=attr_name, phi_dtype=phi_dtype + ) + else: + build_mutable_attribute += BUILD_SCALAR_ATTRIBUTE_TEMPLATE.format( + attr_name=attr_name, phi_dtype=phi_dtype + ) + return build_mutable_attribute + + def GenBuildInputs(op_input_name_list, op_mutable_attribute_name_list): BUILD_INPUT_TEMPLATE = """ std::vector argument_inputs = {{{inputs_args}}}; argument.AddOperands(argument_inputs.begin(), argument_inputs.end()); @@ -805,38 +945,39 @@ def GenBuildOutputs( op_output_type_list, op_output_size_list, op_infer_meta_map, + mutable_attr_is_input=False, ): build_output_str = ' VLOG(4) << "Builder construction outputs";\n' - CREATE_INPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name}; - dense_{name}.set_meta( - phi::DenseTensorMeta(TransToPhiDataType({name}.dtype()), - {name}.dims(), - {name}.data_layout(), - {name}.lod(), - {name}.offset()) - ); + CREATE_INPUT_METATENSOR_TEMPLATE = """ + VLOG(4) << "Builder construction dense_{name}"; + phi::DenseTensor dense_{name}(std::make_unique(paddle::platform::CPUPlace()).get(), + phi::DenseTensorMeta(TransToPhiDataType({name}.dtype()), + {name}.dims(), + {name}.data_layout(), + {name}.lod(), + {name}.offset())); + VLOG(4) << "Builder construction meta_{name}"; phi::MetaTensor meta_{name}(&dense_{name}); """ - CREATE_INPUT_VEC_METATENSOR_TEMPLATE = """ std::vector vec_dense_{name}({name}.size(), phi::DenseTensor()); - std::vector vec_meta_{name}; + CREATE_INPUT_VEC_METATENSOR_TEMPLATE = """ std::vector vec_dense_{name}; for (size_t i=0; i < static_cast({name}.size()); i++) {{ - vec_dense_{name}[i].set_meta( - phi::DenseTensorMeta(TransToPhiDataType({name}[i].dyn_cast().dtype()), - {name}[i].dyn_cast().dims(), - {name}[i].dyn_cast().data_layout(), - {name}[i].dyn_cast().lod(), - {name}[i].dyn_cast().offset()) - ); + vec_dense_{name}.push_back(phi::DenseTensor(std::make_unique(paddle::platform::CPUPlace()).get(), + phi::DenseTensorMeta(TransToPhiDataType({name}[i].dyn_cast().dtype()), + {name}[i].dyn_cast().dims(), + {name}[i].dyn_cast().data_layout(), + {name}[i].dyn_cast().lod(), + {name}[i].dyn_cast().offset()))); + }} + std::vector vec_meta_{name}; + for (size_t i=0; i < vec_dense_{name}.size(); i++) {{ vec_meta_{name}.push_back(phi::MetaTensor(&vec_dense_{name}[i])); }} + std::vector meta_{name}; for (size_t i=0; i < static_cast(vec_meta_{name}.size()); i++) {{ meta_{name}.push_back(&vec_meta_{name}[i]); }} """ - CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE = """ std::vector {name} = {name}_.owner()->dyn_cast().value().dyn_cast().data().GetData(); (void){name};\n""" - CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE = """ {dtype} {name} = {name}_.owner()->dyn_cast().value().dyn_cast<{ir_type}>().data(); (void){name};\n""" - CREATE_STRING_MUTABLE_ATTRIBUE_TEMPLATE = """ std::string {name} = {name}_.owner()->dyn_cast().value().dyn_cast().data(); (void){name};\n""" CREATE_OUTPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name}; phi::MetaTensor meta_{name}(&dense_{name}); @@ -863,31 +1004,6 @@ def GenBuildOutputs( build_output_str += " paddle::dialect::DenseTensorType {name} = {name}_.type().dyn_cast(); (void){name};\n".format( name=op_input_name_list[idx] ) - # Prepare mutable attributes - for idx in range(len(op_mutable_attribute_name_list)): - attr_dtype = op_mutable_attribute_type_list[idx] - # int_array - if attr_dtype[0] == "paddle::dialect::IntArrayAttribute": - build_output_str += ( - CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE.format( - name=op_mutable_attribute_name_list[idx] - ) - ) - # scalar - elif attr_dtype[0] == "paddle::dialect::ScalarAttribute": - build_output_str += CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE.format( - name=op_mutable_attribute_name_list[idx], - dtype=attr_dtype[1], - ir_type=scalar_type_maps[attr_dtype[1]], - ) - # string - elif attr_dtype[0] == "ir::StrAttribute": - build_output_str += CREATE_STRING_MUTABLE_ATTRIBUE_TEMPLATE.format( - name=op_mutable_attribute_name_list[idx] - ) - else: - assert "mutable attribtue type is not right." - build_output_str += "\n" # Prepare inputs_meta_tensor & attributes for infer meta infer_meta_args = [] @@ -979,6 +1095,95 @@ def GenBuildOutputs( return build_output_str +def GenBuild( + op_class_name, + op_input_name_list, + op_input_type_list, + op_attribute_name_list, + op_attribute_build_arg_type_list, + op_attribute_default_value_list, + op_mutable_attribute_name_list, + op_mutable_attribute_type_list, + op_non_mutable_attribute_name_list, + op_non_mutable_attribute_type_list, + op_non_mutable_attribute_build_arg_type_list, + op_non_mutable_attribute_default_value_list, + op_output_name_list, + op_output_type_list, + op_output_size_list, + op_infer_meta_map, + muta_attr_is_input=False, +): + build_args_for_declare = "" + build_func = "" + + build_args_for_declare = GenBuildInputArgsStr( + op_input_name_list, + op_attribute_name_list, + op_attribute_build_arg_type_list, + op_attribute_default_value_list, + op_mutable_attribute_name_list, + op_non_mutable_attribute_name_list, + op_non_mutable_attribute_build_arg_type_list, + op_non_mutable_attribute_default_value_list, + True, + muta_attr_is_input, + ) + + build_args_for_define = GenBuildInputArgsStr( + op_input_name_list, + op_attribute_name_list, + op_attribute_build_arg_type_list, + op_attribute_default_value_list, + op_mutable_attribute_name_list, + op_non_mutable_attribute_name_list, + op_non_mutable_attribute_build_arg_type_list, + op_non_mutable_attribute_default_value_list, + False, + muta_attr_is_input, + ) + inset_full_for_mutable_attributes_str = "" + if not muta_attr_is_input: + inset_full_for_mutable_attributes_str = ( + GenBuildInserFullForMutableAttribute( + op_attribute_name_list, + op_attribute_build_arg_type_list, + op_mutable_attribute_name_list, + op_mutable_attribute_type_list, + ) + ) + + build_inputs_str = GenBuildInputs( + op_input_name_list, op_mutable_attribute_name_list + ) + build_attributes_str = GenBuildAttributes( + op_non_mutable_attribute_name_list, + op_non_mutable_attribute_type_list, + ) + build_outputs_str = GenBuildOutputs( + op_input_name_list, + op_input_type_list, + op_mutable_attribute_name_list, + op_mutable_attribute_type_list, + op_output_name_list, + op_output_type_list, + op_output_size_list, + op_infer_meta_map, + False, + ) + + build_func = OP_BUILD_TEMPLATE.format( + op_name=op_class_name, + build_args=build_args_for_define, + build_mutable_attributes=inset_full_for_mutable_attributes_str, + build_inputs=build_inputs_str, + build_attributes=build_attributes_str, + build_outputs=build_outputs_str, + ) + + return (build_args_for_declare, build_func) + + def OpGenerator( op_yaml_files, op_compat_yaml_file, @@ -1032,31 +1237,22 @@ def OpGenerator( op_attribute_data_type_list = op_info.attribute_data_type_list op_attribute_build_arg_type_list = op_info.attribute_build_arg_type_list op_attribute_default_value_list = op_info.attribute_default_value_list - op_non_mutable_attribute_name_list = [] - op_non_mutable_attribute_type_list = [] - op_non_mutable_attribute_data_type_list = [] - op_non_mutable_attribute_build_arg_type_list = [] - op_non_mutable_attribute_default_value_list = [] - for idx in range(len(op_attribute_name_list)): - if ( - op_attribute_name_list[idx] - not in op_mutable_attribute_name_list - ): - op_non_mutable_attribute_name_list.append( - op_attribute_name_list[idx] - ) - op_non_mutable_attribute_type_list.append( - op_attribute_type_list[idx] - ) - op_non_mutable_attribute_data_type_list.append( - op_attribute_data_type_list[idx] - ) - op_non_mutable_attribute_build_arg_type_list.append( - op_attribute_build_arg_type_list[idx] - ) - op_non_mutable_attribute_default_value_list.append( - op_attribute_default_value_list[idx] - ) + op_non_mutable_attribute_name_list = ( + op_info.non_mutable_attribute_name_list + ) + op_non_mutable_attribute_type_list = ( + op_info.non_mutable_attribute_type_list + ) + op_non_mutable_attribute_data_type_list = ( + op_info.non_mutable_attribute_data_type_list + ) + op_non_mutable_attribute_build_arg_type_list = ( + op_info.non_mutable_attribute_build_arg_type_list + ) + op_non_mutable_attribute_default_value_list = ( + op_info.non_mutable_attribute_default_value_list + ) + # others op_infer_meta_map = op_info.infer_meta_map op_kernel_map = op_info.kernel_map @@ -1075,7 +1271,9 @@ def OpGenerator( op_class_name = to_pascal_case(op_name) + "Op" op_dialect_name = dialect_name + "." + op_name - # gen interface/trait str + # =================================== # + # gen interface/trait list str # + # =================================== # op_interfaces_str = "" if len(op_interfaces) > 0: op_interfaces_str = "," + ",".join(op_interfaces) @@ -1083,6 +1281,9 @@ def OpGenerator( if len(op_traits) > 0: op_traits_str = "," + ",".join(op_traits) + # =================================== # + # gen get input/output methods str # + # =================================== # op_get_inputs_outputs_str = "" for idx in range(len(op_input_name_list)): op_get_inputs_outputs_str += OP_GET_INPUT_TEMPLATE.format( @@ -1100,59 +1301,73 @@ def OpGenerator( output_index=idx, ) - # gen build str - build_define_input_args_str = "" - build_declare_input_args_str = "" - build_func_declare_str = "" + # =================================== # + # gen Build methods str # + # =================================== # + build_args_with_muta_attr_not_input_for_declare = "" + build_func_with_muta_attr_not_input = "" + build_mutable_attr_is_input = "" + build_func_with_muta_attr_is_input = "" + if op_infer_meta_map is not None: - build_define_input_args_str = GenBuildInputArgsStr( - op_input_name_list, - op_mutable_attribute_name_list, - op_non_mutable_attribute_name_list, - op_non_mutable_attribute_build_arg_type_list, - op_non_mutable_attribute_default_value_list, - True, - ) - build_declare_input_args_str = GenBuildInputArgsStr( + ( + build_args_with_muta_attr_not_input_for_declare, + build_func_with_muta_attr_not_input, + ) = GenBuild( + op_class_name, op_input_name_list, + op_input_type_list, + op_attribute_name_list, + op_attribute_build_arg_type_list, + op_attribute_default_value_list, op_mutable_attribute_name_list, + op_mutable_attribute_type_list, op_non_mutable_attribute_name_list, + op_non_mutable_attribute_type_list, op_non_mutable_attribute_build_arg_type_list, op_non_mutable_attribute_default_value_list, - False, - ) - build_inputs_str = GenBuildInputs( - op_input_name_list, op_mutable_attribute_name_list - ) - build_attributes_str = GenBuildAttributes( - op_non_mutable_attribute_name_list, - op_non_mutable_attribute_type_list, - ) - build_outputs_str = GenBuildOutputs( - op_input_name_list, - op_input_type_list, - op_mutable_attribute_name_list, - op_mutable_attribute_type_list, op_output_name_list, op_output_type_list, op_output_size_list, op_infer_meta_map, + muta_attr_is_input=False, ) - build_func_declare_str = OP_BUILD_TEMPLATE.format( - op_name=op_class_name, - build_args=build_declare_input_args_str, - build_inputs=build_inputs_str, - build_attributes=build_attributes_str, - build_outputs=build_outputs_str, - ) - else: - build_func_declare_str = OP_BUILD_TEMPLATE.format( - op_name=op_class_name, - build_args=build_declare_input_args_str, - build_inputs="", - build_attributes="", - build_outputs="", - ) + op_infer_meta_args = op_infer_meta_map['param'] + if (len(op_mutable_attribute_name_list) > 0) and ( + len( + list( + set(op_infer_meta_args) + & set(op_mutable_attribute_name_list) + ) + ) + == 0 + ): + ( + build_args_with_muta_attr_is_input_for_declare, + build_func_with_muta_attr_is_input, + ) = GenBuild( + op_class_name, + op_input_name_list, + op_input_type_list, + op_attribute_name_list, + op_attribute_build_arg_type_list, + op_attribute_default_value_list, + op_mutable_attribute_name_list, + op_mutable_attribute_type_list, + op_non_mutable_attribute_name_list, + op_non_mutable_attribute_type_list, + op_non_mutable_attribute_build_arg_type_list, + op_non_mutable_attribute_default_value_list, + op_output_name_list, + op_output_type_list, + op_output_size_list, + op_infer_meta_map, + muta_attr_is_input=True, + ) + + build_mutable_attr_is_input = "static void Build({build_args});".format( + build_args=build_args_with_muta_attr_is_input_for_declare + ) # gen op_declare_str/op_defined_str if len(op_non_mutable_attribute_name_list) == 0: @@ -1163,7 +1378,8 @@ def OpGenerator( traits=op_traits_str, attribute_declare=op_0_attribute_declare_str, attribute_num=0, - build_args=build_define_input_args_str, + build_args=build_args_with_muta_attr_not_input_for_declare, + build_mutable_attr_is_input=build_mutable_attr_is_input, get_inputs_and_outputs=op_get_inputs_outputs_str, exclusive_interface=exclusive_interface_str, ) @@ -1178,7 +1394,8 @@ def OpGenerator( attribute_num=len(op_non_mutable_attribute_name_list) ), attribute_num=len(op_non_mutable_attribute_name_list), - build_args=build_define_input_args_str, + build_args=build_args_with_muta_attr_not_input_for_declare, + build_mutable_attr_is_input=build_mutable_attr_is_input, get_inputs_and_outputs=op_get_inputs_outputs_str, exclusive_interface=exclusive_interface_str, ) @@ -1191,33 +1408,35 @@ def OpGenerator( attribute_names=attribute_names_str, ) + # =================================== # + # gen GetOpInfo func str # + # =================================== # # generate get op info funciton: inputs - inputs_info_str = "" input_info_list = [] - if len(op_input_name_list) > 0: - for idx in range(len(op_input_name_list)): - input_info_list.append( - CONSTRUCT_INPUT_INFO_TEMPLATE.format( - name=op_input_name_list[idx], - typename=op_input_type_list[idx], - optional=op_input_optional_list[idx], - no_need_buffer=op_input_no_need_buffer_list[idx], - ) + for idx in range(len(op_input_name_list)): + input_info_list.append( + CONSTRUCT_INPUT_INFO_TEMPLATE.format( + name=op_input_name_list[idx], + typename=op_input_type_list[idx], + optional=op_input_optional_list[idx], + no_need_buffer=op_input_no_need_buffer_list[idx], + is_mutable_attribute='false', ) - - # add mutable attribute as input - if len(op_mutable_attribute_name_list) > 0: - for idx in range(len(op_mutable_attribute_name_list)): - input_info_list.append( - CONSTRUCT_INPUT_INFO_TEMPLATE.format( - name=op_mutable_attribute_name_list[idx], - typename=op_mutable_attribute_type_list[idx], - optional='false', - no_need_buffer='false', - ) + ) + for idx in range(len(op_mutable_attribute_name_list)): + input_info_list.append( + CONSTRUCT_INPUT_INFO_TEMPLATE.format( + name=op_mutable_attribute_name_list[idx], + typename=op_mutable_attribute_type_list[idx][0], + optional='false', + no_need_buffer='false', + is_mutable_attribute='true', ) - inputs_info_str = ", ".join(input_info_list) - + ) + if len(input_info_list) > 0: + inputs_info_str = ", ".join(input_info_list) + else: + inputs_info_str = "" # generate get op info funciton: outputs outputs_info_str = "" if len(op_output_name_list) > 0: @@ -1232,25 +1451,21 @@ def OpGenerator( ) ) outputs_info_str = ", ".join(output_info_list) - # generate get op info funciton: attributes attribute_info_str = "" - op_mutable_attribute_name_set = set(op_mutable_attribute_name_list) - if len(op_attribute_name_list) > 0: + if len(op_non_mutable_attribute_name_list) > 0: attribute_info_list = [] - for idx in range(len(op_attribute_name_list)): - attribute_name = op_attribute_name_list[idx] - if attribute_name in op_mutable_attribute_name_set: - continue + for idx in range(len(op_non_mutable_attribute_name_list)): attribute_info_list.append( CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE.format( - name=attribute_name, - typename=op_attribute_type_list[idx], - data_type=op_attribute_data_type_list[idx], + name=op_non_mutable_attribute_name_list[idx], + typename=op_non_mutable_attribute_type_list[idx], + data_type=op_non_mutable_attribute_data_type_list[ + idx + ], ) ) attribute_info_str = ", ".join(attribute_info_list) - # generate runtiem info infer_meta_func_str = "" infer_meta_param_str = "" @@ -1274,6 +1489,9 @@ def OpGenerator( kernel_param=kernel_param_str, ) + # =================================== # + # gen Verify func str # + # =================================== # # generate op verify function: inputs_type_check_str if ( len(op_input_type_list) + len(op_mutable_attribute_name_list) @@ -1327,7 +1545,6 @@ def OpGenerator( standard="paddle::dialect::DenseTensorType", ) inputs_type_check_str += check_str - # generate op verify function: outputs_type_check_str if len(op_output_type_list) == 0: outputs_type_check_str = ( @@ -1364,7 +1581,6 @@ def OpGenerator( index=idx, standard=output_type ) outputs_type_check_str += check_str - # generate op verify function: attributes_check_str if len(op_non_mutable_attribute_name_list) == 0: attributes_check_str = ( @@ -1387,7 +1603,6 @@ def OpGenerator( attributes_check_str += ATTRIBUTE_CHECK_TEMPLATE.format( attribute_name=attribute_name, standard=attribute_type ) - # generate op verify function if "GradOp" in op_class_name or "Grad_Op" in op_class_name: op_verify_str = GRAD_OP_VERIFY_TEMPLATE.format( @@ -1415,7 +1630,9 @@ def OpGenerator( ops_declare_list.append(op_declare_str) ops_defined_list.append(op_defined_str) ops_defined_list.append(op_info_func_str) - ops_defined_list.append(build_func_declare_str) + ops_defined_list.append(build_func_with_muta_attr_not_input) + if len(op_mutable_attribute_name_list) > 0: + ops_defined_list.append(build_func_with_muta_attr_is_input) ops_defined_list.append(op_verify_str) ops_defined_list.append(op_infer_shape_str) diff --git a/paddle/fluid/ir/dialect/pd_attribute.cc b/paddle/fluid/ir/dialect/pd_attribute.cc index 4829010609d..0bb42226d49 100644 --- a/paddle/fluid/ir/dialect/pd_attribute.cc +++ b/paddle/fluid/ir/dialect/pd_attribute.cc @@ -26,5 +26,23 @@ phi::DataLayout DataLayoutAttribute::data() const { return storage()->GetAsKey(); } +phi::Scalar ScalarAttribute::data() { + if (isa()) { + return phi::Scalar(dyn_cast().data()); + } else if (isa()) { + return phi::Scalar(dyn_cast().data()); + } else if (isa()) { + return phi::Scalar(dyn_cast().data()); + } else if (isa()) { + return phi::Scalar(dyn_cast().data()); + } else if (isa()) { + return phi::Scalar(dyn_cast().data()); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Unsupported ir attribute when casting it into " + "phi scalar.")); + } +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/ir/dialect/pd_attribute.h b/paddle/fluid/ir/dialect/pd_attribute.h index 67093f96553..1cc1c90950d 100644 --- a/paddle/fluid/ir/dialect/pd_attribute.h +++ b/paddle/fluid/ir/dialect/pd_attribute.h @@ -17,6 +17,8 @@ #include "paddle/fluid/ir/dialect/pd_attribute_storage.h" #include "paddle/ir/core/attribute.h" #include "paddle/ir/core/builtin_attribute.h" +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/enforce.h" namespace paddle { namespace dialect { @@ -45,6 +47,8 @@ class ScalarAttribute : public ir::Attribute { (val.type_id() == ir::Int32_tAttribute::type_id()) || (val.type_id() == ir::Int64_tAttribute::type_id()); } + + phi::Scalar data(); }; class DataTypeAttribute : public ir::Attribute { diff --git a/paddle/fluid/ir/dialect/utils.h b/paddle/fluid/ir/dialect/utils.h index 65f68db2c6a..14f83d81809 100644 --- a/paddle/fluid/ir/dialect/utils.h +++ b/paddle/fluid/ir/dialect/utils.h @@ -101,14 +101,17 @@ struct OpInputInfo { std::string type_name; bool optional = false; bool no_need_buffer = false; + bool is_mutable_attribute = false; OpInputInfo(std::string name, std::string type_name, bool optional, - bool no_need_buffer) + bool no_need_buffer, + bool is_mutable_attribute) : name(name), type_name(type_name), optional(optional), - no_need_buffer(no_need_buffer) {} + no_need_buffer(no_need_buffer), + is_mutable_attribute(is_mutable_attribute) {} }; struct OpOutputInfo { diff --git a/paddle/ir/core/builder.h b/paddle/ir/core/builder.h index c9fef2da234..80255c3280b 100644 --- a/paddle/ir/core/builder.h +++ b/paddle/ir/core/builder.h @@ -56,7 +56,7 @@ class Builder { template OpTy Build(Args &&...args) { OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name())); - OpTy::Build(argument, std::forward(args)...); + OpTy::Build(*this, argument, std::forward(args)...); Operation *op = Build(std::move(argument)); return op->dyn_cast(); } diff --git a/paddle/ir/core/builtin_op.cc b/paddle/ir/core/builtin_op.cc index 63e8567a5c4..55bfc48bcf9 100644 --- a/paddle/ir/core/builtin_op.cc +++ b/paddle/ir/core/builtin_op.cc @@ -57,20 +57,15 @@ void ModuleOp::Verify(const std::vector &inputs, const ir::AttributeMap &attributes) { VLOG(4) << "Verifying inputs, outputs and attributes for: ModuleOp."; // Verify inputs type: - if (inputs.size() != 0) { - throw("The size of inputs must be equal to 0."); - } + IR_ENFORCE(inputs.size() == 0, "The size of inputs must be equal to 0."); // Verify if attributes contain attribute name in attributes_name: auto iter = attributes.find("program"); - if (iter == attributes.end() || !iter->second.isa()) { - throw("Type of attribute: program is not right."); - } + IR_ENFORCE(iter != attributes.end() && iter->second.isa(), + "Type of attribute: program is not right."); // Verify outputs type: - if (outputs.size() != 0) { - throw("The size of outputs must be equal to 0."); - } + IR_ENFORCE(outputs.size() == 0, "The size of outputs must be equal to 0."); } const char *GetParameterOp::attributes_name[attributes_num] = { @@ -81,17 +76,15 @@ void GetParameterOp::Verify(const std::vector &inputs, const ir::AttributeMap &attributes) { VLOG(4) << "Verifying inputs, outputs and attributes for: GetParameterOp."; // Verify inputs type: - if (inputs.size() != 0) { - throw("The size of inputs must be equal to 0."); - } - // Verify outputs type: - if (outputs.size() != 1) { - throw("The size of outputs must be equal to 1."); - } + IR_ENFORCE(inputs.size() == 0, "The size of inputs must be equal to 0."); + // Verify if attributes contain attribute name in attributes_name: - if (!attributes.at("parameter_name").isa()) { - throw("Type of attribute: parameter_name is not right."); - } + auto iter = attributes.find("parameter_name"); + IR_ENFORCE(iter != attributes.end() && iter->second.isa(), + "Type of attribute: parameter_name is not right."); + + // Verify outputs type: + IR_ENFORCE(outputs.size() == 1, "The size of outputs must be equal to 1."); } const char *SetParameterOp::attributes_name[attributes_num] = { @@ -102,54 +95,45 @@ void SetParameterOp::Verify(const std::vector &inputs, const ir::AttributeMap &attributes) { VLOG(4) << "Verifying inputs, outputs and attributes for: SetParameterOp."; // Verify inputs type: - if (inputs.size() != 1) { - throw("The size of inputs must be equal to 1."); - } - // Verify outputs type: - if (outputs.size() != 0) { - throw("The size of outputs must be equal to 0."); - } + IR_ENFORCE(inputs.size() == 1, "The size of outputs must be equal to 1."); + // Verify if attributes contain attribute name in attributes_name: - if (!attributes.at("parameter_name").isa()) { - throw("Type of attribute: parameter_name is not right."); - } + auto iter = attributes.find("parameter_name"); + IR_ENFORCE(iter != attributes.end() && iter->second.isa(), + "Type of attribute: parameter_name is not right."); + + // Verify outputs type: + IR_ENFORCE(outputs.size() == 0, "The size of outputs must be equal to 0."); } void CombineOp::Verify(const std::vector &inputs, const std::vector &outputs, const ir::AttributeMap &attributes) { // outputs.size() == 1 - PADDLE_ENFORCE_EQ( - outputs.size(), - 1, - phi::errors::PreconditionNotMet( - "The size %d of outputs must be equal to 1.", outputs.size())); + IR_ENFORCE(outputs.size() == 1, + "The size %d of outputs must be equal to 1.", + outputs.size()); + // outputs[0].type == Vector - PADDLE_ENFORCE(outputs[0].isa(), - phi::errors::PreconditionNotMet( - "The type %s of outputs[0] must be equal to VectorType.", - outputs[0])); + IR_ENFORCE(outputs[0].isa(), + "The type %s of outputs[0] must be equal to VectorType.", + outputs[0]); ir::VectorType output_type = outputs[0].dyn_cast(); // inputs.size() == outputs[0].size() - PADDLE_ENFORCE_EQ( - output_type.size(), - inputs.size(), - phi::errors::PreconditionNotMet( - "The size %d of outputs[0] must be equal to size %d of inputs.", - output_type.size(), - inputs.size())); + IR_ENFORCE(output_type.size() == inputs.size(), + "The size %d of outputs[0] must be equal to size %d of inputs.", + output_type.size(), + inputs.size()); // forall i in inputs.size(): inputs[i].type == outputs[0][i].type for (size_t i = 0; i < inputs.size(); i++) { - PADDLE_ENFORCE_EQ( - output_type[i], - inputs[i].type(), - phi::errors::PreconditionNotMet("The type %s of outputs[0][%d] must be " - "equal to type %s of inputs[%d].", - output_type[i], - i, - inputs[i].type(), - i)); + IR_ENFORCE(output_type[i] == inputs[i].type(), + "The type %s of outputs[0][%d] must be " + "equal to type %s of inputs[%d].", + output_type[i], + i, + inputs[i].type(), + i); } } @@ -158,65 +142,50 @@ void SliceOp::Verify(const std::vector &inputs, const std::vector &outputs, const ir::AttributeMap &attributes) { // inputs.size() == 1 - PADDLE_ENFORCE_EQ( - inputs.size(), - 1, - phi::errors::PreconditionNotMet( - "The size %d of inputs must be equal to 1.", inputs.size())); + IR_ENFORCE(inputs.size() == 1, + "The size %d of inputs must be equal to 1.", + inputs.size()); // inputs[0].type == Vector - PADDLE_ENFORCE(inputs[0].type().isa(), - phi::errors::PreconditionNotMet( - "The type %s of inputs[0] must be equal to VectorType.", - inputs[0].type())); + IR_ENFORCE(inputs[0].type().isa(), + "The type %s of inputs[0] must be equal to VectorType.", + inputs[0].type()); ir::VectorType input_type = inputs[0].type().dyn_cast(); // outputs.size() == 1 - PADDLE_ENFORCE_EQ( - outputs.size(), - 1, - phi::errors::PreconditionNotMet( - "The size %d of outputs must be equal to 1.", outputs.size())); + IR_ENFORCE(outputs.size() == 1, + "The size %d of outputs must be equal to 1.", + outputs.size()); // attributes contains index: Int32 - PADDLE_ENFORCE_NE( - attributes.count("index"), - 0, - phi::errors::PreconditionNotMet("The attributes must contains index.")); + IR_ENFORCE(attributes.count("index") != 0, + "The attributes must contains index."); const ir::Attribute &attr = attributes.at("index"); - PADDLE_ENFORCE( - attr.isa(), - phi::errors::PreconditionNotMet("The attribute index must be INT32.")); + IR_ENFORCE(attr.isa(), + "The attribute index must be INT32."); auto index = attr.dyn_cast().data(); // index >= 0 and < inputs[0].size() - PADDLE_ENFORCE_GE( - index, - 0, - phi::errors::PreconditionNotMet( - "The index %d must be greater or equal than 0.", index)); - PADDLE_ENFORCE_LT( - index, - input_type.size(), - phi::errors::PreconditionNotMet( - "The index %d must be less or equal than size %d of inputs[0].", - index, - input_type.size())); + IR_ENFORCE( + index >= 0, "The index %d must be greater or equal than 0.", index); + IR_ENFORCE(static_cast(index) < input_type.size(), + "The index %d must be less or equal than size %d of inputs[0].", + index, + input_type.size()); // inputs[index].type == outputs[0].type - PADDLE_ENFORCE_EQ( + IR_ENFORCE( + input_type[index] == outputs[0], + "The type %s of inputs[%d] must be equal to type %s of outputs[0].", input_type[index], - outputs[0], - phi::errors::PreconditionNotMet( - "The type %s of inputs[%d] must be equal to type %s of outputs[0].", - input_type[index], - index, - outputs[0])); + index, + outputs[0]); } const char *ConstantOp::attributes_name[attributes_num] = {"value"}; -void ConstantOp::Build(OperationArgument &argument, +void ConstantOp::Build(Builder &builder, + OperationArgument &argument, Attribute value, Type output_type) { argument.AddAttribute("value", value); diff --git a/paddle/ir/core/builtin_op.h b/paddle/ir/core/builtin_op.h index 6e54841fe90..8ce62b1e886 100644 --- a/paddle/ir/core/builtin_op.h +++ b/paddle/ir/core/builtin_op.h @@ -86,6 +86,7 @@ class CombineOp : public ir::Op { static constexpr uint32_t attributes_num = 0; static constexpr const char **attributes_name = nullptr; + static void Verify(const std::vector &inputs, const std::vector &outputs, const ir::AttributeMap &attributes); @@ -125,7 +126,8 @@ class ConstantOp : public Op { static constexpr uint32_t attributes_num = 1; static const char *attributes_name[attributes_num]; - static void Build(OperationArgument &argument, // NOLINT + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT Attribute value, Type output_type); diff --git a/paddle/ir/core/operation.cc b/paddle/ir/core/operation.cc index 5532d256d89..198259c2416 100644 --- a/paddle/ir/core/operation.cc +++ b/paddle/ir/core/operation.cc @@ -16,6 +16,7 @@ #include "paddle/ir/core/block.h" #include "paddle/ir/core/dialect.h" +#include "paddle/ir/core/enforce.h" #include "paddle/ir/core/op_info.h" #include "paddle/ir/core/operation.h" #include "paddle/ir/core/program.h" @@ -85,7 +86,7 @@ Operation *Operation::Create(const std::vector &inputs, base_ptr += sizeof(Operation); // 3.3. Construct OpOperands. if ((reinterpret_cast(base_ptr) & 0x7) != 0) { - throw("The address of OpOperandImpl must be divisible by 8."); + IR_THROW("The address of OpOperandImpl must be divisible by 8."); } for (size_t idx = 0; idx < num_operands; idx++) { new (base_ptr) detail::OpOperandImpl(inputs[idx].impl_, op); @@ -147,7 +148,7 @@ void Operation::Destroy() { // 2.2. Deconstruct Operation. if (reinterpret_cast(base_ptr) != reinterpret_cast(this)) { - throw("Operation address error"); + IR_THROW("Operation address error"); } reinterpret_cast(base_ptr)->~Operation(); base_ptr += sizeof(Operation); @@ -178,7 +179,7 @@ Operation::Operation(const AttributeMap &attributes, ir::OpResult Operation::GetResultByIndex(uint32_t index) const { if (index >= num_results_) { - throw("index exceeds OP output range."); + IR_THROW("index exceeds OP output range."); } uint32_t max_inline_idx = detail::OpResultImpl::GetMaxInlineResultIndex(); const char *ptr = @@ -199,7 +200,7 @@ ir::OpResult Operation::GetResultByIndex(uint32_t index) const { ir::OpOperand Operation::GetOperandByIndex(uint32_t index) const { if (index >= num_operands_) { - throw("index exceeds OP input range."); + IR_THROW("index exceeds OP input range."); } const char *ptr = reinterpret_cast(this) + sizeof(Operation) + (index) * sizeof(detail::OpOperandImpl); diff --git a/paddle/ir/core/storage_manager.cc b/paddle/ir/core/storage_manager.cc index 41a52e85b30..8bc74b23c1d 100644 --- a/paddle/ir/core/storage_manager.cc +++ b/paddle/ir/core/storage_manager.cc @@ -17,6 +17,8 @@ #include #include +#include "paddle/ir/core/enforce.h" + namespace ir { // This is a structure for creating, caching, and looking up Storage of // parametric types. @@ -76,7 +78,7 @@ StorageManager::StorageBase *StorageManager::GetParametricStorageImpl( << std::hash()(type_id) << ", param_hash=" << hash_value << "]."; if (parametric_instance_.find(type_id) == parametric_instance_.end()) { - throw("The input data pointer is null."); + IR_THROW("The input data pointer is null."); } ParametricStorageManager ¶metric_storage = *parametric_instance_[type_id]; return parametric_storage.GetOrCreate(hash_value, equal_func, constructor); @@ -88,7 +90,7 @@ StorageManager::StorageBase *StorageManager::GetParameterlessStorageImpl( VLOG(4) << "Try to get a parameterless storage of: [TypeId_hash=" << std::hash()(type_id) << "]."; if (parameterless_instance_.find(type_id) == parameterless_instance_.end()) - throw("TypeId not found in IrContext."); + IR_THROW("TypeId not found in IrContext."); StorageBase *parameterless_instance = parameterless_instance_[type_id]; return parameterless_instance; } @@ -107,7 +109,7 @@ void StorageManager::RegisterParameterlessStorageImpl( VLOG(4) << "Register a parameterless storage of: [TypeId_hash=" << std::hash()(type_id) << "]."; if (parameterless_instance_.find(type_id) != parameterless_instance_.end()) - throw("storage class already registered"); + IR_THROW("storage class already registered"); parameterless_instance_.emplace(type_id, constructor()); } diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 9ca58e7738a..98bc678baaa 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -427,6 +427,18 @@ data_type : dtype backend : place +- op : full_int_array + args : (IntArray value, DataType dtype=DataType::FLOAT32, Place place=CPUPlace()) + output: Tensor(out) + infer_meta : + func : CreateIntArrayInferMeta + param : [value, dtype] + kernel : + func : full_int_array + param : [value, dtype] + data_type : dtype + backend : place + - op : full_like args : (Tensor x, Scalar value, DataType dtype = DataType::UNDEFINED, Place place = {}) output: Tensor(out) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 2c9e720aedc..e591c2c2cd8 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -62,9 +62,9 @@ beta2 : data_type : float tensor_name : Beta2Tensor - episilon : + epsilon : data_type : float - tensor_name : EpisilonTensor + tensor_name : EpsilonTensor manual_signature : [adam_] - op : adamax_ @@ -85,9 +85,9 @@ beta2 : data_type : float tensor_name : Beta2Tensor - episilon : + epsilon : data_type : float - tensor_name : EpisilonTensor + tensor_name : EpsilonTensor - op : add (elementwise_add) backward : add_grad (elementwise_add_grad) @@ -1970,7 +1970,7 @@ outputs: out : Out int_array: - axis : + dims : data_type : int extra : attrs : [bool use_mkldnn = false] diff --git a/paddle/phi/infermeta/nullary.cc b/paddle/phi/infermeta/nullary.cc index b89001572e4..0025a2bdf60 100644 --- a/paddle/phi/infermeta/nullary.cc +++ b/paddle/phi/infermeta/nullary.cc @@ -41,6 +41,15 @@ void CreateInferMeta(const IntArray& shape, DataType dtype, MetaTensor* out) { CreateInferMetaBase(shape.GetData(), dtype, DataLayout::NCHW, out); } +void CreateIntArrayInferMeta(const IntArray& data, + DataType dtype, + MetaTensor* out) { + CreateInferMetaBase({static_cast(data.GetData().size())}, + dtype, + DataLayout::NCHW, + out); +} + void CreateInferMetaBase(const std::vector& shape, DataType dtype, DataLayout layout, diff --git a/paddle/phi/infermeta/nullary.h b/paddle/phi/infermeta/nullary.h index e26d918b5fe..bc630fcba74 100644 --- a/paddle/phi/infermeta/nullary.h +++ b/paddle/phi/infermeta/nullary.h @@ -35,6 +35,10 @@ void AssignValueInferMeta(const std::vector& shape, DataType dtype, MetaTensor* out); +void CreateIntArrayInferMeta(const IntArray& data, + DataType dtype, + MetaTensor* out); + void CreateInferMeta(const IntArray& shape, DataType dtype, MetaTensor* out); void CreateInferMetaBase(const std::vector& shape, diff --git a/paddle/phi/kernels/cpu/full_kernel.cc b/paddle/phi/kernels/cpu/full_kernel.cc index a295bfcc20c..20d50d88f68 100644 --- a/paddle/phi/kernels/cpu/full_kernel.cc +++ b/paddle/phi/kernels/cpu/full_kernel.cc @@ -80,6 +80,18 @@ void FullLikeKernel(const Context& dev_ctx, FullValue(dev_ctx, out, value); } +template +void FullIntArrayKernel(const Context& dev_ctx, + const IntArray& val, + DataType dtype UNUSED, + DenseTensor* out) { + out->Resize(phi::make_ddim({static_cast(val.GetData().size())})); + T* out_data = dev_ctx.template Alloc(out); + for (size_t i = 0; i < val.GetData().size(); ++i) { + out_data[i] = static_cast(val.GetData()[i]); + } +} + } // namespace phi PD_REGISTER_KERNEL(full, @@ -115,3 +127,6 @@ PD_REGISTER_KERNEL(full_like, phi::dtype::complex) { kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); } + +PD_REGISTER_KERNEL( + full_int_array, CPU, ALL_LAYOUT, phi::FullIntArrayKernel, int, int64_t) {} diff --git a/paddle/phi/kernels/full_kernel.h b/paddle/phi/kernels/full_kernel.h index 598bd040e58..a5c5c705c9e 100644 --- a/paddle/phi/kernels/full_kernel.h +++ b/paddle/phi/kernels/full_kernel.h @@ -83,4 +83,10 @@ DenseTensor FullLike(const Context& dev_ctx, return dense_out; } +template +void FullIntArrayKernel(const Context& dev_ctx, + const IntArray& val, + DataType dtype, + DenseTensor* out); + } // namespace phi diff --git a/test/cpp/ir/core/ir_exe_test.cc b/test/cpp/ir/core/ir_exe_test.cc index 0ca3b996818..7c7dc76a5f4 100644 --- a/test/cpp/ir/core/ir_exe_test.cc +++ b/test/cpp/ir/core/ir_exe_test.cc @@ -44,76 +44,61 @@ #include "paddle/phi/core/kernel_registry.h" #include "test/cpp/ir/core/phi_kernel_adaptor.h" +PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(full_int_array, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(uniform, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT); bool simple_cmp(float a, float b) { return std::abs((a - b) / a) < 1e-5; } TEST(program_test, program) { + // Prepare ir env ir::IrContext* ctx = ir::IrContext::Instance(); ctx->GetOrRegisterDialect(); - ir::Program program(ctx); + ir::Builder builder = ir::Builder::AtBlockEnd(ctx, program.block()); ir::Block* block = program.block(); - ir::Type fp32_dtype = ir::Float32Type::get(ctx); - - paddle::dialect::DenseTensorTypeStorage::Dim dims = {2, 2}; - paddle::dialect::DenseTensorTypeStorage::DataLayout data_layout = - paddle::dialect::DenseTensorTypeStorage::DataLayout::NCHW; - paddle::dialect::DenseTensorTypeStorage::LoD lod = {}; - size_t offset = 0; - ir::Type dense_tensor_dtype = paddle::dialect::DenseTensorType::get( - ctx, fp32_dtype, dims, data_layout, lod, offset); - - // (1) Def a = GetParameterOp("a") - std::string op1_name = std::string(paddle::dialect::UniformOp::name()); - ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op1_name); - - // ir::Attribute shape_1 = ir::ArrayAttribute::get(ctx, {ten} ); - ir::Attribute shape_1 = paddle::dialect::IntArrayAttribute::get( - ctx, std::vector({2, 2})); - ir::Attribute data_type = - paddle::dialect::DataTypeAttribute::get(ctx, phi::DataType::FLOAT32); - ir::Attribute min = ir::FloatAttribute::get(ctx, 0.0); - ir::Attribute max = ir::FloatAttribute::get(ctx, 1.0); - ir::Attribute seed = ir::Int32_tAttribute::get(ctx, 2); - ir::Attribute uni_place = paddle::dialect::PlaceAttribute::get( - ctx, phi::Place(phi::AllocationType::CPU)); - std::unordered_map op1_attribute{ - {"shape", shape_1}, - {"dtype", data_type}, - {"min", min}, - {"max", max}, - {"seed", seed}, - {"place", uni_place}}; - ir::Operation* op1 = - ir::Operation::Create({}, op1_attribute, {dense_tensor_dtype}, op1_info); - - block->push_back(op1); - - // (2) Def b = GetParameterOp("b") - std::string op2_name = std::string(paddle::dialect::UniformOp::name()); - ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name); - ir::Attribute ten2 = ir::Int32_tAttribute::get(ctx, 3); - std::unordered_map op2_attribute{{"shape", ten2}}; - ir::Operation* op2 = - ir::Operation::Create({}, op1_attribute, {dense_tensor_dtype}, op2_info); - block->push_back(op2); - - // (3) Def out = AddOp(a, b) - std::string add_op_name = std::string(paddle::dialect::AddOp::name()); - ir::OpInfo add_op_info = ctx->GetRegisteredOpInfo(add_op_name); - ir::Operation* add_op = ir::Operation::Create( - {op1->GetResultByIndex(0), op2->GetResultByIndex(0)}, - {}, - {dense_tensor_dtype}, - add_op_info); - block->push_back(add_op); + // Def: A = paddle::dialect::UniformOp(std::vector shape, + // phi::DataType dtype, float min, float max, int seed, phi::Place place) + paddle::dialect::UniformOp uniform1 = + builder.Build(std::vector{2, 2}, + phi::DataType::FLOAT32, + 0.0, + 1.0, + 2, + phi::CPUPlace()); + EXPECT_EQ(uniform1->GetResultByIndex(0) + .type() + .isa(), + true); + EXPECT_EQ(block->size(), 4u); + + // Def: B = paddle::dialect::UniformOp(...) + paddle::dialect::UniformOp uniform2 = + builder.Build(std::vector{2, 2}, + phi::DataType::FLOAT32, + 0.0, + 1.0, + 2, + phi::CPUPlace()); + EXPECT_EQ(uniform2->GetResultByIndex(0) + .type() + .isa(), + true); + EXPECT_EQ(block->size(), 8u); + + // Def: C = paddle::dialect::AddOp(ir::OpResult x_, ir::OpResult y_) + paddle::dialect::AddOp add = builder.Build( + uniform1->GetResultByIndex(0), uniform2->GetResultByIndex(0)); + EXPECT_EQ( + add->GetResultByIndex(0).type().isa(), + true); + EXPECT_EQ(block->size(), 9u); + + // Execute program paddle::framework::Scope scope; - PhiKernelAdaptor phi_kernel_adaptor(&scope); - phi_kernel_adaptor.run(&program); auto out_tensor = diff --git a/test/cpp/ir/core/ir_op_test.cc b/test/cpp/ir/core/ir_op_test.cc index db839781615..37b29b0175e 100644 --- a/test/cpp/ir/core/ir_op_test.cc +++ b/test/cpp/ir/core/ir_op_test.cc @@ -20,6 +20,7 @@ #include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/dialect.h" +#include "paddle/ir/core/enforce.h" #include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/op_base.h" #include "paddle/ir/core/program.h" @@ -240,10 +241,12 @@ TEST(op_test, module_op_death) { ir::AttributeMap attrs{{"program", ir::Int32_tAttribute::get(ctx, 1)}}; std::vector output_types = {ir::Float32Type::get(ctx)}; - EXPECT_THROW(ir::Operation::Create(inputs, {}, {}, op_info), const char *); - EXPECT_THROW(ir::Operation::Create({}, attrs, {}, op_info), const char *); + EXPECT_THROW(ir::Operation::Create(inputs, {}, {}, op_info), + ir::IrNotMetException); + EXPECT_THROW(ir::Operation::Create({}, attrs, {}, op_info), + ir::IrNotMetException); EXPECT_THROW(ir::Operation::Create({}, {}, output_types, op_info), - const char *); + ir::IrNotMetException); ir::Program program(ctx); diff --git a/test/cpp/ir/core/phi_kernel_adaptor.h b/test/cpp/ir/core/phi_kernel_adaptor.h index 4916582dfbf..2e44ff55848 100644 --- a/test/cpp/ir/core/phi_kernel_adaptor.h +++ b/test/cpp/ir/core/phi_kernel_adaptor.h @@ -98,27 +98,30 @@ void build_context(ir::Operation* op, op->dyn_cast(); auto op_info_res = op_info_interface.GetOpInfo(); + // inputs include input and mutable attributes auto input_info = std::get<0>(op_info_res); - - std::set input_set; + std::map input_index_map; + std::map mutable_attr_type_map; + int input_index = 0; for (auto& t : input_info) { VLOG(6) << t.name << "\t" << t.type_name; - - input_set.insert(t.name); + input_index_map[t.name] = input_index++; + if (t.is_mutable_attribute) { + mutable_attr_type_map[t.name] = t.type_name; + } } - auto attr_map = op->attributes(); - - std::map attr_type_map; auto attr_info = std::get<1>(op_info_res); + std::map attr_type_map; for (auto& t : attr_info) { VLOG(6) << t.name << "\t" << t.type_name; attr_type_map[t.name] = t.type_name; } + auto attr_map = op->attributes(); auto runtime_info = std::get<3>(op_info_res); - int input_index = 0; + // int input_index = 0; std::vector vec_param_list; if (is_infer_meta) { vec_param_list = runtime_info.infer_meta_param; @@ -126,13 +129,31 @@ void build_context(ir::Operation* op, vec_param_list = runtime_info.kernel_param; } for (auto& t : vec_param_list) { - if (input_set.count(t)) { + if (input_index_map.count(t)) { // get information from input - ir::Value ptr = op->GetOperandByIndex(input_index++).source(); + ir::Value ptr = op->GetOperandByIndex(input_index_map[t]).source(); auto in_var_name = name_map.at(ptr); - ctx->EmplaceBackInput( - scope->Var(in_var_name)->GetMutable()); + if (mutable_attr_type_map.count(t)) { + VLOG(6) << "ctx->EmplaceBack mutable attr: " << t << "\t" + << in_var_name; + if (mutable_attr_type_map[t] == "paddle::dialect::IntArrayAttribute") { + ctx->EmplaceBackAttr(phi::IntArray( + *(scope->Var(in_var_name)->GetMutable()))); + } else if (mutable_attr_type_map[t] == + "paddle::dialect::ScalarAttribute") { + ctx->EmplaceBackAttr(phi::Scalar( + *(scope->Var(in_var_name)->GetMutable()))); + } else { + PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ", + mutable_attr_type_map[t])); + } + + } else { + VLOG(6) << "ctx->EmplaceBackInput: " << t << "\t" << in_var_name; + ctx->EmplaceBackInput( + scope->Var(in_var_name)->GetMutable()); + } } if (attr_type_map.count(t)) { @@ -149,10 +170,14 @@ void build_context(ir::Operation* op, } else if (type_name == "paddle::dialect::PlaceAttribute") { ctx->EmplaceBackAttr( attr_map[t].dyn_cast().data()); + } else if (type_name == "paddle::dialect::ScalarAttribute") { + ctx->EmplaceBackAttr( + attr_map[t].dyn_cast().data()); } else { PADDLE_THROW(phi::errors::Unimplemented("attr type not support [%s] ", type_name)); } + VLOG(6) << "ctx->EmplaceBackAttr: " << t; } } @@ -197,6 +222,9 @@ class PhiKernelAdaptor { phi::KernelKey kernel_key(phi::TransToPhiBackend(cpu_place), phi::DataLayout::ANY, phi::DataType::FLOAT32); + if (runtime_info.kernel_func[0] == "full_int_array") { + kernel_key.set_dtype(phi::DataType::INT64); + } auto found_it = phi_kernels.find(kernel_key); if (found_it == phi_kernels.end()) { std::cerr << "kernel name " << runtime_info.kernel_func[0] << std::endl; -- GitLab