diff --git a/paddle/phi/kernels/gpu/embedding_grad_kernel.cu b/paddle/phi/kernels/gpu/embedding_grad_kernel.cu index d04e221a53816aa852ea42bfd55529994e9b763d..5a3715f454980e56e06f697c7335104eb761d28d 100644 --- a/paddle/phi/kernels/gpu/embedding_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/embedding_grad_kernel.cu @@ -35,7 +35,7 @@ __global__ void InputTypeConvert(const InT* in_ids, } } -template +template __global__ void LookupTableV2Grad(T* table, const T* output, const IdT* ids, @@ -43,16 +43,20 @@ __global__ void LookupTableV2Grad(T* table, const int64_t K, const int64_t D) { int idx = threadIdx.x; - int idy = blockIdx.x + threadIdx.y * GridDimX; + int idy = blockIdx.x + threadIdx.y * gridDim.x; while (idy < K) { auto id = static_cast(ids[idy]); const T* out = output + idy * D; T* tab = table + id * D; - for (int i = idx; i < D; i += BlockDimX) { +#ifdef PADDLE_WITH_CUDA + paddle::platform::VectorizedAtomicAddPerBlock(D, idx, blockDim.x, out, tab); +#else + for (int i = idx; i < D; i += blockDim.x) { paddle::platform::CudaAtomicAdd(&tab[i], out[i]); } - idy += BlockDimY * GridDimX; +#endif + idy += blockDim.y * gridDim.x; } } @@ -83,20 +87,22 @@ struct LookupTableV2GradCUDAFunctor { int D = weight_grad_->dims()[1]; int K = input_.numel(); - dim3 threads(128, 8); - dim3 grids(8, 1); const T* d_output = d_output_t.template data(); const auto* ids = input_.template data(); T* d_table = d_table_t->mutable_data(dev_ctx_.GetPlace()); - auto t = EigenVector::Flatten(*d_table_t); - t.device(*dev_ctx_.eigen_device()) = t.constant(static_cast(0)); +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS( + hipMemsetAsync(d_table, 0, N * D * sizeof(T), dev_ctx_.stream())); +#else + PADDLE_ENFORCE_GPU_SUCCESS( + cudaMemsetAsync(d_table, 0, N * D * sizeof(T), dev_ctx_.stream())); +#endif - LookupTableV2Grad<<>>( + const int gridx = 2 * dev_ctx_.GetSMCount(); + dim3 threads(128, 8); + dim3 grids(gridx, 1); + LookupTableV2Grad<<>>( d_table, d_output, ids, N, K, D); } } diff --git a/paddle/phi/kernels/gpu/embedding_kernel.cu b/paddle/phi/kernels/gpu/embedding_kernel.cu index 6830af163e2ce8699dceec6f6e8a1b3ccd6b35f3..0f66dbf59151b8991eb7e70a064e62b16287466a 100644 --- a/paddle/phi/kernels/gpu/embedding_kernel.cu +++ b/paddle/phi/kernels/gpu/embedding_kernel.cu @@ -23,12 +23,7 @@ namespace phi { -template +template __global__ void LookupTableV2(T *output, const T *table, const IdT *ids, @@ -37,13 +32,13 @@ __global__ void LookupTableV2(T *output, const int64_t D, const int64_t padding_idx) { int idx = threadIdx.x; - int idy = blockIdx.x + threadIdx.y * GridDimX; + int idy = blockIdx.x + threadIdx.y * gridDim.x; while (idy < K) { auto id = static_cast(ids[idy]); T *out = output + idy * D; const T *tab = table + id * D; - for (int i = idx; i < D; i += BlockDimX) { + for (int i = idx; i < D; i += blockDim.x) { if (PaddingFlag) { if (id == padding_idx) out[i] = static_cast(0); @@ -53,7 +48,7 @@ __global__ void LookupTableV2(T *output, out[i] = tab[i]; } } - idy += BlockDimY * GridDimX; + idy += blockDim.y * gridDim.x; } } @@ -76,19 +71,20 @@ struct LookupTableV2CUDAFunctor { size_t D = weight_.dims()[1]; size_t K = input_.numel(); + const int gridx = 2 * dev_ctx_.GetSMCount(); dim3 threads(256, 4); - dim3 grids(80, 1); + dim3 grids(gridx, 1); - const auto *table = weight_.template data(); - const auto *ids = input_.template data(); + const T *table = weight_.template data(); + const IdT *ids = input_.template data(); auto *output = out_->template mutable_data(dev_ctx_.GetPlace()); auto stream = dev_ctx_.stream(); if (padding_idx_ == -1) { - LookupTableV2<<>>( + LookupTableV2<<>>( output, table, ids, N, K, D, padding_idx_); } else { - LookupTableV2<<>>( + LookupTableV2<<>>( output, table, ids, N, K, D, padding_idx_); } }