未验证 提交 1f45b313 编写于 作者: shaojie_wang's avatar shaojie_wang 提交者: GitHub

cherry pick dev branch for embedding grad (#53332)

上级 3f2f4040
......@@ -232,6 +232,22 @@ PADDLE_DEFINE_EXPORTED_bool(
"operator. The autotuning algorithm may be non-deterministic. If "
"true, the algorithm is deterministic.");
/**
* CUDA related FLAG
* Name: FLAGS_embedding_deterministic
* Since Version: 2.5
* Value Range: bool, default=false
* Example:
* Note: whether to use deterministic algorithm in embedding op.
* If true, it will use deterministic CUDA kernel in embedding op.
*/
PADDLE_DEFINE_EXPORTED_bool(
embedding_deterministic,
false,
"Whether allow using an deterministic algorithm for embedding "
"operator. The deterministic algorithm may be slower. If "
"true, the algorithm is deterministic.");
/**
* CUDNN related FLAG
* Name: FLAGS_conv_workspace_size_limit
......
......@@ -18,6 +18,7 @@
#include "glog/logging.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/kernel_registry.h"
......@@ -25,10 +26,20 @@
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
DECLARE_bool(cudnn_deterministic);
DECLARE_bool(embedding_deterministic);
namespace phi {
#ifdef PADDLE_WITH_HIP
#define WARP_SIZE 64
#define BLOCKDIMY 16
#else
#define WARP_SIZE 32
#define BLOCKDIMY 32
#endif
#define MASK 0xffffffff
template <typename InT, typename OutT>
__global__ void InputTypeConvert(const InT* in_ids,
const int64_t K,
......@@ -63,6 +74,91 @@ __global__ void EmbeddingGrad(T* table,
}
}
template <typename T, typename IdT>
__global__ void EmbeddingGradDeterministic(
T* table, const T* output, const IdT* ids, const IdT K, const IdT D) {
using MT = typename dtype::MPTypeTrait<T>::Type;
extern __shared__ char buf[];
MT* smem = reinterpret_cast<MT*>(buf);
MT* my_s = smem + WARP_SIZE * threadIdx.y;
IdT* indices_batch =
reinterpret_cast<IdT*>(buf + sizeof(MT) * WARP_SIZE * BLOCKDIMY);
const int stride = static_cast<int>(D);
const int feature = threadIdx.x + blockIdx.x * WARP_SIZE;
// To ensure determinism. If any other warps pulled grad data targeting
// dst_row, we elect the first warp in each matching group as the leader.
// Each leader warp serializes the accumulates targeting dst_row in shared
// memory, then adding the accumulated buffer to dst_row in table.
for (int batch_start = 0; batch_start < K;
batch_start += WARP_SIZE * BLOCKDIMY) {
int tid = threadIdx.x + threadIdx.y * WARP_SIZE;
if (batch_start + tid < K)
indices_batch[tid] = static_cast<IdT>(ids[batch_start + tid]);
int batch_end =
min(static_cast<IdT>(batch_start + WARP_SIZE * BLOCKDIMY), K);
// Loop over the batch of <= 1024 loaded indices in chunks of BLOCKDIMY
for (int chunk_start = batch_start; chunk_start < batch_end;
chunk_start += BLOCKDIMY) {
// This sync makes sure that indices_batch is ready and match-group
// leaders are done with their accumulates before other warps start
// loading again.
__syncthreads();
int n_this_chunk = min(batch_end - chunk_start, BLOCKDIMY);
IdT src_row = static_cast<IdT>(chunk_start + threadIdx.y);
IdT dst_row = indices_batch[src_row - batch_start];
if (src_row < K && feature < stride)
my_s[threadIdx.x] = static_cast<MT>(output[src_row * D + feature]);
__syncthreads();
if (src_row < K) {
int match_found_this_thread = 0;
if (threadIdx.x < n_this_chunk) {
match_found_this_thread =
(dst_row ==
indices_batch[chunk_start - batch_start + threadIdx.x]);
}
#ifdef PADDLE_WITH_HIP
unsigned long long int matchmask = // NOLINT
__ballot(match_found_this_thread); // NOLINT
int first_remaining_peer = __ffsll(matchmask) - 1;
#else
// If and only if match_found_this_thread of the Nth thread is non-zero,
// set the Nth bit of matchmask to 1.
unsigned int matchmask = __ballot_sync(MASK, match_found_this_thread);
// Find the position of the first bit set to 1 in matchmask.
int first_remaining_peer = __ffs(matchmask) - 1;
#endif
// select lowest-indexed warp as the leader
if (threadIdx.y == first_remaining_peer) {
// Set the first bit 1 in matchmask to 0.
matchmask ^= (1 << first_remaining_peer);
while (matchmask) {
#ifdef PADDLE_WITH_HIP
first_remaining_peer = __ffsll(matchmask) - 1;
#else
first_remaining_peer = __ffs(matchmask) - 1;
#endif
my_s[threadIdx.x] +=
smem[threadIdx.x + WARP_SIZE * first_remaining_peer];
matchmask ^= (1 << first_remaining_peer);
}
if (feature < stride)
table[dst_row * D + feature] += static_cast<T>(my_s[threadIdx.x]);
}
}
}
}
}
template <typename T, typename Context>
struct EmbeddingGradCUDAFunctor {
EmbeddingGradCUDAFunctor(const Context& dev_ctx,
......@@ -102,17 +198,23 @@ struct EmbeddingGradCUDAFunctor {
cudaMemsetAsync(d_table, 0, N * D * sizeof(T), dev_ctx_.stream()));
#endif
const int gridx = 2 * dev_ctx_.GetSMCount();
dim3 threads(128, 8);
dim3 grids(gridx, 1);
if (FLAGS_cudnn_deterministic) {
VLOG(2) << "Run grad kernel of embedding with single thread.";
grids.x = 1;
threads.y = 1;
if (FLAGS_embedding_deterministic) {
dim3 threads(WARP_SIZE, BLOCKDIMY);
dim3 grids(static_cast<int>((D + WARP_SIZE - 1) / WARP_SIZE));
using MT = typename dtype::MPTypeTrait<T>::Type;
EmbeddingGradDeterministic<T, IdT>
<<<grids,
threads,
sizeof(MT) * WARP_SIZE * BLOCKDIMY +
sizeof(IdT) * WARP_SIZE * BLOCKDIMY,
dev_ctx_.stream()>>>(d_table, d_output, ids, K, D);
} else {
const int gridx = 2 * dev_ctx_.GetSMCount();
dim3 threads(128, 8);
dim3 grids(gridx, 1);
EmbeddingGrad<T, IdT><<<grids, threads, 0, dev_ctx_.stream()>>>(
d_table, d_output, ids, N, K, D);
}
EmbeddingGrad<T, IdT><<<grids, threads, 0, dev_ctx_.stream()>>>(
d_table, d_output, ids, N, K, D);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册