From 2e6548a9cd2224e1a4b89c1351f1078273f98328 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Wed, 2 Mar 2022 18:40:00 +0800 Subject: [PATCH] vec scale kernel (#40011) --- .../optimizers/distributed_fused_lamb_op.cu | 49 +++++++++++++++---- 1 file changed, 39 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu index ca0828a6f6a..8bb4606ffff 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu @@ -304,14 +304,30 @@ struct AndFunctor { HOSTDEVICE bool operator()(bool x, bool y) const { return x && y; } }; -template +template static __global__ void ScaleCUDAKernel(const T1 *__restrict__ x, const T2 *__restrict__ scale, T1 *__restrict__ y, int num) { static_assert(sizeof(T1) <= sizeof(T2), "sizeof(T1) must be not greater than sizeof(T2)."); T2 s = scale[0]; - CUDA_KERNEL_LOOP(i, num) { + + int i = (threadIdx.x + blockIdx.x * blockDim.x) * VecSize; + int stride = blockDim.x * gridDim.x * VecSize; + + for (; i + VecSize <= num; i += stride) { + platform::AlignedVector x_vec; + platform::AlignedVector y_vec; + + platform::Load(x + i, &x_vec); +#pragma unroll + for (int j = 0; j < VecSize; ++j) { + y_vec[j] = static_cast(static_cast(x_vec[j]) * s); + } + platform::Store(y_vec, y + i); + } + + for (; i < num; ++i) { y[i] = static_cast(static_cast(x[i]) * s); } } @@ -396,7 +412,6 @@ static __global__ void UpdateLambMomentAndTrustRatioDivCUDAKernel( for (; i + VecSize <= num; i += stride) { platform::AlignedVector param_vec; platform::AlignedVector grad_vec; - platform::AlignedVector weight_decay_vec; platform::AlignedVector mom1_vec; platform::AlignedVector mom2_vec; platform::AlignedVector trust_ratio_div_vec; @@ -760,6 +775,24 @@ static bool CreatePreMulScaleOpIfSupported(ncclDataType_t dtype, return false; } +template +static void LaunchScaleKernel(const platform::CUDADeviceContext &dev_ctx, + const T1 *x, const T2 *scale, T1 *y, int n, + gpuStream_t stream) { + int vec_size = std::min(GetChunkedVecSize(x, 0), GetChunkedVecSize(y, 0)); + auto config = platform::GetGpuLaunchConfig1D(dev_ctx, n, vec_size); + +#define PD_LAMB_VEC_SCALE_KERNEL_CASE \ + do { \ + ScaleCUDAKernel<<>>( \ + x, scale, y, n); \ + } while (0) + + PD_VEC_LAUNCH_KERNEL(vec_size, PD_LAMB_VEC_SCALE_KERNEL_CASE); +#undef PD_LAMB_VEC_SCALE_KERNEL_CASE +} + template static void NCCLReduceScatterWithScale( const T *sendbuff, T *recvbuff, size_t recvcount, size_t nranks, @@ -775,10 +808,8 @@ static void NCCLReduceScatterWithScale( PADDLE_ENFORCE_EQ(nranks, 1, platform::errors::InvalidArgument( "nranks must be 1 when scale != nullptr.")); - auto numel = recvcount * nranks; - auto config = platform::GetGpuLaunchConfig1D(dev_ctx, numel); - ScaleCUDAKernel<<>>(sendbuff, scale, recvbuff, numel); + LaunchScaleKernel(dev_ctx, sendbuff, scale, recvbuff, recvcount * nranks, + stream); } return; } @@ -792,9 +823,7 @@ static void NCCLReduceScatterWithScale( if (scale && !should_destroy_op) { size_t numel = recvcount * nranks; T *new_sendbuff = buffer.Alloc(numel); - auto config = platform::GetGpuLaunchConfig1D(dev_ctx, numel); - ScaleCUDAKernel<<>>(sendbuff, scale, new_sendbuff, numel); + LaunchScaleKernel(dev_ctx, sendbuff, scale, new_sendbuff, numel, stream); sendbuff = new_sendbuff; } -- GitLab