From 4bd5b695b324f6ebdab15b83fb5cc5c5844a01fe Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Thu, 1 Jun 2023 22:02:53 +0800 Subject: [PATCH] [IR] Support static build function for op builder (#54197) * add build * add build * refine code * refine code * refine code * refine code * refine interface * fix bug * fix bug * fix bug * refine yaml --- paddle/fluid/dialect/op_gen.py | 621 ++++++++++++++++++--- paddle/fluid/dialect/pd_dialect.cc | 27 +- paddle/fluid/dialect/pd_interface.h | 15 +- paddle/fluid/dialect/pd_op.yaml | 22 - paddle/fluid/dialect/pd_type.cc | 13 +- paddle/fluid/dialect/pd_type.h | 7 +- paddle/fluid/dialect/pd_type_storage.h | 61 +- paddle/fluid/dialect/utils.h | 144 +++-- paddle/fluid/translator/type_translator.cc | 2 +- paddle/ir/core/builtin_attribute_storage.h | 2 +- paddle/phi/api/yaml/backward.yaml | 2 +- paddle/phi/api/yaml/legacy_backward.yaml | 2 +- test/cpp/ir/core/ir_program_test.cc | 10 +- 13 files changed, 672 insertions(+), 256 deletions(-) diff --git a/paddle/fluid/dialect/op_gen.py b/paddle/fluid/dialect/op_gen.py index 11ce656a2e9..0d8c4d336f1 100644 --- a/paddle/fluid/dialect/op_gen.py +++ b/paddle/fluid/dialect/op_gen.py @@ -31,6 +31,8 @@ H_FILE_TEMPLATE = """#ifdef GET_OP_LIST #include +#include "paddle/ir/core/builder.h" +#include "paddle/ir/core/operation_utils.h" #include "paddle/ir/core/op_base.h" #include "paddle/fluid/dialect/utils.h" #include "paddle/fluid/dialect/pd_interface.h" @@ -54,6 +56,7 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{ {attribute_declare} static constexpr uint32_t attributes_num = {attribute_num}; static OpInfoTuple GetOpInfo(); + static void build({build_args}); static void verify(const std::vector &inputs, const std::vector &outputs, const ir::AttributeMap &attributes); {get_inputs_and_outputs} {exclusive_interface} @@ -81,6 +84,14 @@ CC_FILE_TEMPLATE = """#include "{h_file}" #include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/ir_context.h" #include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/infermeta/binary.h" +#include "paddle/phi/infermeta/multiary.h" +#include "paddle/phi/infermeta/nullary.h" +#include "paddle/phi/infermeta/unary.h" +#include "paddle/phi/infermeta/ternary.h" +#include "paddle/phi/infermeta/backward.h" + #include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/nullary.h" @@ -97,45 +108,35 @@ OP_N_ATTRIBUTE_DEFINED_TEMPLATE = """ const char *{op_name}::attributes_name[{attribute_num}] = {{ {attribute_names} }}; """ -# get op input info +# get op info OP_INFO_TEMPLATE = """ OpInfoTuple {op_name}::GetOpInfo() {{ - std::vector inputs = {{ {inputs} }}; - std::vector attributes = {{ {attributes} }}; - std::vector outputs = {{ {outputs} }}; - return std::make_tuple(inputs, attributes, outputs); -}} -""" - -OP_INPUT_INFO_TEMPLATE = """ -std::vector {op_name}::inputs_info() {{ - return {{ {impl} }}; + std::vector inputs = {{ {inputs} }}; + std::vector attributes = {{ {attributes} }}; + std::vector outputs = {{ {outputs} }}; + paddle::dialect::OpRunTimeInfo run_time_info = OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, {{"{kernel_func}"}}, {{"{kernel_param}"}}); + return std::make_tuple(inputs, attributes, outputs, run_time_info); }} """ CONSTRUCT_INPUT_INFO_TEMPLATE = ( """OpInputInfo("{name}", "{typename}", {optional}, {no_need_buffer})""" ) - -# get op output info -OP_OUTPUT_INFO_TEMPLATE = """ -std::vector {op_name}::outputs_info() {{ - return {{ {impl} }}; -}} -""" CONSTRUCT_OUTPUT_INFO_TEMPLATE = ( """OpOutputInfo("{name}", "{typename}", {optional}, {intermediate})""" ) - -# get op attribute info -OP_ATTRIBUTE_INFO_TEMPLATE = """ -std::vector {op_name}::attributes_info() {{ - return {{ {impl} }}; -}} -""" CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE = ( """OpAttributeInfo("{name}", "{typename}", "{data_type}")""" ) +# build +OP_BUILD_TEMPLATE = """ +void {op_name}::build({build_args}) {{ +{build_inputs} +{build_attributes} +{build_outputs} +}} +""" + # verify OP_VERIFY_TEMPLATE = """ void {op_name}::verify(const std::vector &inputs, const std::vector &outputs, const ir::AttributeMap &attributes) {{ @@ -154,6 +155,14 @@ void {op_name}::verify(const std::vector &inputs, const std::vecto }} """ +GRAD_OP_VERIFY_TEMPLATE = """ +void {op_name}::verify(const std::vector &inputs, const std::vector &outputs, const ir::AttributeMap &attributes) {{ + (void)inputs; + (void)outputs; + (void)attributes; +}} +""" + INPUT_TYPE_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(inputs[{index}].type().isa<{standard}>(), true, phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); """ @@ -216,14 +225,10 @@ OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """if (outputs[{index}]) {{ }} """ -ATTRIBUTE_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.count("{attribute_name}")>0, true, - phi::errors::PreconditionNotMet("The AttributeMap miss mandatory attributes of: {attribute_name}.")); - PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").isa<{standard}>(), true, +ATTRIBUTE_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa<{standard}>(), true, phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right.")); """ -ATTRIBUTE_VECTOR_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.count("{attribute_name}")>0, true, - phi::errors::PreconditionNotMet("The AttributeMap miss mandatory attributes of: {attribute_name}.")); - PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").isa(), true, +ATTRIBUTE_VECTOR_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa(), true, phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right.")); for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast().size(); i++) {{ PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").dyn_cast()[i].isa<{standard}>(), true, @@ -286,6 +291,7 @@ class OpInfoParser: # parse outputs self.output_name_list = self.parse_output_name_list() self.output_type_list = self.parse_output_type_list() + self.output_size_list = self.parse_output_size_list() self.output_optional_list = self.parse_output_optional_list() self.output_intermediate_list = self.parse_output_intermediate_list() self.cross_check( @@ -294,11 +300,67 @@ class OpInfoParser: self.output_optional_list, ) # parse attributes + self.attr_types_map = { + 'IntArray': ['paddle::dialect::IntArrayAttribute', 'IntArray'], + 'Scalar': ['paddle::dialect::ScalarAttribute', 'Scalar'], + 'Scalar(int)': ['paddle::dialect::ScalarAttribute', 'int'], + 'Scalar(int64_t)': ['paddle::dialect::ScalarAttribute', 'int64_t'], + 'Scalar(float)': ['paddle::dialect::ScalarAttribute', 'float'], + 'Scalar(dobule)': ['paddle::dialect::ScalarAttribute', 'dobule'], + 'Scalar[]': [ + 'ir::ArrayAttribute', + 'std::vector', + ], + 'int': ['ir::Int32_tAttribute', 'int'], + 'int32_t': ['ir::Int32_tAttribute', 'int32_t'], + 'int64_t': ['ir::Int64_tAttribute', 'int64_t'], + 'long': ['ir::LongAttribute', 'long'], + 'size_t': ['ir::Size_tAttribute', 'size_t'], + 'float': ['ir::FloatAttribute', 'float'], + 'float[]': [ + 'ir::ArrayAttribute', + 'std::vector', + ], + 'double': ['ir::DoubleAttribute', 'double'], + 'bool': ['ir::BoolAttribute', 'bool'], + 'bool[]': [ + 'ir::ArrayAttribute', + 'std::vecot', + ], + 'str': ['ir::StrAttribute', 'std::string'], + 'str[]': [ + 'ir::ArrayAttribute', + 'std::vector', + ], + 'Place': ['paddle::dialect::PlaceAttribute', 'Place'], + 'DataLayout': [ + 'paddle::dialect::DataLayoutAttribute', + 'DataLayout', + ], + 'DataType': ['paddle::dialect::DataTypeAttribute', 'DataType'], + 'int64_t[]': [ + 'ir::ArrayAttribute', + 'std::vector', + ], + 'int[]': [ + 'ir::ArrayAttribute', + 'std::vector', + ], + } self.attribute_name_list = self.parse_attribute_name_list() self.attribute_type_list = self.parse_attribute_type_list() + self.attribute_build_arg_type_list = ( + self.parse_attribute_build_arg_type_list() + ) self.attribute_data_type_list = self.parse_attribute_data_type_list() + self.attribute_default_value_list = ( + self.parse_attribute_default_value_list() + ) self.cross_check(self.attribute_name_list, self.attribute_type_list) + # parse infermeta && kernel + self.infer_meta_map = self.parse_infer_meta_map() + self.kernel_map = self.parse_kernel_map() if 'infer_meta' in self.op_yaml_item: self.infer_shape_func = self.op_yaml_item['infer_meta']["func"] else: @@ -313,6 +375,23 @@ class OpInfoParser: optional_list ), "type list size != optional list size." + def parse_op_phi_name(self): + if self.parse_op_inplace_info() is None: + return [self.op_yaml_item['name']] + else: + if self.op_yaml_item['name'][-1] == "_": + return [self.op_yaml_item['name']] + else: + return [ + self.op_yaml_item['name'], + self.op_yaml_item['name'] + "_", + ] + + def parse_op_inplace_info(self): + if 'inplace' in self.op_yaml_item: + return self.op_yaml_item['inplace'] + return None + def parse_input_name_list(self): name_list = [] for input_info in self.op_yaml_item['inputs']: @@ -369,6 +448,15 @@ class OpInfoParser: type_list.append(output_type_map[output_info['typename']]) return type_list + def parse_output_size_list(self): + size_list = [] + for output_info in self.op_yaml_item['outputs']: + if 'size' in output_info: + size_list.append(output_info['size']) + else: + size_list.append(None) + return size_list + def parse_output_optional_list(self): optional_list = [] for output_info in self.op_yaml_item['outputs']: @@ -399,39 +487,31 @@ class OpInfoParser: name_list.append(attribute_info['name']) return name_list + def parse_attribute_build_arg_type_list(self): + type_list = [] + for attribute_info in self.op_yaml_item['attrs']: + assert ( + attribute_info['typename'] in self.attr_types_map + ), f"{self.op_phi_name} : Attr type error." + + # Scalar & IntArray has data_type + temp_type = self.attr_types_map[attribute_info['typename']][1] + if 'Scalar' in temp_type: + if 'data_type' in attribute_info: + temp_type = attribute_info['data_type'] + if 'IntArray' in temp_type: + if 'data_type' in attribute_info: + temp_type = attribute_info['data_type'] + type_list.append(self.get_phi_dtype_name(temp_type)) + return type_list + def parse_attribute_type_list(self): - attr_types_map = { - 'IntArray': 'paddle::dialect::IntArrayAttribute', - 'Scalar': 'paddle::dialect::ScalarAttribute', - 'Scalar(int)': 'paddle::dialect::ScalarAttribute', - 'Scalar(int64_t)': 'paddle::dialect::ScalarAttribute', - 'Scalar(float)': 'paddle::dialect::ScalarAttribute', - 'Scalar(dobule)': 'paddle::dialect::ScalarAttribute', - 'Scalar[]': 'ir::ArrayAttribute', - 'int': 'ir::Int32_tAttribute', - 'int32_t': 'ir::Int32_tAttribute', - 'int64_t': 'ir::Int64_tAttribute', - 'long': 'ir::LongAttribute', - 'size_t': 'ir::Size_tAttribute', - 'float': 'ir::FloatAttribute', - 'float[]': 'ir::ArrayAttribute', - 'double': 'ir::DoubleAttribute', - 'bool': 'ir::BoolAttribute', - 'bool[]': 'ir::ArrayAttribute', - 'str': 'ir::StrAttribute', - 'str[]': 'ir::ArrayAttribute', - 'Place': 'paddle::dialect::PlaceAttribute', - 'DataLayout': 'paddle::dialect::DataLayoutAttribute', - 'DataType': 'paddle::dialect::DataTypeAttribute', - 'int64_t[]': 'ir::ArrayAttribute', - 'int[]': 'ir::ArrayAttribute', - } type_list = [] for attribute_info in self.op_yaml_item['attrs']: assert ( - attribute_info['typename'] in attr_types_map + attribute_info['typename'] in self.attr_types_map ), f"{self.op_phi_name} : Attr type error." - type_list.append(attr_types_map[attribute_info['typename']]) + type_list.append(self.attr_types_map[attribute_info['typename']][0]) return type_list def parse_attribute_data_type_list(self): @@ -443,22 +523,48 @@ class OpInfoParser: data_type_list.append("") return data_type_list - def parse_op_phi_name(self): - if self.parse_op_inplace_info() is None: - return [self.op_yaml_item['name']] - else: - if self.op_yaml_item['name'][-1] == "_": - return [self.op_yaml_item['name']] + def parse_attribute_default_value_list(self): + default_value_list = [] + for attribute_info in self.op_yaml_item['attrs']: + if 'default_value' in attribute_info: + default_value = attribute_info['default_value'] + default_value_list.append( + self.get_phi_dtype_name(default_value) + ) else: - return [ - self.op_yaml_item['name'], - self.op_yaml_item['name'] + "_", - ] + default_value_list.append(None) + return default_value_list - def parse_op_inplace_info(self): - if 'inplace' in self.op_yaml_item: - return self.op_yaml_item['inplace'] - return None + def parse_infer_meta_map(self): + if 'infer_meta' in self.op_yaml_item: + return self.op_yaml_item['infer_meta'] + else: + return None + + def parse_kernel_map(self): + if 'kernel' in self.op_yaml_item: + return self.op_yaml_item['kernel'] + else: + return None + + def get_phi_dtype_name(self, name): + name = name.replace('Scalar', 'phi::Scalar') + name = name.replace('IntArray', 'phi::IntArray') + name = name.replace('DataLayout', 'phi::DataLayout') + name = name.replace('DataType', 'phi::DataType') + if name.startswith( + ( + "Place", + "CPUPlace", + "GPUPlace", + "GPUPinnedPlace", + "XPUPlace", + "IPUPlace", + "CustomPlace", + ) + ): + return "phi::" + name + return name def to_pascal_case(s): @@ -472,6 +578,280 @@ def to_pascal_case(s): # ===================================== # Generate Op Definition Files # ===================================== +def GenBuildInputArgsStr( + op_input_name_list, + op_attribute_name_list, + op_attribute_build_arg_type_list, + op_attribute_default_value_list, + for_func_define=True, +): + ''' + Example: ir::Builder &builder, ir::OperationArgument &argument, ir::OpResult x_, phi::DataType dtype=phi::DataType::UNDEFINED, phi::Place place={} + ''' + build_args_str = "ir::Builder &builder, ir::OperationArgument &argument" + if len(op_input_name_list) > 0: + for input_name in op_input_name_list: + build_args_str += ", ir::OpResult " + input_name + "_" + for attr_idx in range(len(op_attribute_name_list)): + build_args_str += ( + ", " + + op_attribute_build_arg_type_list[attr_idx] + + " " + + op_attribute_name_list[attr_idx] + ) + if for_func_define: + if op_attribute_default_value_list[attr_idx] is not None: + default_value = op_attribute_default_value_list[attr_idx] + if op_attribute_build_arg_type_list[attr_idx] != "std::string": + if default_value[0] == "'" or default_value[0] == '"': + default_value = default_value[1:] + if default_value[-1] == "'" or default_value[-1] == '"': + default_value = default_value[0:-1] + build_args_str += "=" + default_value + return build_args_str + + +def GenBuildInputs(op_input_name_list): + BUILD_INPUT_TEMPLATE = """ std::vector argument_inputs = {{{inputs_args}}}; + argument.addOperands(argument_inputs.begin(), argument_inputs.end()); +""" + build_input_str = "" + if len(op_input_name_list) > 0: + inputs_args_str = "_, ".join(op_input_name_list) + "_" + build_input_str = BUILD_INPUT_TEMPLATE.format( + inputs_args=inputs_args_str + ) + return build_input_str + + +def GenBuildAttributes(op_attribute_name_list, op_attribute_type_list): + INTARRAY_STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = {op_attribute_type}::get(ir::IrContext::Instance(), phi::IntArray({attr})); +""" + SCALAR_STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = {op_attribute_type}::get(ir::IrContext::Instance(), phi::Scalar({attr})); +""" + STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = {op_attribute_type}::get(ir::IrContext::Instance(), {attr}); +""" + ARRAY_ATTRIBUTE_TEMPLATE = """ std::vector vec_{attr_name}; + for (size_t i = 0; i < static_cast({attr_size}); i++) {{ + {create_attribute} + vec_{attr_name}.push_back(attr_{attr_name}); + }} + ir::Attribute attr_{attr_name} = ir::ArrayAttribute::get(ir::IrContext::Instance(), vec_{attr_name}); +""" + attr_str = "" + for idx in range(len(op_attribute_name_list)): + if "ir::ArrayAttribute<" in op_attribute_type_list[idx]: + inner_attribute_type = op_attribute_type_list[idx][19:-1] + if inner_attribute_type == "paddle::dialect::IntArrayAttribute": + attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format( + attr_name=op_attribute_name_list[idx], + attr_size=op_attribute_name_list[idx] + ".size()", + create_attribute=INTARRAY_STR_TEMPLATE.format( + attr_name=op_attribute_name_list[idx], + op_attribute_type=inner_attribute_type, + attr=op_attribute_name_list[idx] + "[i]", + ), + ) + elif inner_attribute_type == "paddle::dialect::ScalarAttribute": + attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format( + attr_name=op_attribute_name_list[idx], + attr_size=op_attribute_name_list[idx] + ".size()", + create_attribute=SCALAR_STR_TEMPLATE.format( + attr_name=op_attribute_name_list[idx], + op_attribute_type=inner_attribute_type, + attr=op_attribute_name_list[idx] + "[i]", + ), + ) + else: + attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format( + attr_name=op_attribute_name_list[idx], + attr_size=op_attribute_name_list[idx] + ".size()", + create_attribute=STR_TEMPLATE.format( + attr_name=op_attribute_name_list[idx], + op_attribute_type=inner_attribute_type, + attr=op_attribute_name_list[idx] + "[i]", + ), + ) + elif ( + op_attribute_type_list[idx] == "paddle::dialect::IntArrayAttribute" + ): + attr_str += INTARRAY_STR_TEMPLATE.format( + attr_name=op_attribute_name_list[idx], + op_attribute_type=op_attribute_type_list[idx], + attr=op_attribute_name_list[idx], + ) + + elif op_attribute_type_list[idx] == "paddle::dialect::ScalarAttribute": + attr_str += SCALAR_STR_TEMPLATE.format( + attr_name=op_attribute_name_list[idx], + op_attribute_type=op_attribute_type_list[idx], + attr=op_attribute_name_list[idx], + ) + else: + attr_str += STR_TEMPLATE.format( + attr_name=op_attribute_name_list[idx], + op_attribute_type=op_attribute_type_list[idx], + attr=op_attribute_name_list[idx], + ) + attr_str += """ argument.addAttribute("{attr_name}", attr_{attr_name});\n""".format( + attr_name=op_attribute_name_list[idx] + ) + + return attr_str + + +def GenBuildOutputs( + op_input_name_list, + op_input_type_list, + op_output_name_list, + op_output_type_list, + op_output_size_list, + op_infer_meta_map, +): + build_output_str = "" + CREATE_INPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name}; + dense_{name}.set_meta( + phi::DenseTensorMeta(TransToPhiDataType({name}.dtype()), + {name}.dims(), + {name}.data_layout(), + {name}.lod(), + {name}.offset()) + ); + phi::MetaTensor meta_{name}(&dense_{name}); +""" + CREATE_INPUT_VEC_METATENSOR_TEMPLATE = """ std::vector vec_dense_{name}({name}.size(), phi::DenseTensor()); + std::vector vec_meta_{name}; + for (size_t i=0; i < static_cast({name}.size()); i++) {{ + vec_dense_{name}[i].set_meta( + phi::DenseTensorMeta(TransToPhiDataType({name}[i].dyn_cast().dtype()), + {name}[i].dyn_cast().dims(), + {name}[i].dyn_cast().data_layout(), + {name}[i].dyn_cast().lod(), + {name}[i].dyn_cast().offset()) + ); + vec_meta_{name}.push_back(phi::MetaTensor(&vec_dense_{name}[i])); + }} + std::vector meta_{name}; + for (size_t i=0; i < static_cast(vec_meta_{name}.size()); i++) {{ + meta_{name}.push_back(&vec_meta_{name}[i]); + }} + """ + CREATE_OUTPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name}; + phi::MetaTensor meta_{name}(&dense_{name}); +""" + CREATE_OUTPUT_VEC_METATENSOR_TEMPLATE = """ std::vector vec_dense_{name}(({output_size}), phi::DenseTensor()); + std::vector vec_meta_{name}; + for (size_t i=0; i < static_cast({output_size}); i++) {{ + vec_meta_{name}.push_back(phi::MetaTensor(&vec_dense_{name}[i])); + }} + std::vector meta_{name}; + for (size_t i=0; i < static_cast(vec_meta_{name}.size()); i++) {{ + meta_{name}.push_back(&vec_meta_{name}[i]); + }} +""" + # Prepar input type + for idx in range(len(op_input_name_list)): + # is a vector + if 'ir::VectorType' in op_input_type_list[idx]: + build_output_str += " ir::VectorType {name} = {name}_.type().dyn_cast(); (void){name};\n".format( + name=op_input_name_list[idx] + ) + # is a Tensor + else: + build_output_str += " paddle::dialect::DenseTensorType {name} = {name}_.type().dyn_cast(); (void){name};\n".format( + name=op_input_name_list[idx] + ) + + # Prepare inputs for infer meta + infer_meta_args = [] + for idx in range(len(op_infer_meta_map['param'])): + # is input + if op_infer_meta_map['param'][idx] in op_input_name_list: + if ( + "meta_" + op_infer_meta_map['param'][idx] + ) not in infer_meta_args: + # is a vector + if ( + 'ir::VectorType' + in op_input_type_list[ + op_input_name_list.index( + op_infer_meta_map['param'][idx] + ) + ] + ): + build_output_str += ( + CREATE_INPUT_VEC_METATENSOR_TEMPLATE.format( + name=op_infer_meta_map['param'][idx] + ) + ) + # is a Tensor + else: + build_output_str += CREATE_INPUT_METATENSOR_TEMPLATE.format( + name=op_infer_meta_map['param'][idx] + ) + + infer_meta_args.append("meta_" + op_infer_meta_map['param'][idx]) + # is attribute + else: + infer_meta_args.append(op_infer_meta_map['param'][idx]) + + # Prepare outputs for infer meta + for idx in range(len(op_output_name_list)): + # is a vector + if 'ir::VectorType' in op_output_type_list[idx]: + build_output_str += CREATE_OUTPUT_VEC_METATENSOR_TEMPLATE.format( + name=op_output_name_list[idx], + output_size=op_output_size_list[idx], + ) + infer_meta_args.append(f"meta_{op_output_name_list[idx]}") + # is a Tensor + else: + build_output_str += CREATE_OUTPUT_METATENSOR_TEMPLATE.format( + name=op_output_name_list[idx] + ) + infer_meta_args.append(f"&meta_{op_output_name_list[idx]}") + + # Execute infer meta function + CREATE_INFER_META_FUNC_TEMPLATE = """ + phi::{func}({args}); +""" + build_output_str += CREATE_INFER_META_FUNC_TEMPLATE.format( + func=op_infer_meta_map['func'], args=", ".join(infer_meta_args) + ) + + # use dense_{name} or vec_dense_{name} to create Outputs type + build_output_str += "\n std::vector argument_outputs;" + + CREATE_OUTPUT_DENSE_TENSOR_TEMPLATE = """ + ir::Type {name}_dense_tensor_type = paddle::dialect::DenseTensorType::get(ir::IrContext::Instance(), TransToIrDataType(dense_{name}.dtype()), dense_{name}.dims(), dense_{name}.layout(), dense_{name}.lod(), dense_{name}.offset()); + argument_outputs.push_back({name}_dense_tensor_type); +""" + CREATE_OUTPUT_VEC_DENSE_TENSOR_TEMPLATE = """ + std::vector {name}_types; + for (size_t i=0; i < static_cast({output_size}); i++) {{ + {name}_types.push_back(paddle::dialect::DenseTensorType::get(ir::IrContext::Instance(), TransToIrDataType(vec_dense_{name}[i].dtype()), vec_dense_{name}[i].dims(), vec_dense_{name}[i].layout(), vec_dense_{name}[i].lod(), vec_dense_{name}[i].offset())); + }} + ir::Type {name}_vector_type = ir::VectorType::get(ir::IrContext::Instance(), {name}_types); + argument_outputs.push_back({name}_vector_type); +""" + for idx in range(len(op_output_name_list)): + # is a vector + if 'ir::VectorType' in op_output_type_list[idx]: + build_output_str += CREATE_OUTPUT_VEC_DENSE_TENSOR_TEMPLATE.format( + name=op_output_name_list[idx], + output_size=op_output_size_list[idx], + ) + # is a Tensor + else: + build_output_str += CREATE_OUTPUT_DENSE_TENSOR_TEMPLATE.format( + name=op_output_name_list[idx] + ) + + build_output_str += " argument.addTypes(argument_outputs.begin(), argument_outputs.end());\n" + + return build_output_str + + def OpGenerator( op_yaml_files, op_compat_yaml_file, @@ -512,11 +892,16 @@ def OpGenerator( op_input_no_need_buffer_list = op_info.input_no_need_buffer_list op_output_name_list = op_info.output_name_list op_output_type_list = op_info.output_type_list + op_output_size_list = op_info.output_size_list op_output_optional_list = op_info.output_optional_list op_output_intermediate_list = op_info.output_intermediate_list op_attribute_name_list = op_info.attribute_name_list op_attribute_type_list = op_info.attribute_type_list op_attribute_data_type_list = op_info.attribute_data_type_list + op_attribute_build_arg_type_list = op_info.attribute_build_arg_type_list + op_attribute_default_value_list = op_info.attribute_default_value_list + op_infer_meta_map = op_info.infer_meta_map + op_kernel_map = op_info.kernel_map op_interfaces = ["GetOpInfoInterface"] op_traits = [] @@ -552,6 +937,53 @@ def OpGenerator( output_index=idx, ) + # gen build str + build_define_input_args_str = "" + build_declare_input_args_str = "" + build_func_declare_str = "" + if op_infer_meta_map is not None: + build_define_input_args_str = GenBuildInputArgsStr( + op_input_name_list, + op_attribute_name_list, + op_attribute_build_arg_type_list, + op_attribute_default_value_list, + True, + ) + build_declare_input_args_str = GenBuildInputArgsStr( + op_input_name_list, + op_attribute_name_list, + op_attribute_build_arg_type_list, + op_attribute_default_value_list, + False, + ) + build_inputs_str = GenBuildInputs(op_input_name_list) + build_attributes_str = GenBuildAttributes( + op_attribute_name_list, op_attribute_type_list + ) + build_outputs_str = GenBuildOutputs( + op_input_name_list, + op_input_type_list, + op_output_name_list, + op_output_type_list, + op_output_size_list, + op_infer_meta_map, + ) + build_func_declare_str = OP_BUILD_TEMPLATE.format( + op_name=op_class_name, + build_args=build_declare_input_args_str, + build_inputs=build_inputs_str, + build_attributes=build_attributes_str, + build_outputs=build_outputs_str, + ) + else: + build_func_declare_str = OP_BUILD_TEMPLATE.format( + op_name=op_class_name, + build_args=build_declare_input_args_str, + build_inputs="", + build_attributes="", + build_outputs="", + ) + # gen op_declare_str/op_defined_str if len(op_attribute_name_list) == 0: op_declare_str = OP_DECLARE_TEMPLATE.format( @@ -561,6 +993,7 @@ def OpGenerator( traits=op_traits_str, attribute_declare=op_0_attribute_declare_str, attribute_num=0, + build_args=build_define_input_args_str, get_inputs_and_outputs=op_get_inputs_outputs_str, exclusive_interface=exclusive_interface_str, ) @@ -575,6 +1008,7 @@ def OpGenerator( attribute_num=len(op_attribute_name_list) ), attribute_num=len(op_attribute_name_list), + build_args=build_define_input_args_str, get_inputs_and_outputs=op_get_inputs_outputs_str, exclusive_interface=exclusive_interface_str, ) @@ -631,11 +1065,27 @@ def OpGenerator( ) attribute_info_str = ", ".join(attribute_info_list) + # generate runtiem info + infer_meta_func_str = "" + infer_meta_param_str = "" + if op_infer_meta_map is not None: + infer_meta_func_str = op_infer_meta_map['func'] + infer_meta_param_str = '", "'.join(op_infer_meta_map['param']) + kernel_func_str = "" + kernel_param_str = "" + if op_kernel_map is not None: + kernel_func_str = '", "'.join(op_kernel_map['func']) + kernel_param_str = '", "'.join(op_kernel_map['param']) + op_info_func_str = OP_INFO_TEMPLATE.format( op_name=op_class_name, inputs=inputs_info_str, attributes=attribute_info_str, outputs=outputs_info_str, + infer_meta_func=infer_meta_func_str, + infer_meta_param=infer_meta_param_str, + kernel_func=kernel_func_str, + kernel_param=kernel_param_str, ) # generate op verify function: inputs_type_check_str @@ -736,14 +1186,19 @@ def OpGenerator( ) # generate op verify function - op_verify_str = OP_VERIFY_TEMPLATE.format( - op_name=op_class_name, - inputs_size=len(op_input_type_list), - outputs_size=len(op_output_type_list), - inputs_type_check=inputs_type_check_str, - outputs_type_check=outputs_type_check_str, - attributes_check=attributes_check_str, - ) + if "GradOp" in op_class_name or "Grad_Op" in op_class_name: + op_verify_str = GRAD_OP_VERIFY_TEMPLATE.format( + op_name=op_class_name, + ) + else: + op_verify_str = OP_VERIFY_TEMPLATE.format( + op_name=op_class_name, + inputs_size=len(op_input_type_list), + outputs_size=len(op_output_type_list), + inputs_type_check=inputs_type_check_str, + outputs_type_check=outputs_type_check_str, + attributes_check=attributes_check_str, + ) op_infer_shape_str = "" if op_info.infer_shape_func: @@ -756,6 +1211,7 @@ def OpGenerator( ops_declare_list.append(op_declare_str) ops_defined_list.append(op_defined_str) ops_defined_list.append(op_info_func_str) + ops_defined_list.append(build_func_declare_str) ops_defined_list.append(op_verify_str) ops_defined_list.append(op_infer_shape_str) @@ -786,7 +1242,7 @@ def OpGenerator( namespace=name, input=source_file_str ) # Add namespaces source_file_str = CC_FILE_TEMPLATE.format( - h_file=op_def_h_file, input=source_file_str + h_file=op_def_h_file[:-4], input=source_file_str ) # Add head # (5) Generate pd_op.h.tmp, pd_op.cc.tmp @@ -817,6 +1273,7 @@ def ParseArguments(): # ===================================== if __name__ == "__main__": # parse arguments + print("auto gen op") args = ParseArguments() op_yaml_files = args.op_yaml_files.split(",") op_compat_yaml_file = args.op_compat_yaml_file diff --git a/paddle/fluid/dialect/pd_dialect.cc b/paddle/fluid/dialect/pd_dialect.cc index e9802e790e6..b9772a4dcd1 100644 --- a/paddle/fluid/dialect/pd_dialect.cc +++ b/paddle/fluid/dialect/pd_dialect.cc @@ -35,13 +35,13 @@ ParameterConvertInterface::ParameterToVariable(ir::Parameter *parameter) { std::make_shared(); phi::DenseTensor *tensor = var->GetMutable(); // Init DenseTensor - auto dim = parameter->type().dyn_cast().dim(); + auto dim = parameter->type().dyn_cast().dims(); phi::DenseTensorMeta meta( TransToPhiDataType( parameter->type().dyn_cast().dtype()), - phi::DDim(dim.data(), dim.size()), - TransToPhiDataLayout( - parameter->type().dyn_cast().data_layout()), + dim, + + parameter->type().dyn_cast().data_layout(), parameter->type().dyn_cast().lod(), parameter->type().dyn_cast().offset()); tensor->set_meta(meta); @@ -67,17 +67,13 @@ std::unique_ptr ParameterConvertInterface::VariableToParameter( // Get Meta ir::IrContext *ctx = ir::IrContext::Instance(); ir::Type data_type = TransToIrDataType(tensor->dtype(), ctx); - DenseTensorTypeStorage::Dim dims(tensor->dims().size()); - std::copy(tensor->dims().Get(), - tensor->dims().Get() + tensor->dims().size(), - dims.data()); - DenseTensorTypeStorage::DataLayout data_layout = - TransToIrDataLayout(tensor->layout()); - DenseTensorTypeStorage::LoD lod = tensor->lod(); - size_t offset = tensor->meta().offset; void *data = tensor->data(); - ir::Type dense_tensor_type = - DenseTensorType::get(ctx, data_type, dims, data_layout, lod, offset); + ir::Type dense_tensor_type = DenseTensorType::get(ctx, + data_type, + tensor->dims(), + tensor->layout(), + tensor->lod(), + tensor->meta().offset); return std::make_unique( data, tensor->numel() * phi::SizeOf(tensor->dtype()), @@ -116,8 +112,7 @@ void PaddleDialect::PrintType(ir::Type type, std::ostream &os) { DenseTensorType tensor_type = type.dyn_cast(); os << "tensor<"; - auto &dims = tensor_type.dim(); - for (auto d : dims) { + for (auto d : phi::vectorize(tensor_type.dims())) { os << d; os << "x"; } diff --git a/paddle/fluid/dialect/pd_interface.h b/paddle/fluid/dialect/pd_interface.h index 45c32bc8370..f49ba35fcc3 100644 --- a/paddle/fluid/dialect/pd_interface.h +++ b/paddle/fluid/dialect/pd_interface.h @@ -19,25 +19,22 @@ using OpInfoTuple = std::tuple, std::vector, - std::vector>; + std::vector, + paddle::dialect::OpRunTimeInfo>; namespace paddle { namespace dialect { class GetOpInfoInterface : public ir::OpInterfaceBase { public: struct Concept { - explicit Concept(OpInfoTuple (*get_op_info)(ir::Operation *)) + explicit Concept(OpInfoTuple (*get_op_info)()) : get_op_info_(get_op_info) {} - OpInfoTuple (*get_op_info_)(ir::Operation *); + OpInfoTuple (*get_op_info_)(); }; template struct Model : public Concept { - static OpInfoTuple GetOpInfo(ir::Operation *op) { - ConcreteOp concret_op = op->dyn_cast(); - if (concret_op == nullptr) throw("concret_op is nullptr"); - return concret_op.GetOpInfo(); - } + static OpInfoTuple GetOpInfo() { return ConcreteOp::GetOpInfo(); } Model() : Concept(GetOpInfo) {} }; @@ -45,7 +42,7 @@ class GetOpInfoInterface : public ir::OpInterfaceBase { GetOpInfoInterface(ir::Operation *op, Concept *impl) : ir::OpInterfaceBase(op), impl_(impl) {} - OpInfoTuple GetOpInfo() { return impl_->get_op_info_(operation()); } + OpInfoTuple GetOpInfo() { return impl_->get_op_info_(); } private: Concept *impl_; diff --git a/paddle/fluid/dialect/pd_op.yaml b/paddle/fluid/dialect/pd_op.yaml index 7ca8646a038..4dcc6cff8fc 100644 --- a/paddle/fluid/dialect/pd_op.yaml +++ b/paddle/fluid/dialect/pd_op.yaml @@ -11,17 +11,6 @@ - {typename: Tensor, name: out, optional: false, intermediate: false} no_need_buffer: null data_transform: null - infer_meta: - func: null - param: null - kernel: - func: null - param: null - backend: null - layout: null - data_type: null - dispatch: null - force_backend: null inplace: null backward: null - name: fetch @@ -37,16 +26,5 @@ - {typename: 'Tensor[]', name: out, optional: false, intermediate: false} no_need_buffer: null data_transform: null - infer_meta: - func: null - param: null - kernel: - func: null - param: null - backend: null - layout: null - data_type: null - dispatch: null - force_backend: null inplace: null backward: null diff --git a/paddle/fluid/dialect/pd_type.cc b/paddle/fluid/dialect/pd_type.cc index 017af3ffd5a..1c3de528bcf 100644 --- a/paddle/fluid/dialect/pd_type.cc +++ b/paddle/fluid/dialect/pd_type.cc @@ -18,20 +18,13 @@ namespace paddle { namespace dialect { const ir::Type& DenseTensorType::dtype() const { return storage()->dtype_; } -const paddle::dialect::DenseTensorTypeStorage::Dim& DenseTensorType::dim() - const { - return storage()->dims_; -} +const phi::DDim& DenseTensorType::dims() const { return storage()->dims_; } -const paddle::dialect::DenseTensorTypeStorage::DataLayout& -DenseTensorType::data_layout() const { +const phi::DataLayout& DenseTensorType::data_layout() const { return storage()->layout_; } -const paddle::dialect::DenseTensorTypeStorage::LoD& DenseTensorType::lod() - const { - return storage()->lod_; -} +const phi::LoD& DenseTensorType::lod() const { return storage()->lod_; } const size_t& DenseTensorType::offset() const { return storage()->offset_; } diff --git a/paddle/fluid/dialect/pd_type.h b/paddle/fluid/dialect/pd_type.h index 8e9f1e6c54c..0e23916234e 100644 --- a/paddle/fluid/dialect/pd_type.h +++ b/paddle/fluid/dialect/pd_type.h @@ -30,12 +30,11 @@ class DenseTensorType : public ir::Type { const ir::Type &dtype() const; - const paddle::dialect::DenseTensorTypeStorage::Dim &dim() const; + const phi::DDim &dims() const; - const paddle::dialect::DenseTensorTypeStorage::DataLayout &data_layout() - const; + const phi::DataLayout &data_layout() const; - const paddle::dialect::DenseTensorTypeStorage::LoD &lod() const; + const phi::LoD &lod() const; const size_t &offset() const; }; diff --git a/paddle/fluid/dialect/pd_type_storage.h b/paddle/fluid/dialect/pd_type_storage.h index 3104edb80f8..dbdb3b374e4 100644 --- a/paddle/fluid/dialect/pd_type_storage.h +++ b/paddle/fluid/dialect/pd_type_storage.h @@ -18,6 +18,7 @@ #include "paddle/ir/core/type.h" #include "paddle/ir/core/utils.h" +#include "paddle/phi/core/tensor_meta.h" namespace std { /// @@ -46,46 +47,20 @@ namespace dialect { /// (3)define HashValue method, (4)overload operator==. /// struct DenseTensorTypeStorage : public ir::TypeStorage { - /// - /// \brief It is consistent with the DataLayout defined by Phi operator - /// library. See the file for details: paddle/phi/common/layout.h. - /// - enum class DataLayout : unsigned int { - UNDEFINED = 0, - NHWC, - NCHW, - NCDHW, - NDHWC, - ONEDNN, - SPARSE_COO, - SPARSE_CSR, - PSTRING_UNION, - - NUM_DATA_LAYOUTS, - - // See Note [ Why we need ALL in basic kernel key member? ] - ALL_LAYOUT = UNDEFINED, - - // Note: Unify phi DataLayout and fluid::framework::DataLayout, - // for compatible with fluid DataLayout, here need prefix `k` - kNHWC = NHWC, - kNCHW = NCHW, - kMKLDNN = ONEDNN, // all layouts supported by ONEDNN internally - kNDHWC = NDHWC, - kNCDHW = NCDHW, - }; - - using Dim = std::vector; - + using DataLayout = phi::DataLayout; + using Dim = phi::DDim; using LoD = std::vector>; - /// /// \brief Declare ParamKey according to parameter type. /// - using ParamKey = std::tuple; - - DenseTensorTypeStorage( - ir::Type dtype, Dim dims, DataLayout layout, LoD lod, size_t offset) + using ParamKey = + std::tuple; + + DenseTensorTypeStorage(ir::Type dtype, + phi::DDim dims, + phi::DataLayout layout, + phi::LoD lod, + size_t offset) : dtype_(dtype), dims_(dims), layout_(layout), @@ -114,16 +89,16 @@ struct DenseTensorTypeStorage : public ir::TypeStorage { ir::hash_combine(hash_value, std::hash()(std::get<0>(key))); // hash dims hash_value = - ir::hash_combine(hash_value, std::hash()(std::get<1>(key))); + ir::hash_combine(hash_value, std::hash()(std::get<1>(key))); // hash layout hash_value = ir::hash_combine( hash_value, - std::hash::type>()( - static_cast::type>( + std::hash::type>()( + static_cast::type>( std::get<2>(key)))); // hash lod hash_value = - ir::hash_combine(hash_value, std::hash()(std::get<3>(key))); + ir::hash_combine(hash_value, std::hash()(std::get<3>(key))); // hash offset hash_value = ir::hash_combine(hash_value, std::hash()(std::get<4>(key))); @@ -146,9 +121,9 @@ struct DenseTensorTypeStorage : public ir::TypeStorage { /// layout, lod, offset. /// ir::Type dtype_; - Dim dims_; - DataLayout layout_; - LoD lod_; + phi::DDim dims_; + phi::DataLayout layout_; + phi::LoD lod_; size_t offset_; }; diff --git a/paddle/fluid/dialect/utils.h b/paddle/fluid/dialect/utils.h index bfc28f06267..56f44d87271 100644 --- a/paddle/fluid/dialect/utils.h +++ b/paddle/fluid/dialect/utils.h @@ -70,67 +70,76 @@ inline ir::Type TransToIrDataType(phi::DataType dtype, } } -inline phi::DataLayout TransToPhiDataLayout( - DenseTensorTypeStorage::DataLayout data_layout) { - switch (data_layout) { - case DenseTensorTypeStorage::DataLayout::NHWC: - return phi::DataLayout::NHWC; - case DenseTensorTypeStorage::DataLayout::NCHW: - return phi::DataLayout::NCHW; - case DenseTensorTypeStorage::DataLayout::NCDHW: - return phi::DataLayout::NCDHW; - case DenseTensorTypeStorage::DataLayout::NDHWC: - return phi::DataLayout::NDHWC; - case DenseTensorTypeStorage::DataLayout::ONEDNN: - return phi::DataLayout::ONEDNN; - case DenseTensorTypeStorage::DataLayout::SPARSE_COO: - return phi::DataLayout::SPARSE_COO; - case DenseTensorTypeStorage::DataLayout::SPARSE_CSR: - return phi::DataLayout::SPARSE_CSR; - case DenseTensorTypeStorage::DataLayout::PSTRING_UNION: - return phi::DataLayout::PSTRING_UNION; - case DenseTensorTypeStorage::DataLayout::NUM_DATA_LAYOUTS: - return phi::DataLayout::NUM_DATA_LAYOUTS; - case DenseTensorTypeStorage::DataLayout::ALL_LAYOUT: - return phi::DataLayout::ALL_LAYOUT; - default: - PADDLE_THROW(phi::errors::Unimplemented( - "Unsupported ir data layout `%s` when casting it into " - "phi data type.", - static_cast(data_layout))); - } -} +// inline phi::DataLayout TransToPhiDataLayout( +// DenseTensorTypeStorage::DataLayout data_layout) { +// switch (data_layout) { +// case DenseTensorTypeStorage::DataLayout::NHWC: +// return phi::DataLayout::NHWC; +// case DenseTensorTypeStorage::DataLayout::NCHW: +// return phi::DataLayout::NCHW; +// case DenseTensorTypeStorage::DataLayout::NCDHW: +// return phi::DataLayout::NCDHW; +// case DenseTensorTypeStorage::DataLayout::NDHWC: +// return phi::DataLayout::NDHWC; +// case DenseTensorTypeStorage::DataLayout::ONEDNN: +// return phi::DataLayout::ONEDNN; +// case DenseTensorTypeStorage::DataLayout::SPARSE_COO: +// return phi::DataLayout::SPARSE_COO; +// case DenseTensorTypeStorage::DataLayout::SPARSE_CSR: +// return phi::DataLayout::SPARSE_CSR; +// case DenseTensorTypeStorage::DataLayout::PSTRING_UNION: +// return phi::DataLayout::PSTRING_UNION; +// case DenseTensorTypeStorage::DataLayout::NUM_DATA_LAYOUTS: +// return phi::DataLayout::NUM_DATA_LAYOUTS; +// case DenseTensorTypeStorage::DataLayout::ALL_LAYOUT: +// return phi::DataLayout::ALL_LAYOUT; +// default: +// PADDLE_THROW(phi::errors::Unimplemented( +// "Unsupported ir data layout `%s` when casting it into " +// "phi data type.", +// static_cast(data_layout))); +// } +// } -inline DenseTensorTypeStorage::DataLayout TransToIrDataLayout( - phi::DataLayout data_layout) { - switch (data_layout) { - case phi::DataLayout::NHWC: - return DenseTensorTypeStorage::DataLayout::NHWC; - case phi::DataLayout::NCHW: - return DenseTensorTypeStorage::DataLayout::NCHW; - case phi::DataLayout::NCDHW: - return DenseTensorTypeStorage::DataLayout::NCDHW; - case phi::DataLayout::NDHWC: - return DenseTensorTypeStorage::DataLayout::NDHWC; - case phi::DataLayout::ONEDNN: - return DenseTensorTypeStorage::DataLayout::ONEDNN; - case phi::DataLayout::SPARSE_COO: - return DenseTensorTypeStorage::DataLayout::SPARSE_COO; - case phi::DataLayout::SPARSE_CSR: - return DenseTensorTypeStorage::DataLayout::SPARSE_CSR; - case phi::DataLayout::PSTRING_UNION: - return DenseTensorTypeStorage::DataLayout::PSTRING_UNION; - case phi::DataLayout::NUM_DATA_LAYOUTS: - return DenseTensorTypeStorage::DataLayout::NUM_DATA_LAYOUTS; - case phi::DataLayout::ALL_LAYOUT: - return DenseTensorTypeStorage::DataLayout::ALL_LAYOUT; - default: - PADDLE_THROW(phi::errors::Unimplemented( - "Unsupported phi data layout `%s` when casting it into " - "ir data type.", - static_cast(data_layout))); - } -} +// inline DenseTensorTypeStorage::DataLayout TransToIrDataLayout( +// phi::DataLayout data_layout) { +// switch (data_layout) { +// case phi::DataLayout::NHWC: +// return DenseTensorTypeStorage::DataLayout::NHWC; +// case phi::DataLayout::NCHW: +// return DenseTensorTypeStorage::DataLayout::NCHW; +// case phi::DataLayout::NCDHW: +// return DenseTensorTypeStorage::DataLayout::NCDHW; +// case phi::DataLayout::NDHWC: +// return DenseTensorTypeStorage::DataLayout::NDHWC; +// case phi::DataLayout::ONEDNN: +// return DenseTensorTypeStorage::DataLayout::ONEDNN; +// case phi::DataLayout::SPARSE_COO: +// return DenseTensorTypeStorage::DataLayout::SPARSE_COO; +// case phi::DataLayout::SPARSE_CSR: +// return DenseTensorTypeStorage::DataLayout::SPARSE_CSR; +// case phi::DataLayout::PSTRING_UNION: +// return DenseTensorTypeStorage::DataLayout::PSTRING_UNION; +// case phi::DataLayout::NUM_DATA_LAYOUTS: +// return DenseTensorTypeStorage::DataLayout::NUM_DATA_LAYOUTS; +// case phi::DataLayout::ALL_LAYOUT: +// return DenseTensorTypeStorage::DataLayout::ALL_LAYOUT; +// default: +// PADDLE_THROW(phi::errors::Unimplemented( +// "Unsupported phi data layout `%s` when casting it into " +// "ir data type.", +// static_cast(data_layout))); +// } +// } + +// inline phi::DenseTensorMeta TransToDenseTensorMeta( +// paddle::dialect::DenseTensorType type) { +// return phi::DenseTensorMeta(TransToPhiDataType(type.dtype()), +// type.dim(), +// type.data_layout(), +// type.lod(), +// type.offset()); +// } struct OpInputInfo { std::string name; @@ -172,5 +181,20 @@ struct OpAttributeInfo { : name(name), type_name(type_name), data_type(data_type) {} }; +struct OpRunTimeInfo { + std::string infer_meta_func; + std::vector infer_meta_param; + std::vector kernel_func; + std::vector kernel_param; + OpRunTimeInfo(std::string infer_meta_func, + std::vector infer_meta_param, + std::vector kernel_func, + std::vector kernel_param) + : infer_meta_func(infer_meta_func), + infer_meta_param(infer_meta_param), + kernel_func(kernel_func), + kernel_param(kernel_param) {} +}; + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/translator/type_translator.cc b/paddle/fluid/translator/type_translator.cc index c677ec8682d..1681728a5c9 100644 --- a/paddle/fluid/translator/type_translator.cc +++ b/paddle/fluid/translator/type_translator.cc @@ -50,7 +50,7 @@ TypeTranslator::TypeTranslator() { ir::Type dtype = this->operator[](var_desc.GetDataType())(ctx, var_desc); - DenseTensorTypeStorage::Dim dim = var_desc.GetShape(); + DenseTensorTypeStorage::Dim dim = phi::make_ddim(var_desc.GetShape()); DenseTensorTypeStorage::DataLayout layout = DenseTensorTypeStorage::DataLayout::UNDEFINED; DenseTensorTypeStorage::LoD lod = {}; diff --git a/paddle/ir/core/builtin_attribute_storage.h b/paddle/ir/core/builtin_attribute_storage.h index 101c4781beb..3d2e23bc047 100644 --- a/paddle/ir/core/builtin_attribute_storage.h +++ b/paddle/ir/core/builtin_attribute_storage.h @@ -25,7 +25,7 @@ namespace ir { #define DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(concrete_storage, base_type) \ struct concrete_storage : public ir::AttributeStorage { \ - using ParamKey = bool; \ + using ParamKey = base_type; \ \ explicit concrete_storage(const ParamKey &key) { data_ = key; } \ \ diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 306f7a77a05..bf9bc57c691 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -221,7 +221,7 @@ - backward_op : broadcast_tensors_grad forward : broadcast_tensors (Tensor[] input) -> Tensor[](out) args : (Tensor[] input, Tensor[] out_grad) - output : Tensor[](input_grad) + output : Tensor[](input_grad){input.size()} infer_meta : func : UnchangedMultiInferMeta param : [input] diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 58ecf0604ad..b312fa2658e 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -235,7 +235,7 @@ - backward_op : einsum_grad forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache), Tensor[](x_shape) args : (Tensor[] x_shape, Tensor[] inner_cache, Tensor out_grad, str equation) - output : Tensor[](x_grad){x.size()} + output : Tensor[](x_grad){x_shape.size()} infer_meta : func : UnchangedMultiInferMeta param : [x_shape] diff --git a/test/cpp/ir/core/ir_program_test.cc b/test/cpp/ir/core/ir_program_test.cc index 8b78b739a65..3150519a315 100644 --- a/test/cpp/ir/core/ir_program_test.cc +++ b/test/cpp/ir/core/ir_program_test.cc @@ -107,10 +107,9 @@ TEST(program_test, program) { a_interface->ParameterToVariable(program.GetParameter("a")); const phi::DenseTensor &a_tensor = a_var->Get(); EXPECT_EQ(a_tensor.numel(), 4); - EXPECT_EQ(a_tensor.dims(), phi::DDim(dims.data(), dims.size())); + EXPECT_EQ(a_tensor.dims(), dims); EXPECT_EQ(a_tensor.dtype(), paddle::dialect::TransToPhiDataType(fp32_dtype)); - EXPECT_EQ(a_tensor.layout(), - paddle::dialect::TransToPhiDataLayout(data_layout)); + EXPECT_EQ(a_tensor.layout(), data_layout); EXPECT_EQ(a_tensor.lod(), lod); EXPECT_EQ(a_tensor.offset(), offset); for (int64_t i = 0; i < a_tensor.numel(); i++) { @@ -137,10 +136,9 @@ TEST(program_test, program) { b_interface->ParameterToVariable(program.GetParameter("b")); const phi::DenseTensor &b_tensor = b_var->Get(); EXPECT_EQ(b_tensor.numel(), 4); - EXPECT_EQ(b_tensor.dims(), phi::DDim(dims.data(), dims.size())); + EXPECT_EQ(b_tensor.dims(), dims); EXPECT_EQ(b_tensor.dtype(), paddle::dialect::TransToPhiDataType(fp32_dtype)); - EXPECT_EQ(b_tensor.layout(), - paddle::dialect::TransToPhiDataLayout(data_layout)); + EXPECT_EQ(b_tensor.layout(), data_layout); EXPECT_EQ(b_tensor.lod(), lod); EXPECT_EQ(b_tensor.offset(), offset); for (int64_t i = 0; i < b_tensor.numel(); i++) { -- GitLab