未验证 提交 6f684bd2 编写于 作者: shaojie_wang's avatar shaojie_wang 提交者: GitHub

fix shared memory over usage in embedding grad kernel on deterministic mode (#53247)

* fix shared memory over usage in embedding grad kernel on determistic mode

* use IdT as interger dtype
上级 ed45ecc6
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
...@@ -74,16 +75,14 @@ __global__ void EmbeddingGrad(T* table, ...@@ -74,16 +75,14 @@ __global__ void EmbeddingGrad(T* table,
} }
template <typename T, typename IdT> template <typename T, typename IdT>
__global__ void EmbeddingGradDeterministic(T* table, __global__ void EmbeddingGradDeterministic(
const T* output, T* table, const T* output, const IdT* ids, const IdT K, const IdT D) {
const IdT* ids, using MT = typename dtype::MPTypeTrait<T>::Type;
const int64_t K,
const int64_t D) {
extern __shared__ char buf[]; extern __shared__ char buf[];
T* smem = reinterpret_cast<T*>(buf); MT* smem = reinterpret_cast<MT*>(buf);
T* my_s = smem + WARP_SIZE * threadIdx.y; MT* my_s = smem + WARP_SIZE * threadIdx.y;
int64_t* indices_batch = IdT* indices_batch =
reinterpret_cast<int64_t*>(buf + sizeof(T) * WARP_SIZE * BLOCKDIMY); reinterpret_cast<IdT*>(buf + sizeof(MT) * WARP_SIZE * BLOCKDIMY);
const int stride = static_cast<int>(D); const int stride = static_cast<int>(D);
...@@ -97,10 +96,10 @@ __global__ void EmbeddingGradDeterministic(T* table, ...@@ -97,10 +96,10 @@ __global__ void EmbeddingGradDeterministic(T* table,
batch_start += WARP_SIZE * BLOCKDIMY) { batch_start += WARP_SIZE * BLOCKDIMY) {
int tid = threadIdx.x + threadIdx.y * WARP_SIZE; int tid = threadIdx.x + threadIdx.y * WARP_SIZE;
if (batch_start + tid < K) if (batch_start + tid < K)
indices_batch[tid] = static_cast<int64_t>(ids[batch_start + tid]); indices_batch[tid] = static_cast<IdT>(ids[batch_start + tid]);
int batch_end = int batch_end =
min(static_cast<int64_t>(batch_start + WARP_SIZE * BLOCKDIMY), K); min(static_cast<IdT>(batch_start + WARP_SIZE * BLOCKDIMY), K);
// Loop over the batch of <= 1024 loaded indices in chunks of BLOCKDIMY // Loop over the batch of <= 1024 loaded indices in chunks of BLOCKDIMY
for (int chunk_start = batch_start; chunk_start < batch_end; for (int chunk_start = batch_start; chunk_start < batch_end;
...@@ -112,10 +111,10 @@ __global__ void EmbeddingGradDeterministic(T* table, ...@@ -112,10 +111,10 @@ __global__ void EmbeddingGradDeterministic(T* table,
int n_this_chunk = min(batch_end - chunk_start, BLOCKDIMY); int n_this_chunk = min(batch_end - chunk_start, BLOCKDIMY);
int64_t src_row = static_cast<int64_t>(chunk_start + threadIdx.y); IdT src_row = static_cast<IdT>(chunk_start + threadIdx.y);
int64_t dst_row = indices_batch[src_row - batch_start]; IdT dst_row = indices_batch[src_row - batch_start];
if (src_row < K && feature < stride) if (src_row < K && feature < stride)
my_s[threadIdx.x] = static_cast<T>(output[src_row * D + feature]); my_s[threadIdx.x] = static_cast<MT>(output[src_row * D + feature]);
__syncthreads(); __syncthreads();
...@@ -202,11 +201,12 @@ struct EmbeddingGradCUDAFunctor { ...@@ -202,11 +201,12 @@ struct EmbeddingGradCUDAFunctor {
if (FLAGS_embedding_deterministic) { if (FLAGS_embedding_deterministic) {
dim3 threads(WARP_SIZE, BLOCKDIMY); dim3 threads(WARP_SIZE, BLOCKDIMY);
dim3 grids(static_cast<int>((D + WARP_SIZE - 1) / WARP_SIZE)); dim3 grids(static_cast<int>((D + WARP_SIZE - 1) / WARP_SIZE));
using MT = typename dtype::MPTypeTrait<T>::Type;
EmbeddingGradDeterministic<T, IdT> EmbeddingGradDeterministic<T, IdT>
<<<grids, <<<grids,
threads, threads,
sizeof(T) * WARP_SIZE * BLOCKDIMY + sizeof(MT) * WARP_SIZE * BLOCKDIMY +
sizeof(int) * WARP_SIZE * BLOCKDIMY, sizeof(IdT) * WARP_SIZE * BLOCKDIMY,
dev_ctx_.stream()>>>(d_table, d_output, ids, K, D); dev_ctx_.stream()>>>(d_table, d_output, ids, K, D);
} else { } else {
const int gridx = 2 * dev_ctx_.GetSMCount(); const int gridx = 2 * dev_ctx_.GetSMCount();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册