diff --git a/paddle/pten/api/ext/dispatch.h b/paddle/pten/api/ext/dispatch.h index 945a9557c40e05c93791ce40072ed3875f27d8ea..93b4226e66d5648a69a9d2e97d687c3412d58bb5 100644 --- a/paddle/pten/api/ext/dispatch.h +++ b/paddle/pten/api/ext/dispatch.h @@ -302,4 +302,44 @@ namespace paddle { } \ }() +#define PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_3_TYPES( \ + SPECIFIED_TYPE1, SPECIFIED_TYPE2, SPECIFIED_TYPE3, TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::BOOL, bool, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::FLOAT64, double, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, \ + ::paddle::DataType::COMPLEX64, \ + ::paddle::complex64, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, \ + ::paddle::DataType::COMPLEX128, \ + ::paddle::complex128, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, \ + SPECIFIED_TYPE1, \ + ::paddle::experimental::DataTypeToCppType::type, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, \ + SPECIFIED_TYPE2, \ + ::paddle::experimental::DataTypeToCppType::type, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, \ + SPECIFIED_TYPE3, \ + ::paddle::experimental::DataTypeToCppType::type, \ + __VA_ARGS__) \ + default: \ + PD_THROW("function " #NAME " is not implemented for data type `", \ + __dtype__, \ + "`"); \ + } \ + }() + } // namespace paddle diff --git a/paddle/pten/kernels/cpu/reduce.h b/paddle/pten/kernels/cpu/reduce.h index 2b0659ac2e35746634fe01e6ee5263aabf4806a6..bdf9e65f541886900c58163bda4604e7f338c0c0 100644 --- a/paddle/pten/kernels/cpu/reduce.h +++ b/paddle/pten/kernels/cpu/reduce.h @@ -220,22 +220,15 @@ void Reduce(const DeviceContext& dev_ctx, // no need to cast dtype if (out_dtype == pten::DataType::UNDEFINED || out_dtype == x.dtype()) { - if (out_dtype == pten::DataType::UNDEFINED) { - out_dtype = x.dtype(); - } // do reduce sum PD_VISIT_ALL_TYPES( - out_dtype, "ReduceKernelImpl", ([&] { + x.dtype(), "ReduceKernelImpl", ([&] { pten::ReduceKernelImpl( dev_ctx, x, out, dims, keep_dim, reduce_all); })); } else { - pten::DenseTensor tmp_tensor = pten::DenseTensor( - pten::make_intrusive(x.place()), - pten::DenseTensorMeta(out_dtype, x.dims(), x.layout())); - // cast x tensor to out_dtype - pten::CastKernel(dev_ctx, x, out_dtype, &tmp_tensor); + auto tmp_tensor = pten::Cast(dev_ctx, x, out_dtype); // do reduce sum PD_VISIT_ALL_TYPES( diff --git a/paddle/pten/kernels/gpu/reduce.h b/paddle/pten/kernels/gpu/reduce.h index 7a76a988dee253c8e8ec5733a64534ea93307013..b0fbef9a18e8882d6383ea5ef93eb301fa56a220 100644 --- a/paddle/pten/kernels/gpu/reduce.h +++ b/paddle/pten/kernels/gpu/reduce.h @@ -43,6 +43,7 @@ namespace cub = hipcub; #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/enforce.h" #include "paddle/pten/core/utils/array.h" +#include "paddle/pten/kernels/cast_kernel.h" #include "paddle/pten/kernels/funcs/elementwise_base.h" #include "paddle/pten/kernels/primitive/kernel_primitives.h" @@ -1232,21 +1233,23 @@ void Reduce(const GPUContext& dev_ctx, gpuStream_t stream = dev_ctx.stream(); if (out_dtype != pten::DataType::UNDEFINED && out_dtype != x.dtype()) { - PD_DISPATCH_FLOATING_AND_COMPLEX_AND_2_TYPES( + auto tmp_tensor = pten::Cast(dev_ctx, x, out_dtype); + PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_3_TYPES( pten::DataType::INT32, pten::DataType::INT64, + pten::DataType::FLOAT16, out_dtype, "TensorReduceFunctorImpl", ([&] { using MPType = typename kps::details::MPTypeTrait::Type; - pten::kernels::TensorReduceFunctorImpl>( + TransformOp>( dev_ctx, - x, + tmp_tensor, out, - TransformOp(reduce_num), + TransformOp(reduce_num), reduce_dims, stream); })); diff --git a/python/paddle/fluid/tests/unittests/test_reduce_op.py b/python/paddle/fluid/tests/unittests/test_reduce_op.py index 25bf60334e7f345bef1f525780a596b51a3d6353..faa67e1d6da8f44bf1a09036d0d1dc9e49ff462c 100644 --- a/python/paddle/fluid/tests/unittests/test_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_reduce_op.py @@ -740,17 +740,24 @@ class API_TestSumOp(unittest.TestCase): if np_axis is None: np_axis = attr_axis - with fluid.program_guard(fluid.Program(), fluid.Program()): - data = fluid.data("data", shape=shape, dtype=x_dtype) - result_sum = paddle.sum(x=data, axis=attr_axis, dtype=attr_dtype) - - exe = fluid.Executor(fluid.CPUPlace()) - input_data = np.random.rand(*shape).astype(x_dtype) - res, = exe.run(feed={"data": input_data}, fetch_list=[result_sum]) - - self.assertTrue( - np.allclose( - res, np.sum(input_data.astype(attr_dtype), axis=np_axis))) + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for place in places: + with fluid.program_guard(fluid.Program(), fluid.Program()): + data = fluid.data("data", shape=shape, dtype=x_dtype) + result_sum = paddle.sum(x=data, + axis=attr_axis, + dtype=attr_dtype) + + exe = fluid.Executor(place) + input_data = np.random.rand(*shape).astype(x_dtype) + res, = exe.run(feed={"data": input_data}, + fetch_list=[result_sum]) + + self.assertTrue( + np.allclose( + res, np.sum(input_data.astype(attr_dtype), axis=np_axis))) def test_static(self): shape = [10, 10] @@ -759,10 +766,12 @@ class API_TestSumOp(unittest.TestCase): self.run_static(shape, "bool", axis, attr_dtype=None) self.run_static(shape, "bool", axis, attr_dtype="int32") self.run_static(shape, "bool", axis, attr_dtype="int64") + self.run_static(shape, "bool", axis, attr_dtype="float16") self.run_static(shape, "int32", axis, attr_dtype=None) self.run_static(shape, "int32", axis, attr_dtype="int32") self.run_static(shape, "int32", axis, attr_dtype="int64") + self.run_static(shape, "int32", axis, attr_dtype="float64") self.run_static(shape, "int64", axis, attr_dtype=None) self.run_static(shape, "int64", axis, attr_dtype="int64") @@ -771,6 +780,7 @@ class API_TestSumOp(unittest.TestCase): self.run_static(shape, "float32", axis, attr_dtype=None) self.run_static(shape, "float32", axis, attr_dtype="float32") self.run_static(shape, "float32", axis, attr_dtype="float64") + self.run_static(shape, "float32", axis, attr_dtype="int64") self.run_static(shape, "float64", axis, attr_dtype=None) self.run_static(shape, "float64", axis, attr_dtype="float32")