未验证 提交 5ee4a21a 编写于 作者: C cifar10 提交者: GitHub

fix arg_max to select first index (#44521)

上级 a2b39320
......@@ -72,7 +72,7 @@ class ArgMaxMLUKernel : public framework::OpKernel<T> {
MLUCnnlTensorDesc input_desc(
flatten_x, CNNL_LAYOUT_ARRAY, ToCnnlDataType(flatten_x.dtype()));
MLUCnnlReduceDesc reduction_desc(reduce_dims,
CNNL_REDUCE_MAX_LAST_INDEX,
CNNL_REDUCE_MAX,
ToCnnlDataType<T>(),
CNNL_NOT_PROPAGATE_NAN,
CNNL_REDUCE_ONLY_INDICES,
......
......@@ -52,6 +52,38 @@ class BaseTestCase(OpTest):
self.check_output_with_place(self.place)
class TestArgMaxSameValue1(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dtype = 'float32'
self.axis = 0
def setUp(self):
self.set_mlu()
self.initTestCase()
self.x = np.array([1, 2, 3, 5, 4, 5]).astype(self.dtype)
self.inputs = {'X': self.x}
self.attrs = {'axis': self.axis}
self.outputs = {'Out': np.argmax(self.x, axis=self.axis)}
class TestArgMaxSameValue2(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dtype = 'float16'
self.axis = 0
def setUp(self):
self.set_mlu()
self.initTestCase()
self.x = np.array([[2, 3, 5, 5], [3, 2, 5, 5]]).astype(self.dtype)
self.inputs = {'X': self.x}
self.attrs = {'axis': self.axis}
self.outputs = {'Out': np.argmax(self.x, axis=self.axis)}
# test argmax, dtype: float16
class TestArgMaxFloat16Case1(BaseTestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册