From 6f684bd2ee5de97053292a7bf648419273e671c3 Mon Sep 17 00:00:00 2001 From: Shaojie WANG Date: Tue, 25 Apr 2023 00:03:56 -0700 Subject: [PATCH] fix shared memory over usage in embedding grad kernel on deterministic mode (#53247) * fix shared memory over usage in embedding grad kernel on determistic mode * use IdT as interger dtype --- .../phi/kernels/gpu/embedding_grad_kernel.cu | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/paddle/phi/kernels/gpu/embedding_grad_kernel.cu b/paddle/phi/kernels/gpu/embedding_grad_kernel.cu index cb34f5844b0..4771dd15dd2 100644 --- a/paddle/phi/kernels/gpu/embedding_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/embedding_grad_kernel.cu @@ -18,6 +18,7 @@ #include "glog/logging.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/kernel_registry.h" @@ -74,16 +75,14 @@ __global__ void EmbeddingGrad(T* table, } template -__global__ void EmbeddingGradDeterministic(T* table, - const T* output, - const IdT* ids, - const int64_t K, - const int64_t D) { +__global__ void EmbeddingGradDeterministic( + T* table, const T* output, const IdT* ids, const IdT K, const IdT D) { + using MT = typename dtype::MPTypeTrait::Type; extern __shared__ char buf[]; - T* smem = reinterpret_cast(buf); - T* my_s = smem + WARP_SIZE * threadIdx.y; - int64_t* indices_batch = - reinterpret_cast(buf + sizeof(T) * WARP_SIZE * BLOCKDIMY); + MT* smem = reinterpret_cast(buf); + MT* my_s = smem + WARP_SIZE * threadIdx.y; + IdT* indices_batch = + reinterpret_cast(buf + sizeof(MT) * WARP_SIZE * BLOCKDIMY); const int stride = static_cast(D); @@ -97,10 +96,10 @@ __global__ void EmbeddingGradDeterministic(T* table, batch_start += WARP_SIZE * BLOCKDIMY) { int tid = threadIdx.x + threadIdx.y * WARP_SIZE; if (batch_start + tid < K) - indices_batch[tid] = static_cast(ids[batch_start + tid]); + indices_batch[tid] = static_cast(ids[batch_start + tid]); int batch_end = - min(static_cast(batch_start + WARP_SIZE * BLOCKDIMY), K); + min(static_cast(batch_start + WARP_SIZE * BLOCKDIMY), K); // Loop over the batch of <= 1024 loaded indices in chunks of BLOCKDIMY for (int chunk_start = batch_start; chunk_start < batch_end; @@ -112,10 +111,10 @@ __global__ void EmbeddingGradDeterministic(T* table, int n_this_chunk = min(batch_end - chunk_start, BLOCKDIMY); - int64_t src_row = static_cast(chunk_start + threadIdx.y); - int64_t dst_row = indices_batch[src_row - batch_start]; + IdT src_row = static_cast(chunk_start + threadIdx.y); + IdT dst_row = indices_batch[src_row - batch_start]; if (src_row < K && feature < stride) - my_s[threadIdx.x] = static_cast(output[src_row * D + feature]); + my_s[threadIdx.x] = static_cast(output[src_row * D + feature]); __syncthreads(); @@ -202,11 +201,12 @@ struct EmbeddingGradCUDAFunctor { if (FLAGS_embedding_deterministic) { dim3 threads(WARP_SIZE, BLOCKDIMY); dim3 grids(static_cast((D + WARP_SIZE - 1) / WARP_SIZE)); + using MT = typename dtype::MPTypeTrait::Type; EmbeddingGradDeterministic <<>>(d_table, d_output, ids, K, D); } else { const int gridx = 2 * dev_ctx_.GetSMCount(); -- GitLab