未验证 提交 512e4339 编写于 作者: S Shang Zhizhou 提交者: GitHub

fix reduce_max bug (#38026)

* fix reduce_max bug

* add unittest
上级 92ad682f
...@@ -82,10 +82,18 @@ inline void GetShuffledDim(const DDim& src_dims, ...@@ -82,10 +82,18 @@ inline void GetShuffledDim(const DDim& src_dims,
std::vector<bool> src_dims_check(src_dims.size(), false); std::vector<bool> src_dims_check(src_dims.size(), false);
size_t src_size = src_dims.size(); size_t src_size = src_dims.size();
size_t reduce_size = reduced_dims.size(); size_t reduce_size = reduced_dims.size();
std::vector<int64_t> regular_reduced_dims = reduced_dims;
for (size_t i = 0; i < regular_reduced_dims.size(); i++) {
if (regular_reduced_dims[i] < 0) {
regular_reduced_dims[i] = src_size + regular_reduced_dims[i];
}
}
for (size_t i = 0; i < reduce_size; ++i) { for (size_t i = 0; i < reduce_size; ++i) {
dst_dims->at(src_size - reduce_size + i) = src_dims[reduced_dims[i]]; dst_dims->at(src_size - reduce_size + i) =
(*perm_axis)[src_size - reduce_size + i] = reduced_dims[i]; src_dims[regular_reduced_dims[i]];
src_dims_check[reduced_dims[i]] = true; (*perm_axis)[src_size - reduce_size + i] = regular_reduced_dims[i];
src_dims_check[regular_reduced_dims[i]] = true;
} }
size_t offset = 0; size_t offset = 0;
......
...@@ -86,6 +86,18 @@ class ApiMaxTest(unittest.TestCase): ...@@ -86,6 +86,18 @@ class ApiMaxTest(unittest.TestCase):
z_expected = np.array(np.max(np_x, axis=0)) z_expected = np.array(np.max(np_x, axis=0))
self.assertEqual((np_z == z_expected).all(), True) self.assertEqual((np_z == z_expected).all(), True)
def test_big_dimension(self):
paddle.disable_static()
x = paddle.rand(shape=[2, 2, 2, 2, 2, 2, 2])
np_x = x.numpy()
z1 = paddle.max(x, axis=-1)
z2 = paddle.max(x, axis=6)
np_z1 = z1.numpy()
np_z2 = z2.numpy()
z_expected = np.array(np.max(np_x, axis=6))
self.assertEqual((np_z1 == z_expected).all(), True)
self.assertEqual((np_z2 == z_expected).all(), True)
class TestOutDtype(unittest.TestCase): class TestOutDtype(unittest.TestCase):
def test_max(self): def test_max(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册