From bf669a5c710709db0a2f22e7b4982243052d12c6 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Thu, 6 Jul 2023 17:32:43 +0800 Subject: [PATCH] fix unique and close op bug (#55168) --- .../fluid/ir/dialect/op_generator/op_gen.py | 41 ++++++++----------- paddle/phi/api/yaml/legacy_ops.yaml | 1 + 2 files changed, 19 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/ir/dialect/op_generator/op_gen.py b/paddle/fluid/ir/dialect/op_generator/op_gen.py index 4b25921c8bb..36f4c1550b7 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_gen.py @@ -336,30 +336,26 @@ class OpInfoParser: for scalar_attr in self.op_compat_item['scalar'].keys(): if 'data_type' in self.op_compat_item['scalar'][scalar_attr]: if ( - self.op_compat_item['scalar'][scalar_attr]['data_type'] - == "std::string" + scalar_attr == "depth" + and self.op_phi_name[0] == "one_hot" ): - # see isclose and allclose in op_compat.yaml - mutable_attribute_name_list.append(scalar_attr) - mutable_attribute_type_list.append( - ["ir::StrAttribute", "std::string"] - ) + mutable_attribute_name_list.append("num_classes") else: - if ( - scalar_attr == "depth" - and self.op_phi_name[0] == "one_hot" - ): - mutable_attribute_name_list.append("num_classes") - else: - mutable_attribute_name_list.append(scalar_attr) - mutable_attribute_type_list.append( - [ - "paddle::dialect::ScalarAttribute", - self.op_compat_item['scalar'][scalar_attr][ - 'data_type' - ], - ] - ) + mutable_attribute_name_list.append(scalar_attr) + data_type = self.op_compat_item['scalar'][scalar_attr][ + 'data_type' + ] + # patch for isclose and allclose + if (self.op_compat_item['op'] == "isclose") or ( + self.op_compat_item['op'] == "allclose" + ): + data_type = "float" + mutable_attribute_type_list.append( + [ + "paddle::dialect::ScalarAttribute", + data_type, + ] + ) # See eye in op_compat.yaml else: mutable_attribute_name_list.append(scalar_attr) @@ -371,7 +367,6 @@ class OpInfoParser: ], ] ) - # int_array if (self.op_compat_item is not None) and ( 'int_array' in self.op_compat_item diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 0468d5d8e21..a4b0fa02d69 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -979,6 +979,7 @@ kernel : func : unique data_type : x + optional : indices, inverse, counts - op : unpool args: (Tensor x, Tensor indices, int[] ksize, int[] strides, int[] padding, IntArray output_size, str data_format) -- GitLab