未验证 提交 bf669a5c 编写于 作者: Z zhangbo9674 提交者: GitHub

fix unique and close op bug (#55168)

上级 b5645956
...@@ -335,16 +335,6 @@ class OpInfoParser: ...@@ -335,16 +335,6 @@ class OpInfoParser:
): ):
for scalar_attr in self.op_compat_item['scalar'].keys(): for scalar_attr in self.op_compat_item['scalar'].keys():
if 'data_type' in self.op_compat_item['scalar'][scalar_attr]: if 'data_type' in self.op_compat_item['scalar'][scalar_attr]:
if (
self.op_compat_item['scalar'][scalar_attr]['data_type']
== "std::string"
):
# see isclose and allclose in op_compat.yaml
mutable_attribute_name_list.append(scalar_attr)
mutable_attribute_type_list.append(
["ir::StrAttribute", "std::string"]
)
else:
if ( if (
scalar_attr == "depth" scalar_attr == "depth"
and self.op_phi_name[0] == "one_hot" and self.op_phi_name[0] == "one_hot"
...@@ -352,12 +342,18 @@ class OpInfoParser: ...@@ -352,12 +342,18 @@ class OpInfoParser:
mutable_attribute_name_list.append("num_classes") mutable_attribute_name_list.append("num_classes")
else: else:
mutable_attribute_name_list.append(scalar_attr) mutable_attribute_name_list.append(scalar_attr)
data_type = self.op_compat_item['scalar'][scalar_attr][
'data_type'
]
# patch for isclose and allclose
if (self.op_compat_item['op'] == "isclose") or (
self.op_compat_item['op'] == "allclose"
):
data_type = "float"
mutable_attribute_type_list.append( mutable_attribute_type_list.append(
[ [
"paddle::dialect::ScalarAttribute", "paddle::dialect::ScalarAttribute",
self.op_compat_item['scalar'][scalar_attr][ data_type,
'data_type'
],
] ]
) )
# See eye in op_compat.yaml # See eye in op_compat.yaml
...@@ -371,7 +367,6 @@ class OpInfoParser: ...@@ -371,7 +367,6 @@ class OpInfoParser:
], ],
] ]
) )
# int_array # int_array
if (self.op_compat_item is not None) and ( if (self.op_compat_item is not None) and (
'int_array' in self.op_compat_item 'int_array' in self.op_compat_item
......
...@@ -979,6 +979,7 @@ ...@@ -979,6 +979,7 @@
kernel : kernel :
func : unique func : unique
data_type : x data_type : x
optional : indices, inverse, counts
- op : unpool - op : unpool
args: (Tensor x, Tensor indices, int[] ksize, int[] strides, int[] padding, IntArray output_size, str data_format) args: (Tensor x, Tensor indices, int[] ksize, int[] strides, int[] padding, IntArray output_size, str data_format)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册