From c9a334e1b386b8d40cbca15562132e07aba623a0 Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Fri, 15 Jan 2021 18:00:37 +0800 Subject: [PATCH] add VecCastCUDAKernel (#30296) --- paddle/fluid/operators/cast_op.cu | 49 +++++++++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/cast_op.cu b/paddle/fluid/operators/cast_op.cu index 55cc5a675b..13759633d0 100644 --- a/paddle/fluid/operators/cast_op.cu +++ b/paddle/fluid/operators/cast_op.cu @@ -19,6 +19,43 @@ limitations under the License. */ namespace paddle { namespace operators { +// aligned vector generates vectorized load/store on CUDA +template +struct alignas(sizeof(T) * Size) AlignedVector { + T val[Size]; +}; + +template +inline int VectorizedSize(const T* pointer) { + uint64_t address = reinterpret_cast(pointer); + constexpr int vec4 = std::alignment_of>::value; // NOLINT + if (address % vec4 == 0) { + return 4; + } + return 1; +} + +template +__global__ void VecCastCUDAKernel(const InT* in, const int64_t N, OutT* out) { + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; + using LoadT = AlignedVector; + using StoreT = AlignedVector; + for (int i = idx * VecSize; i < N; i += blockDim.x * gridDim.x * VecSize) { + InT in_vec[VecSize]; + LoadT* in_value = reinterpret_cast(&in_vec); + *in_value = *reinterpret_cast(&in[i]); + + OutT out_vec[VecSize]; +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + out_vec[ii] = static_cast(in_vec[ii]); + } + + *(reinterpret_cast(&out[i])) = + *reinterpret_cast(&out_vec[0]); + } +} + template __global__ void CastCUDAKernel(const InT* in, const int64_t N, OutT* out) { CUDA_KERNEL_LOOP(index, N) { out[index] = static_cast(in[index]); } @@ -40,8 +77,16 @@ struct CastOpFunctor { auto* out = out_->mutable_data(ctx_.GetPlace()); platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D(ctx_, size); - CastCUDAKernel<<>>(in, size, out); + int vec_size = VectorizedSize(out); + if (!std::is_same::value && vec_size == 4 && size % 4 == 0) { + VecCastCUDAKernel<<< + config.block_per_grid, config.thread_per_block, 0, ctx_.stream()>>>( + in, size, out); + } else { + CastCUDAKernel<<>>( + in, size, out); + } } }; -- GitLab