From 84273aaa4e35eac3c6594770fa6cc35778ace83d Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Mon, 24 Oct 2022 15:51:11 +0800 Subject: [PATCH] fix cumsum compilation error for GPU architecture that does not support fast FP16 (#47277) --- paddle/phi/kernels/gpu/cum_kernel.cu | 90 +++++++++++++++------------- 1 file changed, 47 insertions(+), 43 deletions(-) diff --git a/paddle/phi/kernels/gpu/cum_kernel.cu b/paddle/phi/kernels/gpu/cum_kernel.cu index a471790771e..0c6cd8b5562 100644 --- a/paddle/phi/kernels/gpu/cum_kernel.cu +++ b/paddle/phi/kernels/gpu/cum_kernel.cu @@ -34,18 +34,6 @@ namespace cub = hipcub; namespace phi { -template -class CumTypeTrait { - public: - using Type = T; -}; - -template <> -class CumTypeTrait { - public: - using Type = __half; -}; - template __device__ void BlockReverse( const T* idata, T* odata, int src_base, int dst_base, int valid_item) { @@ -228,6 +216,51 @@ __global__ void BlockScanKernel(T* d_out, } } +template +typename std::enable_if::value>::type +ThrustCumsumKernel(const Context& dev_ctx, + const T* in_data, + T* out_data, + int64_t size, + bool reverse, + bool exclusive) { +#ifdef __HIPCC__ + const auto& policy = thrust::hip::par.on(dev_ctx.stream()); +#else + const auto& policy = thrust::cuda::par.on(dev_ctx.stream()); +#endif + if (reverse) { + thrust::reverse_iterator> reversed_in( + thrust::device_pointer_cast(in_data) + size); + thrust::reverse_iterator> reversed_out( + thrust::device_pointer_cast(out_data) + size); + if (exclusive) { + thrust::exclusive_scan( + policy, reversed_in, reversed_in + size, reversed_out); + } else { + thrust::inclusive_scan( + policy, reversed_in, reversed_in + size, reversed_out); + } + } else { + if (exclusive) { + thrust::exclusive_scan(policy, in_data, in_data + size, out_data); + } else { + thrust::inclusive_scan(policy, in_data, in_data + size, out_data); + } + } + + return; +} + +template +typename std::enable_if::value>::type +ThrustCumsumKernel(const Context& dev_ctx, + const phi::dtype::float16* in_data, + phi::dtype::float16* out_data, + int64_t size, + bool reverse, + bool exclusive) {} + template void ScanKernel(const Context& dev_ctx, const DenseTensor& x, @@ -260,37 +293,8 @@ void ScanKernel(const Context& dev_ctx, // length of the ‘axis’ dimension. if (!std::is_same::value && std::is_same::value && size == out_dims[axis]) { -#ifdef __HIPCC__ - const auto& policy = thrust::hip::par.on(dev_ctx.stream()); -#else - const auto& policy = thrust::cuda::par.on(dev_ctx.stream()); -#endif - - using CumType = typename CumTypeTrait::Type; - CumType* out_data_ptr = reinterpret_cast(out_data); - const CumType* in_data_ptr = reinterpret_cast(in_data); - if (reverse) { - thrust::reverse_iterator> reversed_in( - thrust::device_pointer_cast(in_data_ptr) + size); - thrust::reverse_iterator> reversed_out( - thrust::device_pointer_cast(out_data_ptr) + size); - if (exclusive) { - thrust::exclusive_scan( - policy, reversed_in, reversed_in + size, reversed_out); - } else { - thrust::inclusive_scan( - policy, reversed_in, reversed_in + size, reversed_out); - } - } else { - if (exclusive) { - thrust::exclusive_scan( - policy, in_data_ptr, in_data_ptr + size, out_data_ptr); - } else { - thrust::inclusive_scan( - policy, in_data_ptr, in_data_ptr + size, out_data_ptr); - } - } - + ThrustCumsumKernel( + dev_ctx, in_data, out_data, size, reverse, exclusive); return; } -- GitLab