diff --git a/paddle/pten/kernels/hybird/eigen/reduce.h b/paddle/pten/kernels/hybird/eigen/reduce.h index 52ea1e68e12ae9b3e5ca34a2630b0bd88ff8ba73..e6ab872928c77dab3f22a9ce3af24f5ca29256ae 100644 --- a/paddle/pten/kernels/hybird/eigen/reduce.h +++ b/paddle/pten/kernels/hybird/eigen/reduce.h @@ -82,10 +82,18 @@ inline void GetShuffledDim(const DDim& src_dims, std::vector src_dims_check(src_dims.size(), false); size_t src_size = src_dims.size(); size_t reduce_size = reduced_dims.size(); + std::vector regular_reduced_dims = reduced_dims; + for (size_t i = 0; i < regular_reduced_dims.size(); i++) { + if (regular_reduced_dims[i] < 0) { + regular_reduced_dims[i] = src_size + regular_reduced_dims[i]; + } + } + for (size_t i = 0; i < reduce_size; ++i) { - dst_dims->at(src_size - reduce_size + i) = src_dims[reduced_dims[i]]; - (*perm_axis)[src_size - reduce_size + i] = reduced_dims[i]; - src_dims_check[reduced_dims[i]] = true; + dst_dims->at(src_size - reduce_size + i) = + src_dims[regular_reduced_dims[i]]; + (*perm_axis)[src_size - reduce_size + i] = regular_reduced_dims[i]; + src_dims_check[regular_reduced_dims[i]] = true; } size_t offset = 0; diff --git a/python/paddle/fluid/tests/unittests/test_max_op.py b/python/paddle/fluid/tests/unittests/test_max_op.py index 3a1dbc8f95f904c58a1e05a934ddf0adcc5cee25..caee7d9e5c2bab577f6c48a64325c22be3c8c7e7 100644 --- a/python/paddle/fluid/tests/unittests/test_max_op.py +++ b/python/paddle/fluid/tests/unittests/test_max_op.py @@ -86,6 +86,18 @@ class ApiMaxTest(unittest.TestCase): z_expected = np.array(np.max(np_x, axis=0)) self.assertEqual((np_z == z_expected).all(), True) + def test_big_dimension(self): + paddle.disable_static() + x = paddle.rand(shape=[2, 2, 2, 2, 2, 2, 2]) + np_x = x.numpy() + z1 = paddle.max(x, axis=-1) + z2 = paddle.max(x, axis=6) + np_z1 = z1.numpy() + np_z2 = z2.numpy() + z_expected = np.array(np.max(np_x, axis=6)) + self.assertEqual((np_z1 == z_expected).all(), True) + self.assertEqual((np_z2 == z_expected).all(), True) + class TestOutDtype(unittest.TestCase): def test_max(self):