未验证 提交 6f684bd2 编写于 作者: shaojie_wang's avatar shaojie_wang 提交者: GitHub

fix shared memory over usage in embedding grad kernel on deterministic mode (#53247)

* fix shared memory over usage in embedding grad kernel on determistic mode

* use IdT as interger dtype
上级 ed45ecc6
......@@ -18,6 +18,7 @@
#include "glog/logging.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/kernel_registry.h"
......@@ -74,16 +75,14 @@ __global__ void EmbeddingGrad(T* table,
}
template <typename T, typename IdT>
__global__ void EmbeddingGradDeterministic(T* table,
const T* output,
const IdT* ids,
const int64_t K,
const int64_t D) {
__global__ void EmbeddingGradDeterministic(
T* table, const T* output, const IdT* ids, const IdT K, const IdT D) {
using MT = typename dtype::MPTypeTrait<T>::Type;
extern __shared__ char buf[];
T* smem = reinterpret_cast<T*>(buf);
T* my_s = smem + WARP_SIZE * threadIdx.y;
int64_t* indices_batch =
reinterpret_cast<int64_t*>(buf + sizeof(T) * WARP_SIZE * BLOCKDIMY);
MT* smem = reinterpret_cast<MT*>(buf);
MT* my_s = smem + WARP_SIZE * threadIdx.y;
IdT* indices_batch =
reinterpret_cast<IdT*>(buf + sizeof(MT) * WARP_SIZE * BLOCKDIMY);
const int stride = static_cast<int>(D);
......@@ -97,10 +96,10 @@ __global__ void EmbeddingGradDeterministic(T* table,
batch_start += WARP_SIZE * BLOCKDIMY) {
int tid = threadIdx.x + threadIdx.y * WARP_SIZE;
if (batch_start + tid < K)
indices_batch[tid] = static_cast<int64_t>(ids[batch_start + tid]);
indices_batch[tid] = static_cast<IdT>(ids[batch_start + tid]);
int batch_end =
min(static_cast<int64_t>(batch_start + WARP_SIZE * BLOCKDIMY), K);
min(static_cast<IdT>(batch_start + WARP_SIZE * BLOCKDIMY), K);
// Loop over the batch of <= 1024 loaded indices in chunks of BLOCKDIMY
for (int chunk_start = batch_start; chunk_start < batch_end;
......@@ -112,10 +111,10 @@ __global__ void EmbeddingGradDeterministic(T* table,
int n_this_chunk = min(batch_end - chunk_start, BLOCKDIMY);
int64_t src_row = static_cast<int64_t>(chunk_start + threadIdx.y);
int64_t dst_row = indices_batch[src_row - batch_start];
IdT src_row = static_cast<IdT>(chunk_start + threadIdx.y);
IdT dst_row = indices_batch[src_row - batch_start];
if (src_row < K && feature < stride)
my_s[threadIdx.x] = static_cast<T>(output[src_row * D + feature]);
my_s[threadIdx.x] = static_cast<MT>(output[src_row * D + feature]);
__syncthreads();
......@@ -202,11 +201,12 @@ struct EmbeddingGradCUDAFunctor {
if (FLAGS_embedding_deterministic) {
dim3 threads(WARP_SIZE, BLOCKDIMY);
dim3 grids(static_cast<int>((D + WARP_SIZE - 1) / WARP_SIZE));
using MT = typename dtype::MPTypeTrait<T>::Type;
EmbeddingGradDeterministic<T, IdT>
<<<grids,
threads,
sizeof(T) * WARP_SIZE * BLOCKDIMY +
sizeof(int) * WARP_SIZE * BLOCKDIMY,
sizeof(MT) * WARP_SIZE * BLOCKDIMY +
sizeof(IdT) * WARP_SIZE * BLOCKDIMY,
dev_ctx_.stream()>>>(d_table, d_output, ids, K, D);
} else {
const int gridx = 2 * dev_ctx_.GetSMCount();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册