diff --git a/paddle/phi/core/flags.cc b/paddle/phi/core/flags.cc index ad2e38b70d0ed654df26aef51227e6cf9e08bd16..0b248a968785a5262b91a59eefaee55c309de9fa 100644 --- a/paddle/phi/core/flags.cc +++ b/paddle/phi/core/flags.cc @@ -232,6 +232,22 @@ PADDLE_DEFINE_EXPORTED_bool( "operator. The autotuning algorithm may be non-deterministic. If " "true, the algorithm is deterministic."); +/** + * CUDA related FLAG + * Name: FLAGS_embedding_deterministic + * Since Version: 2.5 + * Value Range: bool, default=false + * Example: + * Note: whether to use deterministic algorithm in embedding op. + * If true, it will use deterministic CUDA kernel in embedding op. + */ +PADDLE_DEFINE_EXPORTED_bool( + embedding_deterministic, + false, + "Whether allow using an deterministic algorithm for embedding " + "operator. The deterministic algorithm may be slower. If " + "true, the algorithm is deterministic."); + /** * CUDNN related FLAG * Name: FLAGS_conv_workspace_size_limit diff --git a/paddle/phi/kernels/gpu/embedding_grad_kernel.cu b/paddle/phi/kernels/gpu/embedding_grad_kernel.cu index e2bcfa4d19eb0407575b5737d3f697a7bfb59ecf..4771dd15dd2967c8c566397d40ded6dc24178c56 100644 --- a/paddle/phi/kernels/gpu/embedding_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/embedding_grad_kernel.cu @@ -18,6 +18,7 @@ #include "glog/logging.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/kernel_registry.h" @@ -25,10 +26,20 @@ #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/embedding_util.h" -DECLARE_bool(cudnn_deterministic); +DECLARE_bool(embedding_deterministic); namespace phi { +#ifdef PADDLE_WITH_HIP +#define WARP_SIZE 64 +#define BLOCKDIMY 16 +#else +#define WARP_SIZE 32 +#define BLOCKDIMY 32 +#endif + +#define MASK 0xffffffff + template __global__ void InputTypeConvert(const InT* in_ids, const int64_t K, @@ -63,6 +74,91 @@ __global__ void EmbeddingGrad(T* table, } } +template +__global__ void EmbeddingGradDeterministic( + T* table, const T* output, const IdT* ids, const IdT K, const IdT D) { + using MT = typename dtype::MPTypeTrait::Type; + extern __shared__ char buf[]; + MT* smem = reinterpret_cast(buf); + MT* my_s = smem + WARP_SIZE * threadIdx.y; + IdT* indices_batch = + reinterpret_cast(buf + sizeof(MT) * WARP_SIZE * BLOCKDIMY); + + const int stride = static_cast(D); + + const int feature = threadIdx.x + blockIdx.x * WARP_SIZE; + + // To ensure determinism. If any other warps pulled grad data targeting + // dst_row, we elect the first warp in each matching group as the leader. + // Each leader warp serializes the accumulates targeting dst_row in shared + // memory, then adding the accumulated buffer to dst_row in table. + for (int batch_start = 0; batch_start < K; + batch_start += WARP_SIZE * BLOCKDIMY) { + int tid = threadIdx.x + threadIdx.y * WARP_SIZE; + if (batch_start + tid < K) + indices_batch[tid] = static_cast(ids[batch_start + tid]); + + int batch_end = + min(static_cast(batch_start + WARP_SIZE * BLOCKDIMY), K); + + // Loop over the batch of <= 1024 loaded indices in chunks of BLOCKDIMY + for (int chunk_start = batch_start; chunk_start < batch_end; + chunk_start += BLOCKDIMY) { + // This sync makes sure that indices_batch is ready and match-group + // leaders are done with their accumulates before other warps start + // loading again. + __syncthreads(); + + int n_this_chunk = min(batch_end - chunk_start, BLOCKDIMY); + + IdT src_row = static_cast(chunk_start + threadIdx.y); + IdT dst_row = indices_batch[src_row - batch_start]; + if (src_row < K && feature < stride) + my_s[threadIdx.x] = static_cast(output[src_row * D + feature]); + + __syncthreads(); + + if (src_row < K) { + int match_found_this_thread = 0; + if (threadIdx.x < n_this_chunk) { + match_found_this_thread = + (dst_row == + indices_batch[chunk_start - batch_start + threadIdx.x]); + } +#ifdef PADDLE_WITH_HIP + unsigned long long int matchmask = // NOLINT + __ballot(match_found_this_thread); // NOLINT + int first_remaining_peer = __ffsll(matchmask) - 1; +#else + // If and only if match_found_this_thread of the Nth thread is non-zero, + // set the Nth bit of matchmask to 1. + unsigned int matchmask = __ballot_sync(MASK, match_found_this_thread); + // Find the position of the first bit set to 1 in matchmask. + int first_remaining_peer = __ffs(matchmask) - 1; +#endif + + // select lowest-indexed warp as the leader + if (threadIdx.y == first_remaining_peer) { + // Set the first bit 1 in matchmask to 0. + matchmask ^= (1 << first_remaining_peer); + while (matchmask) { +#ifdef PADDLE_WITH_HIP + first_remaining_peer = __ffsll(matchmask) - 1; +#else + first_remaining_peer = __ffs(matchmask) - 1; +#endif + my_s[threadIdx.x] += + smem[threadIdx.x + WARP_SIZE * first_remaining_peer]; + matchmask ^= (1 << first_remaining_peer); + } + if (feature < stride) + table[dst_row * D + feature] += static_cast(my_s[threadIdx.x]); + } + } + } + } +} + template struct EmbeddingGradCUDAFunctor { EmbeddingGradCUDAFunctor(const Context& dev_ctx, @@ -102,17 +198,23 @@ struct EmbeddingGradCUDAFunctor { cudaMemsetAsync(d_table, 0, N * D * sizeof(T), dev_ctx_.stream())); #endif - const int gridx = 2 * dev_ctx_.GetSMCount(); - dim3 threads(128, 8); - dim3 grids(gridx, 1); - - if (FLAGS_cudnn_deterministic) { - VLOG(2) << "Run grad kernel of embedding with single thread."; - grids.x = 1; - threads.y = 1; + if (FLAGS_embedding_deterministic) { + dim3 threads(WARP_SIZE, BLOCKDIMY); + dim3 grids(static_cast((D + WARP_SIZE - 1) / WARP_SIZE)); + using MT = typename dtype::MPTypeTrait::Type; + EmbeddingGradDeterministic + <<>>(d_table, d_output, ids, K, D); + } else { + const int gridx = 2 * dev_ctx_.GetSMCount(); + dim3 threads(128, 8); + dim3 grids(gridx, 1); + EmbeddingGrad<<>>( + d_table, d_output, ids, N, K, D); } - EmbeddingGrad<<>>( - d_table, d_output, ids, N, K, D); } }