diff --git a/paddle/phi/kernels/gpu/cum_kernel.cu b/paddle/phi/kernels/gpu/cum_kernel.cu index 1db74770a7da69d67f06ff0172f476bb110bed78..a471790771ec9dd5f20e903f8f0855e8cd2d4c57 100644 --- a/paddle/phi/kernels/gpu/cum_kernel.cu +++ b/paddle/phi/kernels/gpu/cum_kernel.cu @@ -27,11 +27,25 @@ namespace cub = hipcub; #endif #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/common/float16.h" #include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/kernel_registry.h" namespace phi { +template +class CumTypeTrait { + public: + using Type = T; +}; + +template <> +class CumTypeTrait { + public: + using Type = __half; +}; + template __device__ void BlockReverse( const T* idata, T* odata, int src_base, int dst_base, int valid_item) { @@ -39,7 +53,7 @@ __device__ void BlockReverse( int tx = threadIdx.x; int offset = tx; - T src_data = 0; + T src_data = static_cast(0); int src_offset = BLOCK_SIZE - offset - 1; if (src_offset < valid_item) { src_data = idata[src_base + src_offset]; @@ -160,14 +174,18 @@ __global__ void BlockScanKernel(T* d_out, int scan_size, bool exclusive, Op op) { + using MT = typename phi::dtype::MPTypeTrait::Type; + // Specialize BlockLoad, BlockStore, and BlockRadixSort collective types typedef cub:: - BlockLoad + BlockLoad BlockLoadT; - typedef cub:: - BlockStore - BlockStoreT; - typedef cub::BlockScan BlockScanT; + typedef cub::BlockStore + BlockStoreT; + typedef cub::BlockScan BlockScanT; // Allocate type-safe, repurposable shared memory for collectives __shared__ union { typename BlockLoadT::TempStorage load; @@ -176,8 +194,7 @@ __global__ void BlockScanKernel(T* d_out, } temp_storage; int bx = blockIdx.x; - - BlockPrefixCallbackOp prefix_op(Identity::value, op); + BlockPrefixCallbackOp prefix_op(Identity::value, op); // Obtain this block's segment of consecutive keys (blocked across threads) int item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD; @@ -192,7 +209,7 @@ __global__ void BlockScanKernel(T* d_out, int offset = block_offset + bx * scan_size; - T thread_keys[ITEMS_PER_THREAD]; + MT thread_keys[ITEMS_PER_THREAD]; BlockLoadT(temp_storage.load) .Load(d_in + offset, thread_keys, valid_item, 0); @@ -241,17 +258,22 @@ void ScanKernel(const Context& dev_ctx, // Use thrust for parallel acceleration when the input size is equal to the // length of the ‘axis’ dimension. - if (std::is_same::value && size == out_dims[axis]) { + if (!std::is_same::value && + std::is_same::value && size == out_dims[axis]) { #ifdef __HIPCC__ const auto& policy = thrust::hip::par.on(dev_ctx.stream()); #else const auto& policy = thrust::cuda::par.on(dev_ctx.stream()); #endif + + using CumType = typename CumTypeTrait::Type; + CumType* out_data_ptr = reinterpret_cast(out_data); + const CumType* in_data_ptr = reinterpret_cast(in_data); if (reverse) { - thrust::reverse_iterator> reversed_in( - thrust::device_pointer_cast(in_data) + size); - thrust::reverse_iterator> reversed_out( - thrust::device_pointer_cast(out_data) + size); + thrust::reverse_iterator> reversed_in( + thrust::device_pointer_cast(in_data_ptr) + size); + thrust::reverse_iterator> reversed_out( + thrust::device_pointer_cast(out_data_ptr) + size); if (exclusive) { thrust::exclusive_scan( policy, reversed_in, reversed_in + size, reversed_out); @@ -261,11 +283,14 @@ void ScanKernel(const Context& dev_ctx, } } else { if (exclusive) { - thrust::exclusive_scan(policy, in_data, in_data + size, out_data); + thrust::exclusive_scan( + policy, in_data_ptr, in_data_ptr + size, out_data_ptr); } else { - thrust::inclusive_scan(policy, in_data, in_data + size, out_data); + thrust::inclusive_scan( + policy, in_data_ptr, in_data_ptr + size, out_data_ptr); } } + return; } @@ -305,7 +330,6 @@ void ScanKernel(const Context& dev_ctx, int outer_size = height / scan_size; int inner_size = width; // Consider the size of shared memory, here block size is 128 - dim3 scan_grid(outer_size, inner_size); dim3 reverse_grid = scan_grid; if (reverse) { @@ -380,6 +404,7 @@ void LogcumsumexpKernel(const Context& dev_ctx, } // namespace phi +#ifdef PADDLE_WITH_HIP PD_REGISTER_KERNEL(cumsum, GPU, ALL_LAYOUT, @@ -392,3 +417,23 @@ PD_REGISTER_KERNEL(cumsum, PD_REGISTER_KERNEL( logcumsumexp, GPU, ALL_LAYOUT, phi::LogcumsumexpKernel, float, double) {} +#else +PD_REGISTER_KERNEL(cumsum, + GPU, + ALL_LAYOUT, + phi::CumsumKernel, + float, + double, + int16_t, + int, + int64_t, + phi::dtype::float16) {} + +PD_REGISTER_KERNEL(logcumsumexp, + GPU, + ALL_LAYOUT, + phi::LogcumsumexpKernel, + float, + double, + phi::dtype::float16) {} +#endif diff --git a/paddle/phi/kernels/gpu/logcumsumexp_grad_kernel.cu b/paddle/phi/kernels/gpu/logcumsumexp_grad_kernel.cu index 9f4633a1e021ac77d9e47677b10fdc7ef9f8a9d2..a2044cc3afe72e5f2160ed89dbc4d09dfca4c321 100644 --- a/paddle/phi/kernels/gpu/logcumsumexp_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/logcumsumexp_grad_kernel.cu @@ -20,9 +20,19 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/logcumsumexp_grad_impl.h" +#ifdef PADDLE_WITH_HIP PD_REGISTER_KERNEL(logcumsumexp_grad, GPU, ALL_LAYOUT, phi::LogcumsumexpGradKernel, float, double) {} +#else +PD_REGISTER_KERNEL(logcumsumexp_grad, + GPU, + ALL_LAYOUT, + phi::LogcumsumexpGradKernel, + phi::dtype::float16, + float, + double) {} +#endif diff --git a/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h b/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h index 602f2248902ccb9b17d85dd184ab8fe94807a493..85a530b1b7559733a1393a9429969913da5eee23 100644 --- a/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h +++ b/paddle/phi/kernels/impl/logcumsumexp_grad_impl.h @@ -12,9 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include +#pragma once +#include #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/cum_kernel.h" #include "paddle/phi/kernels/funcs/eigen/common.h" @@ -55,32 +57,38 @@ void LogcumsumexpGradKernel(const Context& dev_ctx, auto eigen_d_out = EigenVector::Flatten(d_out); auto& place = *dev_ctx.eigen_device(); + using MT = typename phi::dtype::MPTypeTrait::Type; DenseTensor output_pos; output_pos.Resize(d_out.dims()); - dev_ctx.template Alloc(&output_pos); - auto eigen_output_pos = EigenVector::Flatten(output_pos); + dev_ctx.template Alloc(&output_pos); + auto eigen_output_pos = EigenVector::Flatten(output_pos); DenseTensor output_neg; output_neg.Resize(d_out.dims()); - dev_ctx.template Alloc(&output_neg); - auto eigen_output_neg = EigenVector::Flatten(output_neg); + dev_ctx.template Alloc(&output_neg); + auto eigen_output_neg = EigenVector::Flatten(output_neg); DenseTensor tmp; tmp.Resize(d_out.dims()); - dev_ctx.template Alloc(&tmp); - auto eigen_tmp = EigenVector::Flatten(tmp); + dev_ctx.template Alloc(&tmp); + auto eigen_tmp = EigenVector::Flatten(tmp); eigen_tmp.device(place) = - eigen_d_out.unaryExpr(LogGradPositiveFunctor()) - eigen_out; - LogcumsumexpKernel( + eigen_d_out.template cast().unaryExpr(LogGradPositiveFunctor()) - + eigen_out.template cast(); + LogcumsumexpKernel( dev_ctx, tmp, axis, flatten, exclusive, reverse, &output_pos); - eigen_output_pos.device(place) = (eigen_output_pos + eigen_x).exp(); + auto out_pos = eigen_output_pos + eigen_x.template cast(); + eigen_output_pos.device(place) = out_pos.exp(); eigen_tmp.device(place) = - eigen_d_out.unaryExpr(LogGradNegativeFunctor()) - eigen_out; - LogcumsumexpKernel( + eigen_d_out.template cast().unaryExpr(LogGradNegativeFunctor()) - + eigen_out.template cast(); + LogcumsumexpKernel( dev_ctx, tmp, axis, flatten, exclusive, reverse, &output_neg); - eigen_output_neg.device(place) = (eigen_output_neg + eigen_x).exp(); + auto out_neg = eigen_output_neg + eigen_x.template cast(); + eigen_output_neg.device(place) = out_neg.exp(); auto eigen_d_x = EigenVector::Flatten(*d_x); - eigen_d_x.device(place) = eigen_output_pos - eigen_output_neg; + eigen_d_x.device(place) = + (eigen_output_pos - eigen_output_neg).template cast(); } } // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_cumsum_op.py b/python/paddle/fluid/tests/unittests/test_cumsum_op.py index 1c6c19f5a9e062a0b3cb617305aefe8f35ea2c89..e63252c2c0897d2347668b3112573de408115514 100644 --- a/python/paddle/fluid/tests/unittests/test_cumsum_op.py +++ b/python/paddle/fluid/tests/unittests/test_cumsum_op.py @@ -199,6 +199,32 @@ class TestSumOp7(OpTest): self.check_grad(['X'], 'Out') +class TestCumsumFP16(unittest.TestCase): + + def check_main(self, x_np, dtype): + paddle.disable_static() + x = paddle.to_tensor(x_np.astype(dtype)) + x.stop_gradient = False + y = paddle.cumsum(x, dtype=dtype) + x_g = paddle.grad(y, [x]) + y_np = y.numpy().astype('float32') + x_g_np = x_g[0].numpy().astype('float32') + paddle.enable_static() + return y_np, x_g_np + + def test_main(self): + if not paddle.is_compiled_with_cuda(): + return + + np.random.seed(20) + x_np = np.random.random([10, 12]) + y_np_1, x_g_np_1 = self.check_main(x_np, 'float16') + y_np_2, x_g_np_2 = self.check_main(x_np, 'float32') + + np.testing.assert_allclose(y_np_1, y_np_2, rtol=1e-03) + np.testing.assert_allclose(x_g_np_1, x_g_np_2, rtol=1e-03) + + class TestSumOpExclusive1(OpTest): def setUp(self): @@ -289,6 +315,24 @@ class TestSumOpExclusive5(OpTest): self.check_output() +class TestSumOpExclusiveFP16(OpTest): + + def setUp(self): + self.op_type = "cumsum" + self.attrs = {'axis': 2, "exclusive": True, "dtype": "float16"} + a = np.random.random((4, 5, 3096)).astype("float64") + self.inputs = {'X': a} + self.outputs = { + 'Out': + np.concatenate((np.zeros( + (4, 5, 1), dtype=np.float64), a[:, :, :-1].cumsum(axis=2)), + axis=2) + } + + def test_check_output(self): + self.check_output() + + class TestSumOpReverseExclusive(OpTest): def setUp(self): diff --git a/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py b/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py index 566b65624e5729ada7d8953b7b88624e1e1594df..6064303a5217883d7521f302adcb4161c2a6c2ff 100644 --- a/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py +++ b/python/paddle/fluid/tests/unittests/test_logcumsumexp_op.py @@ -210,6 +210,8 @@ class BaseTestCases: input, attrs = self.input_and_attrs() self.inputs = {'X': input} self.attrs = attrs + if "dtype" in attrs: + del attrs["dtype"] self.outputs = {'Out': np_logcumsumexp(input, **attrs)} def test_check_output(self): @@ -264,5 +266,36 @@ class TestLogcumsumexpOp4(BaseTestCases.BaseOpTest): } +class TestLogcumsumexpFP16(unittest.TestCase): + + def check_main(self, x_np, dtype, axis=None): + paddle.disable_static() + x = paddle.to_tensor(x_np.astype(dtype)) + x.stop_gradient = False + y = paddle.logcumsumexp(x, dtype=dtype, axis=axis) + x_g = paddle.grad(y, [x]) + y_np = y.numpy().astype('float32') + x_g_np = x_g[0].numpy().astype('float32') + paddle.enable_static() + return y_np, x_g_np + + def test_main(self): + if not paddle.is_compiled_with_cuda(): + return + + np.random.seed(20) + x_np = np.random.random([10, 12]) + + y_np_1, x_g_np_1 = self.check_main(x_np, 'float16') + y_np_2, x_g_np_2 = self.check_main(x_np, 'float32') + np.testing.assert_allclose(y_np_1, y_np_2, rtol=1e-03) + np.testing.assert_allclose(x_g_np_1, x_g_np_2, rtol=1e-03) + + y_np_1, x_g_np_1 = self.check_main(x_np, 'float16', axis=1) + y_np_2, x_g_np_2 = self.check_main(x_np, 'float32', axis=1) + np.testing.assert_allclose(y_np_1, y_np_2, rtol=1e-03) + np.testing.assert_allclose(x_g_np_1, x_g_np_2, rtol=2e-03) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index ffe426876ad1cae1832a379ba93259a4e36b4634..a2720ec48f22bc78725ae867dd7ce50178ab4dc7 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -3173,7 +3173,7 @@ def cumsum(x, axis=None, dtype=None, name=None): Args: x (Tensor): The input tensor needed to be cumsumed. axis (int, optional): The dimension to accumulate along. -1 means the last dimension. The default (None) is to compute the cumsum over the flattened array. - dtype (str, optional): The data type of the output tensor, can be float32, float64, int32, int64. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. The default value is None. + dtype (str, optional): The data type of the output tensor, can be float16, float32, float64, int32, int64. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. The default value is None. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -3246,7 +3246,7 @@ def logcumsumexp(x, axis=None, dtype=None, name=None): Args: x (Tensor): The input tensor. axis (int, optional): The dimension to do the operation along. -1 means the last dimension. The default (None) is to compute the cumsum over the flattened array. - dtype (str, optional): The data type of the output tensor, can be float32, float64. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. The default value is None. + dtype (str, optional): The data type of the output tensor, can be float16, float32, float64. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. The default value is None. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -3295,7 +3295,8 @@ def logcumsumexp(x, axis=None, dtype=None, name=None): return _legacy_C_ops.logcumsumexp(x, 'axis', axis, 'flatten', flatten) - check_variable_and_dtype(x, 'x', ['float32', 'float64'], "logcumsumexp") + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + "logcumsumexp") helper = LayerHelper('logcumsumexp', **locals()) out = helper.create_variable_for_type_inference(x.dtype)