未验证 提交 995332ef 编写于 作者: S Shang Zhizhou 提交者: GitHub

fix reduce_max/reduce_min bug (#38476)

上级 20403fe9
......@@ -31,11 +31,12 @@ void Reduce(const DeviceContext& dev_ctx,
DataType out_dtype,
DenseTensor* out) {
// If the dims has full dim, set the reduce_all is True
const auto& input_dim_size = x.dims().size();
const int& input_dim_size = x.dims().size();
std::set<int> dims_set(dims.begin(), dims.end());
bool full_dim = true;
for (auto i = 0; i < input_dim_size; ++i) {
if (dims_set.find(i) == dims_set.end()) {
for (int i = 0; i < input_dim_size; ++i) {
if (dims_set.find(i) == dims_set.end() &&
dims_set.find(i - input_dim_size) == dims_set.end()) {
full_dim = false;
break;
}
......
......@@ -98,6 +98,15 @@ class ApiMaxTest(unittest.TestCase):
self.assertEqual((np_z1 == z_expected).all(), True)
self.assertEqual((np_z2 == z_expected).all(), True)
def test_all_negative_axis(self):
paddle.disable_static()
x = paddle.rand(shape=[2, 2])
np_x = x.numpy()
z1 = paddle.max(x, axis=(-2, -1))
np_z1 = z1.numpy()
z_expected = np.array(np.max(np_x, axis=(0, 1)))
self.assertEqual((np_z1 == z_expected).all(), True)
class TestOutDtype(unittest.TestCase):
def test_max(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册