diff --git a/paddle/fluid/ir/dialect/op_generator/op_gen.py b/paddle/fluid/ir/dialect/op_generator/op_gen.py index 4b25921c8bb3cff3f7fc5ea9373b4c6becdb397f..36f4c1550b71e73a88e8acc0f8653981a40d2a81 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_gen.py @@ -336,30 +336,26 @@ class OpInfoParser: for scalar_attr in self.op_compat_item['scalar'].keys(): if 'data_type' in self.op_compat_item['scalar'][scalar_attr]: if ( - self.op_compat_item['scalar'][scalar_attr]['data_type'] - == "std::string" + scalar_attr == "depth" + and self.op_phi_name[0] == "one_hot" ): - # see isclose and allclose in op_compat.yaml - mutable_attribute_name_list.append(scalar_attr) - mutable_attribute_type_list.append( - ["ir::StrAttribute", "std::string"] - ) + mutable_attribute_name_list.append("num_classes") else: - if ( - scalar_attr == "depth" - and self.op_phi_name[0] == "one_hot" - ): - mutable_attribute_name_list.append("num_classes") - else: - mutable_attribute_name_list.append(scalar_attr) - mutable_attribute_type_list.append( - [ - "paddle::dialect::ScalarAttribute", - self.op_compat_item['scalar'][scalar_attr][ - 'data_type' - ], - ] - ) + mutable_attribute_name_list.append(scalar_attr) + data_type = self.op_compat_item['scalar'][scalar_attr][ + 'data_type' + ] + # patch for isclose and allclose + if (self.op_compat_item['op'] == "isclose") or ( + self.op_compat_item['op'] == "allclose" + ): + data_type = "float" + mutable_attribute_type_list.append( + [ + "paddle::dialect::ScalarAttribute", + data_type, + ] + ) # See eye in op_compat.yaml else: mutable_attribute_name_list.append(scalar_attr) @@ -371,7 +367,6 @@ class OpInfoParser: ], ] ) - # int_array if (self.op_compat_item is not None) and ( 'int_array' in self.op_compat_item diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 0468d5d8e21990e6c5c3db57f8e8d5465c5d9c90..a4b0fa02d6920eb03dda1d24b914d7c06ea58329 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -979,6 +979,7 @@ kernel : func : unique data_type : x + optional : indices, inverse, counts - op : unpool args: (Tensor x, Tensor indices, int[] ksize, int[] strides, int[] padding, IntArray output_size, str data_format)