未验证 提交 0d9b12fa 编写于 作者: S Shang Zhizhou 提交者: GitHub

fix reduce_max/reduce_min bug (#38478)

上级 462ee101
......@@ -241,11 +241,12 @@ class ReduceKernel : public framework::OpKernel<T> {
framework::proto::VarType::Type cast_out_dtype;
// The dims has full dim, set the reduce_all is True
const auto& input_dim_size = context.Input<Tensor>("X")->dims().size();
const int& input_dim_size = context.Input<Tensor>("X")->dims().size();
std::set<int> dims_set(dims.begin(), dims.end());
bool full_dim = true;
for (auto i = 0; i < input_dim_size; i++) {
if (dims_set.find(i) == dims_set.end()) {
for (int i = 0; i < input_dim_size; i++) {
if (dims_set.find(i) == dims_set.end() &&
dims_set.find(i - input_dim_size) == dims_set.end()) {
full_dim = false;
break;
}
......
......@@ -98,6 +98,15 @@ class ApiMaxTest(unittest.TestCase):
self.assertEqual((np_z1 == z_expected).all(), True)
self.assertEqual((np_z2 == z_expected).all(), True)
def test_all_negative_axis(self):
paddle.disable_static()
x = paddle.rand(shape=[2, 2])
np_x = x.numpy()
z1 = paddle.max(x, axis=(-2, -1))
np_z1 = z1.numpy()
z_expected = np.array(np.max(np_x, axis=(0, 1)))
self.assertEqual((np_z1 == z_expected).all(), True)
class TestOutDtype(unittest.TestCase):
def test_max(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册