diff --git a/paddle/phi/kernels/cpu/reduce.h b/paddle/phi/kernels/cpu/reduce.h index bfcbe0eee1f60728c01663dd8121fe4f60ad0a09..4c44d3aa79b15622cf0bd0447ef135ffd390d300 100644 --- a/paddle/phi/kernels/cpu/reduce.h +++ b/paddle/phi/kernels/cpu/reduce.h @@ -52,6 +52,7 @@ void Reduce(const DeviceContext& dev_ctx, phi::funcs::ReduceKernelImpl( dev_ctx, x, out, dims, keep_dim, reduce_all); })); + } else { // cast x tensor to out_dtype auto tmp_tensor = phi::Cast(dev_ctx, x, out_dtype); @@ -65,7 +66,7 @@ void Reduce(const DeviceContext& dev_ctx, } } -template +template void BoolReduceKernel(const DeviceContext& dev_ctx, const phi::DenseTensor& input, const std::vector& dims, @@ -73,7 +74,7 @@ void BoolReduceKernel(const DeviceContext& dev_ctx, bool reduce_all, phi::DenseTensor* output) { reduce_all = recompute_reduce_all(input, dims, reduce_all); - dev_ctx.template Alloc(output); + dev_ctx.template Alloc(output); // The dims has full dim, set the reduce_all is True const auto& input_dim_size = input.dims().size(); @@ -86,9 +87,15 @@ void BoolReduceKernel(const DeviceContext& dev_ctx, } } reduce_all = (reduce_all || full_dim); - - funcs::ReduceKernelImpl( - dev_ctx, input, output, dims, keep_dim, reduce_all); + DenseTensor tmp_tensor; + if (input.dtype() != phi::DataType::BOOL) { + tmp_tensor = + phi::Cast(dev_ctx, input, phi::DataType::BOOL); + } else { + tmp_tensor = input; + } + funcs::ReduceKernelImpl( + dev_ctx, tmp_tensor, output, dims, keep_dim, reduce_all); } } // namespace phi diff --git a/paddle/phi/kernels/cpu/reduce_all_kernel.cc b/paddle/phi/kernels/cpu/reduce_all_kernel.cc index 1dea17d4b7af3066c74bd301944c7265fd977503..5c863b1a95a3cfef8066fcc903a76c842c9e0b86 100644 --- a/paddle/phi/kernels/cpu/reduce_all_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_all_kernel.cc @@ -35,6 +35,14 @@ void AllRawKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(all_raw, CPU, ALL_LAYOUT, phi::AllRawKernel, bool) { +PD_REGISTER_KERNEL(all_raw, + CPU, + ALL_LAYOUT, + phi::AllRawKernel, + float, + double, + int, + int64_t, + bool) { kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); } diff --git a/paddle/phi/kernels/cpu/reduce_any_kernel.cc b/paddle/phi/kernels/cpu/reduce_any_kernel.cc index 553393e7dba35a825ab808562fd1b14fb83bad70..cb82fc3a71cb9c75135be427328d47f868d3318c 100644 --- a/paddle/phi/kernels/cpu/reduce_any_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_any_kernel.cc @@ -35,4 +35,14 @@ void AnyRawKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(any_raw, CPU, ALL_LAYOUT, phi::AnyRawKernel, bool) {} +PD_REGISTER_KERNEL(any_raw, + CPU, + ALL_LAYOUT, + phi::AnyRawKernel, + float, + double, + int, + int64_t, + bool) { + kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); +} diff --git a/paddle/phi/kernels/kps/reduce_all_kernel.cu b/paddle/phi/kernels/kps/reduce_all_kernel.cu index c0c338bb4f249726c9a14ffd45f6f1f28ac7b4b3..f85ffdcba4f4207ee70e98337e9741903fd3bc22 100644 --- a/paddle/phi/kernels/kps/reduce_all_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_all_kernel.cu @@ -26,7 +26,7 @@ void AllRawKernel(const Context& dev_ctx, bool reduce_all, DenseTensor* out) { reduce_all = recompute_reduce_all(x, dims, reduce_all); - auto out_dtype = x.dtype(); + auto out_dtype = phi::DataType::BOOL; phi::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); } @@ -38,7 +38,15 @@ PD_REGISTER_KERNEL(all_raw, KPS, ALL_LAYOUT, phi::AllRawKernel, bool) { kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); } #else -PD_REGISTER_KERNEL(all_raw, KPS, ALL_LAYOUT, phi::AllRawKernel, bool) { +PD_REGISTER_KERNEL(all_raw, + KPS, + ALL_LAYOUT, + phi::AllRawKernel, + float, + double, + int, + int64_t, + bool) { kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); } #endif diff --git a/paddle/phi/kernels/kps/reduce_any_kernel.cu b/paddle/phi/kernels/kps/reduce_any_kernel.cu index 3210f23c3b205951cdf2116a4229f253f3d0b801..a6b79540e29278872f930c0cb9e8364ed9cbacb4 100644 --- a/paddle/phi/kernels/kps/reduce_any_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_any_kernel.cu @@ -26,7 +26,7 @@ void AnyRawKernel(const Context& dev_ctx, bool reduce_all, DenseTensor* out) { reduce_all = recompute_reduce_all(x, dims, reduce_all); - auto out_dtype = x.dtype(); + auto out_dtype = phi::DataType::BOOL; phi::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); } @@ -36,5 +36,15 @@ void AnyRawKernel(const Context& dev_ctx, #ifdef PADDLE_WITH_XPU_KP PD_REGISTER_KERNEL(any_raw, KPS, ALL_LAYOUT, phi::AnyRawKernel, bool) {} #else -PD_REGISTER_KERNEL(any_raw, KPS, ALL_LAYOUT, phi::AnyRawKernel, bool) {} +PD_REGISTER_KERNEL(any_raw, + KPS, + ALL_LAYOUT, + phi::AnyRawKernel, + float, + double, + int, + int64_t, + bool) { + kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); +} #endif diff --git a/paddle/phi/kernels/reduce_all_kernel.cc b/paddle/phi/kernels/reduce_all_kernel.cc index 9e799f0d219fc0c24e02043d04753c0cc00f3ca5..d6f88a596af3ac13132c7d2c8c1bc09dd9c5a533 100644 --- a/paddle/phi/kernels/reduce_all_kernel.cc +++ b/paddle/phi/kernels/reduce_all_kernel.cc @@ -38,10 +38,16 @@ void AllKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(all, CPU, ALL_LAYOUT, phi::AllKernel, bool) {} +PD_REGISTER_KERNEL( + all, CPU, ALL_LAYOUT, phi::AllKernel, float, double, int, int64_t, bool) { + kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); +} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PD_REGISTER_KERNEL(all, GPU, ALL_LAYOUT, phi::AllKernel, bool) {} +PD_REGISTER_KERNEL( + all, GPU, ALL_LAYOUT, phi::AllKernel, float, double, int, int64_t, bool) { + kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); +} #endif #if defined(PADDLE_WITH_XPU_KP) diff --git a/paddle/phi/kernels/reduce_any_kernel.cc b/paddle/phi/kernels/reduce_any_kernel.cc index 9d162f8e0203330d95dc3ef042a9cbfd03a1de63..076aacfa3ed82cb0e47020da3e35cf7a3d03e8a6 100644 --- a/paddle/phi/kernels/reduce_any_kernel.cc +++ b/paddle/phi/kernels/reduce_any_kernel.cc @@ -31,10 +31,16 @@ void AnyKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(any, CPU, ALL_LAYOUT, phi::AnyKernel, bool) {} +PD_REGISTER_KERNEL( + any, CPU, ALL_LAYOUT, phi::AnyKernel, float, double, int64_t, int, bool) { + kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); +} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PD_REGISTER_KERNEL(any, GPU, ALL_LAYOUT, phi::AnyKernel, bool) {} +PD_REGISTER_KERNEL( + any, GPU, ALL_LAYOUT, phi::AnyKernel, float, double, int, int64_t, bool) { + kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); +} #endif #if defined(PADDLE_WITH_XPU_KP) && !defined(PADDLE_WITH_XPU) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index e4172fefcc3fe7c0d5dd762ac1f72cb95c2744b3..974db208cbaecaa79354b403c26c65652ac02df6 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -4096,12 +4096,13 @@ def all(x, axis=None, keepdim=False, name=None): 'keep_dim': keepdim, 'reduce_all': reduce_all, } - check_variable_and_dtype(x, 'x', ['bool'], 'all') - + check_variable_and_dtype( + x, 'x', ['bool', 'float32', 'float64', 'int32', 'int64'], 'all' + ) check_type(axis, 'axis', (int, list, tuple, type(None)), 'all') helper = LayerHelper('all', **locals()) - out = helper.create_variable_for_type_inference(dtype=x.dtype) + out = helper.create_variable_for_type_inference(dtype=paddle.bool) helper.append_op( type='reduce_all', inputs={'X': x}, @@ -4170,13 +4171,13 @@ def any(x, axis=None, keepdim=False, name=None): 'keep_dim': keepdim, 'reduce_all': reduce_all, } - - check_variable_and_dtype(x, 'x', ['bool'], 'any') - + check_variable_and_dtype( + x, 'x', ['bool', 'float32', 'float64', 'int32', 'int64'], 'any' + ) check_type(axis, 'axis', (int, list, tuple, type(None)), 'any') helper = LayerHelper('any', **locals()) - out = helper.create_variable_for_type_inference(dtype=x.dtype) + out = helper.create_variable_for_type_inference(dtype=paddle.bool) helper.append_op( type='reduce_any', inputs={'X': x}, diff --git a/test/legacy_test/test_reduce_op.py b/test/legacy_test/test_reduce_op.py index 5875e959c35b2b1565903db64d35e4a67310c5fe..9617f39537b4be4f4dd0ddc8b12fd76c3e49e24e 100644 --- a/test/legacy_test/test_reduce_op.py +++ b/test/legacy_test/test_reduce_op.py @@ -804,6 +804,30 @@ class TestAllOp(OpTest): self.check_output() +class TestAllFloatOp(OpTest): + def setUp(self): + self.op_type = "reduce_all" + self.python_api = reduce_all_wrapper + self.inputs = {'X': np.random.randint(0, 2, (5, 6, 10)).astype("float")} + self.outputs = {'Out': self.inputs['X'].all()} + self.attrs = {'reduce_all': True} + + def test_check_output(self): + self.check_output() + + +class TestAllIntOp(OpTest): + def setUp(self): + self.op_type = "reduce_all" + self.python_api = reduce_all_wrapper + self.inputs = {'X': np.random.randint(0, 2, (5, 6, 10)).astype("int")} + self.outputs = {'Out': self.inputs['X'].all()} + self.attrs = {'reduce_all': True} + + def test_check_output(self): + self.check_output() + + class TestAllOp_ZeroDim(OpTest): def setUp(self): self.python_api = paddle.all @@ -900,11 +924,6 @@ class TestAllOpError(unittest.TestCase): # The input type of reduce_all_op must be Variable. input1 = 12 self.assertRaises(TypeError, paddle.all, input1) - # The input dtype of reduce_all_op must be bool. - input2 = paddle.static.data( - name='input2', shape=[-1, 12, 10], dtype="int32" - ) - self.assertRaises(TypeError, paddle.all, input2) def reduce_any_wrapper(x, axis=None, keepdim=False, reduce_all=True, name=None): @@ -923,6 +942,30 @@ class TestAnyOp(OpTest): self.check_output() +class TestAnyFloatOp(OpTest): + def setUp(self): + self.op_type = "reduce_any" + self.python_api = reduce_any_wrapper + self.inputs = {'X': np.random.randint(0, 2, (5, 6, 10)).astype("float")} + self.outputs = {'Out': self.inputs['X'].any()} + self.attrs = {'reduce_all': True} + + def test_check_output(self): + self.check_output() + + +class TestAnyIntOp(OpTest): + def setUp(self): + self.op_type = "reduce_any" + self.python_api = reduce_any_wrapper + self.inputs = {'X': np.random.randint(0, 2, (5, 6, 10)).astype("int")} + self.outputs = {'Out': self.inputs['X'].any()} + self.attrs = {'reduce_all': True} + + def test_check_output(self): + self.check_output() + + class TestAnyOp_ZeroDim(OpTest): def setUp(self): self.python_api = paddle.any @@ -1021,11 +1064,6 @@ class TestAnyOpError(unittest.TestCase): # The input type of reduce_any_op must be Variable. input1 = 12 self.assertRaises(TypeError, paddle.any, input1) - # The input dtype of reduce_any_op must be bool. - input2 = paddle.static.data( - name='input2', shape=[-1, 12, 10], dtype="int32" - ) - self.assertRaises(TypeError, paddle.any, input2) class Test1DReduce(OpTest): @@ -1645,11 +1683,43 @@ class TestAllAPI(unittest.TestCase): feed={"input": input_np}, fetch_list=[result], ) - np.testing.assert_allclose(fetches[0], np.all(input_np), rtol=1e-05) + self.assertTrue((fetches[0] == np.all(input_np)).all()) + + def check_static_float_result(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input = paddle.static.data( + name="input", shape=[4, 4], dtype="float" + ) + result = paddle.all(x=input) + input_np = np.random.randint(0, 2, [4, 4]).astype("float") + + exe = fluid.Executor(place) + fetches = exe.run( + fluid.default_main_program(), + feed={"input": input_np}, + fetch_list=[result], + ) + self.assertTrue((fetches[0] == np.all(input_np)).all()) + + def check_static_int_result(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input = paddle.static.data(name="input", shape=[4, 4], dtype="int") + result = paddle.all(x=input) + input_np = np.random.randint(0, 2, [4, 4]).astype("int") + + exe = fluid.Executor(place) + fetches = exe.run( + fluid.default_main_program(), + feed={"input": input_np}, + fetch_list=[result], + ) + self.assertTrue((fetches[0] == np.all(input_np)).all()) def test_static(self): for place in self.places: self.check_static_result(place=place) + self.check_static_float_result(place=place) + self.check_static_int_result(place=place) def test_dygraph(self): paddle.disable_static() @@ -1679,6 +1749,18 @@ class TestAllAPI(unittest.TestCase): expect_res4 = np.all(np_x, axis=1, keepdims=True) self.assertTrue((np_out4 == expect_res4).all()) + x = paddle.cast(x, 'float') + out5 = paddle.all(x) + np_out5 = out5.numpy() + expect_res5 = np.all(np_x) + self.assertTrue((np_out5 == expect_res5).all()) + + x = paddle.cast(x, 'int') + out6 = paddle.all(x) + np_out6 = out6.numpy() + expect_res6 = np.all(np_x) + self.assertTrue((np_out6 == expect_res6).all()) + paddle.enable_static() @@ -1702,11 +1784,43 @@ class TestAnyAPI(unittest.TestCase): feed={"input": input_np}, fetch_list=[result], ) - np.testing.assert_allclose(fetches[0], np.any(input_np), rtol=1e-05) + self.assertTrue((fetches[0] == np.any(input_np)).all()) + + def check_static_float_result(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input = paddle.static.data( + name="input", shape=[4, 4], dtype="float" + ) + result = paddle.any(x=input) + input_np = np.random.randint(0, 2, [4, 4]).astype("float") + + exe = fluid.Executor(place) + fetches = exe.run( + fluid.default_main_program(), + feed={"input": input_np}, + fetch_list=[result], + ) + self.assertTrue((fetches[0] == np.any(input_np)).all()) + + def check_static_int_result(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input = paddle.static.data(name="input", shape=[4, 4], dtype="int") + result = paddle.any(x=input) + input_np = np.random.randint(0, 2, [4, 4]).astype("int") + + exe = fluid.Executor(place) + fetches = exe.run( + fluid.default_main_program(), + feed={"input": input_np}, + fetch_list=[result], + ) + self.assertTrue((fetches[0] == np.any(input_np)).all()) def test_static(self): for place in self.places: self.check_static_result(place=place) + self.check_static_float_result(place=place) + self.check_static_int_result(place=place) def test_dygraph(self): paddle.disable_static() @@ -1736,6 +1850,21 @@ class TestAnyAPI(unittest.TestCase): expect_res4 = np.any(np_x, axis=1, keepdims=True) self.assertTrue((np_out4 == expect_res4).all()) + np_x = np.random.randint(0, 2, (12, 10)).astype(np.float32) + x = paddle.assign(np_x) + x = paddle.cast(x, 'float32') + + out5 = paddle.any(x) + np_out5 = out5.numpy() + expect_res5 = np.any(np_x) + self.assertTrue((np_out5 == expect_res5).all()) + + x = paddle.cast(x, 'int') + out6 = paddle.any(x) + np_out6 = out6.numpy() + expect_res6 = np.any(np_x) + self.assertTrue((np_out6 == expect_res6).all()) + paddle.enable_static()