未验证 提交 55418d3f 编写于 作者: T TTerror 提交者: GitHub

fix test_refactor_op_xpu, *test=kunlun (#39168)

上级 232bbce2
......@@ -52,8 +52,9 @@ class XPUTestArgsortOp1(XPUOpTestWrapper):
classes = []
for descending in [True, False]:
for axis in [0, 1, 2, -1, -2]:
class_name = 'XPUTestArgsortOp_axis_' + str(axis)
attr_dict = {'init_axis': axis, 'descending': descending}
class_name = 'XPUTestArgsortOp_axis_' + str(axis) + '_' + str(
descending)
attr_dict = {'init_axis': axis, 'init_descending': descending}
classes.append([class_name, attr_dict])
return base_class, classes
......@@ -64,8 +65,9 @@ class XPUTestArgsortOp1(XPUOpTestWrapper):
self.place = paddle.XPUPlace(0)
self.dtype = self.in_type
self.input_shape = (2, 2, 2, 3, 3)
self.axis = -1
self.descending = False
self.axis = -1 if not hasattr(self, 'init_axis') else self.init_axis
self.descending = False if not hasattr(
self, 'init_descending') else self.init_descending
if self.in_type == 'float32':
self.x = np.random.random(self.input_shape).astype(self.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册