From 5ee4a21a10cc2b22121239992b562f8889d0b56d Mon Sep 17 00:00:00 2001 From: cifar10 <41565156+cifar10@users.noreply.github.com> Date: Fri, 22 Jul 2022 10:29:28 +0800 Subject: [PATCH] fix arg_max to select first index (#44521) --- paddle/fluid/operators/arg_max_op_mlu.cc | 2 +- .../unittests/mlu/test_arg_max_op_mlu.py | 32 +++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/arg_max_op_mlu.cc b/paddle/fluid/operators/arg_max_op_mlu.cc index ca8966a795f..44f74f016c0 100644 --- a/paddle/fluid/operators/arg_max_op_mlu.cc +++ b/paddle/fluid/operators/arg_max_op_mlu.cc @@ -72,7 +72,7 @@ class ArgMaxMLUKernel : public framework::OpKernel { MLUCnnlTensorDesc input_desc( flatten_x, CNNL_LAYOUT_ARRAY, ToCnnlDataType(flatten_x.dtype())); MLUCnnlReduceDesc reduction_desc(reduce_dims, - CNNL_REDUCE_MAX_LAST_INDEX, + CNNL_REDUCE_MAX, ToCnnlDataType(), CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_ONLY_INDICES, diff --git a/python/paddle/fluid/tests/unittests/mlu/test_arg_max_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_arg_max_op_mlu.py index bd943e05b2d..45a79b5ece5 100644 --- a/python/paddle/fluid/tests/unittests/mlu/test_arg_max_op_mlu.py +++ b/python/paddle/fluid/tests/unittests/mlu/test_arg_max_op_mlu.py @@ -52,6 +52,38 @@ class BaseTestCase(OpTest): self.check_output_with_place(self.place) +class TestArgMaxSameValue1(BaseTestCase): + + def initTestCase(self): + self.op_type = 'arg_max' + self.dtype = 'float32' + self.axis = 0 + + def setUp(self): + self.set_mlu() + self.initTestCase() + self.x = np.array([1, 2, 3, 5, 4, 5]).astype(self.dtype) + self.inputs = {'X': self.x} + self.attrs = {'axis': self.axis} + self.outputs = {'Out': np.argmax(self.x, axis=self.axis)} + + +class TestArgMaxSameValue2(BaseTestCase): + + def initTestCase(self): + self.op_type = 'arg_max' + self.dtype = 'float16' + self.axis = 0 + + def setUp(self): + self.set_mlu() + self.initTestCase() + self.x = np.array([[2, 3, 5, 5], [3, 2, 5, 5]]).astype(self.dtype) + self.inputs = {'X': self.x} + self.attrs = {'axis': self.axis} + self.outputs = {'Out': np.argmax(self.x, axis=self.axis)} + + # test argmax, dtype: float16 class TestArgMaxFloat16Case1(BaseTestCase): -- GitLab