diff --git a/paddle/fluid/operators/collective/c_embedding_op.cu b/paddle/fluid/operators/collective/c_embedding_op.cu index 8b521580c5cd51f958a963f2e74d2ff433661959..4861b5d26ab0f026563305dcda4fa32da1dd0409 100644 --- a/paddle/fluid/operators/collective/c_embedding_op.cu +++ b/paddle/fluid/operators/collective/c_embedding_op.cu @@ -18,8 +18,9 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/float16.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" +#include "paddle/phi/kernels/funcs/embedding_grad.h" -DECLARE_bool(cudnn_deterministic); +DECLARE_int64(embedding_deterministic); namespace paddle { namespace operators { @@ -154,7 +155,6 @@ class CEmbeddingGradCUDAKernel : public framework::OpKernel { int D = d_table_t->dims()[1]; int K = ids_t->numel(); - const int64_t end_idx = start_idx + N; auto limit = K * D; int blocks = NumBlocks(limit); int threads = kNumCUDAThreads; @@ -166,33 +166,64 @@ class CEmbeddingGradCUDAKernel : public framework::OpKernel { t.device(*dev_ctx.eigen_device()) = t.constant(static_cast(0)); const auto &index_type = framework::TransToProtoVarType(ids_t->dtype()); - if (FLAGS_cudnn_deterministic) { - VLOG(2) << "Run grad kernel of embedding with single thread."; - blocks = 1; - } - if (index_type == framework::proto::VarType::INT32) { - CEmbeddingGrad - <<>>(d_table, - d_output, - ids_t->data(), - K, - D, - N, - start_idx, - end_idx, - limit); - } else if (index_type == framework::proto::VarType::INT64) { - CEmbeddingGrad - <<>>(d_table, - d_output, - ids_t->data(), - K, - D, - N, - start_idx, - end_idx, - limit); + if (FLAGS_embedding_deterministic == 1) { + if (index_type == framework::proto::VarType::INT32) { + phi::funcs::LaunchEmbeddingGradDeterministicKernel( + dev_ctx, + ids_t->data(), + d_output, + d_table, + N, + D, + K, + start_idx); + return; + } else if (index_type == framework::proto::VarType::INT64) { + phi::funcs::LaunchEmbeddingGradDeterministicKernel( + dev_ctx, + ids_t->data(), + d_output, + d_table, + N, + D, + K, + start_idx); + return; + } + } else { + if (FLAGS_embedding_deterministic > 1) { + VLOG(2) << "Run grad kernel of embedding with single thread."; + blocks = 1; + } + const int64_t end_idx = start_idx + N; + if (index_type == framework::proto::VarType::INT32) { + CEmbeddingGrad + <<>>(d_table, + d_output, + ids_t->data(), + K, + D, + N, + start_idx, + end_idx, + limit); + return; + } else if (index_type == framework::proto::VarType::INT64) { + CEmbeddingGrad + <<>>(d_table, + d_output, + ids_t->data(), + K, + D, + N, + start_idx, + end_idx, + limit); + return; + } } + PADDLE_THROW(phi::errors::InvalidArgument( + "The data type of Input(Ids) must be int32 or int64.")); } }; diff --git a/paddle/phi/core/flags.cc b/paddle/phi/core/flags.cc index 0b248a968785a5262b91a59eefaee55c309de9fa..c8da29d82b9fbe24bbdb77511e7bf3c1d6999bbb 100644 --- a/paddle/phi/core/flags.cc +++ b/paddle/phi/core/flags.cc @@ -236,17 +236,19 @@ PADDLE_DEFINE_EXPORTED_bool( * CUDA related FLAG * Name: FLAGS_embedding_deterministic * Since Version: 2.5 - * Value Range: bool, default=false + * Value Range: int64, default=0 * Example: * Note: whether to use deterministic algorithm in embedding op. - * If true, it will use deterministic CUDA kernel in embedding op. + * If it is 1, it will use the optimized deterministic CUDA kernel in + * embedding op. If it is 2, it will use the legacy deterministic + * CUDA kernel in embedding op. */ -PADDLE_DEFINE_EXPORTED_bool( +PADDLE_DEFINE_EXPORTED_int64( embedding_deterministic, - false, + 0, "Whether allow using an deterministic algorithm for embedding " "operator. The deterministic algorithm may be slower. If " - "true, the algorithm is deterministic."); + "it is larger than 0, the algorithm is deterministic."); /** * CUDNN related FLAG diff --git a/paddle/phi/kernels/funcs/embedding_grad.h b/paddle/phi/kernels/funcs/embedding_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..3ad0f22c8e912d2124d589a8223a581e5aca88d2 --- /dev/null +++ b/paddle/phi/kernels/funcs/embedding_grad.h @@ -0,0 +1,167 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" + +namespace phi { +namespace funcs { + +template +__global__ void EmbeddingGradDeterministicKernel(T* table, + const T* output, + const IdT* ids, + const int64_t K, + const int64_t D, + const int64_t start_idx, + const int64_t end_idx) { + using MT = typename dtype::MPTypeTrait::Type; + constexpr int64_t kInvalidId = -1; + extern __shared__ char buf[]; + MT* smem = reinterpret_cast(buf); + MT* my_s = smem + WarpSize * threadIdx.y; + IdT* indices_batch = + reinterpret_cast(buf + sizeof(MT) * WarpSize * BlockDimY); + + const int stride = static_cast(D); + + const int feature = threadIdx.x + blockIdx.x * WarpSize; + + // To ensure determinism. If any other warps pulled grad data targeting + // dst_row, we elect the first warp in each matching group as the leader. + // Each leader warp serializes the accumulates targeting dst_row in shared + // memory, then adding the accumulated buffer to dst_row in table. + for (int batch_start = 0; batch_start < K; + batch_start += WarpSize * BlockDimY) { + int tid = threadIdx.x + threadIdx.y * WarpSize; + if (batch_start + tid < K) { + int64_t cur_id = static_cast(ids[batch_start + tid]); + if (UseLimit) { + if (cur_id >= start_idx && cur_id < end_idx) { + cur_id -= start_idx; + } else { + cur_id = kInvalidId; + } + } + indices_batch[tid] = cur_id; + } + + int batch_end = + min(static_cast(batch_start + WarpSize * BlockDimY), K); + + // Loop over the batch of <= 1024 loaded indices in chunks of BLOCKDIMY + for (int chunk_start = batch_start; chunk_start < batch_end; + chunk_start += BlockDimY) { + // This sync makes sure that indices_batch is ready and match-group + // leaders are done with their accumulates before other warps start + // loading again. + __syncthreads(); + + int n_this_chunk = min(batch_end - chunk_start, BlockDimY); + + int64_t src_row = static_cast(chunk_start + threadIdx.y); + int64_t dst_row = indices_batch[src_row - batch_start]; + if (src_row < K && feature < stride) { + if (UseLimit && dst_row == kInvalidId) { + my_s[threadIdx.x] = static_cast(0); + } else { + my_s[threadIdx.x] = static_cast(output[src_row * D + feature]); + } + } + + __syncthreads(); + + if (src_row < K) { + int match_found_this_thread = 0; + if (threadIdx.x < n_this_chunk && + (!UseLimit || dst_row != kInvalidId)) { + match_found_this_thread = + (dst_row == + indices_batch[chunk_start - batch_start + threadIdx.x]); + } +#ifdef PADDLE_WITH_HIP + unsigned long long int matchmask = // NOLINT + __ballot(match_found_this_thread); // NOLINT + int first_remaining_peer = __ffsll(matchmask) - 1; +#else + // If and only if match_found_this_thread of the Nth thread is non-zero, + // set the Nth bit of matchmask to 1. + unsigned int matchmask = + __ballot_sync(0xffffffff, match_found_this_thread); + // Find the position of the first bit set to 1 in matchmask. + int first_remaining_peer = __ffs(matchmask) - 1; +#endif + + // select lowest-indexed warp as the leader + if (threadIdx.y == first_remaining_peer) { + // Set the first bit 1 in matchmask to 0. + matchmask ^= (1 << first_remaining_peer); + while (matchmask) { +#ifdef PADDLE_WITH_HIP + first_remaining_peer = __ffsll(matchmask) - 1; +#else + first_remaining_peer = __ffs(matchmask) - 1; +#endif + my_s[threadIdx.x] += + smem[threadIdx.x + WarpSize * first_remaining_peer]; + matchmask ^= (1 << first_remaining_peer); + } + if (feature < stride && (!UseLimit || dst_row != kInvalidId)) { + auto table_idx = dst_row * D + feature; + table[table_idx] = static_cast( + static_cast(table[table_idx]) + my_s[threadIdx.x]); + } + } + } + } + } +} + +template +void LaunchEmbeddingGradDeterministicKernel(const GPUContext& ctx, + const IdT* ids, + const T* d_out, + T* d_table, + int64_t N, + int64_t D, + int64_t K, + int64_t start_idx = -1) { +#ifdef PADDLE_WITH_HIP + constexpr int kWarpSize = 64; + constexpr int kBlockDimY = 16; +#else + constexpr int kWarpSize = 32; + constexpr int kBlockDimY = 32; +#endif + dim3 threads(kWarpSize, kBlockDimY); + dim3 grids(static_cast((D + kWarpSize - 1) / kWarpSize)); + using MT = typename dtype::MPTypeTrait::Type; + constexpr auto kSharedMemSize = sizeof(MT) * kWarpSize * kBlockDimY + + sizeof(IdT) * kWarpSize * kBlockDimY; + if (start_idx < 0) { + EmbeddingGradDeterministicKernel + <<>>( + d_table, d_out, ids, K, D, -1, -1); + } else { + int64_t end_idx = start_idx + N; + EmbeddingGradDeterministicKernel + <<>>( + d_table, d_out, ids, K, D, start_idx, end_idx); + } +} + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/embedding_grad_kernel.cu b/paddle/phi/kernels/gpu/embedding_grad_kernel.cu index 4771dd15dd2967c8c566397d40ded6dc24178c56..80533c77e208cb8c886d83d13fdb9526de8b53f5 100644 --- a/paddle/phi/kernels/gpu/embedding_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/embedding_grad_kernel.cu @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/phi/kernels/embedding_grad_kernel.h" +#include "paddle/phi/kernels/funcs/embedding_grad.h" #include "gflags/gflags.h" #include "glog/logging.h" @@ -26,7 +27,7 @@ #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/embedding_util.h" -DECLARE_bool(embedding_deterministic); +DECLARE_int64(embedding_deterministic); namespace phi { @@ -198,20 +199,18 @@ struct EmbeddingGradCUDAFunctor { cudaMemsetAsync(d_table, 0, N * D * sizeof(T), dev_ctx_.stream())); #endif - if (FLAGS_embedding_deterministic) { - dim3 threads(WARP_SIZE, BLOCKDIMY); - dim3 grids(static_cast((D + WARP_SIZE - 1) / WARP_SIZE)); - using MT = typename dtype::MPTypeTrait::Type; - EmbeddingGradDeterministic - <<>>(d_table, d_output, ids, K, D); + if (FLAGS_embedding_deterministic == 1) { + phi::funcs::LaunchEmbeddingGradDeterministicKernel( + dev_ctx_, ids, d_output, d_table, N, D, K); } else { const int gridx = 2 * dev_ctx_.GetSMCount(); dim3 threads(128, 8); dim3 grids(gridx, 1); + if (FLAGS_embedding_deterministic > 1) { + VLOG(2) << "Run grad kernel of embedding with single thread."; + grids.x = 1; + threads.y = 1; + } EmbeddingGrad<<>>( d_table, d_output, ids, N, K, D); } diff --git a/python/paddle/fluid/tests/unittests/test_embedding_deterministic.py b/python/paddle/fluid/tests/unittests/test_embedding_deterministic.py new file mode 100644 index 0000000000000000000000000000000000000000..e64b4aa07ef9c14bb7145f5bed7cc3321a94fd5a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_embedding_deterministic.py @@ -0,0 +1,213 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import random +import sys +import unittest + +import numpy as np + +import paddle +from paddle.distributed.fleet.layers.mpu.mp_ops import _c_lookup_table + + +@contextlib.contextmanager +def deterministic_guard(value): + flag_name = 'FLAGS_embedding_deterministic' + old_value = paddle.get_flags(flag_name)[flag_name] + paddle.set_flags({flag_name: value}) + assert paddle.get_flags(flag_name)[flag_name] == value + yield + paddle.set_flags({flag_name: old_value}) + assert paddle.get_flags(flag_name)[flag_name] == old_value + + +def to_numpy(tensor): + if tensor.dtype in [paddle.float16, paddle.bfloat16]: + tensor = tensor.astype(paddle.float32) + return tensor.numpy() + + +def clone_weight(weight): + if weight.dtype == paddle.bfloat16: + weight = weight.astype(paddle.float32).numpy() + weight = paddle.to_tensor(weight, dtype=paddle.float32).astype( + paddle.bfloat16 + ) + else: + weight = paddle.to_tensor(weight.numpy()) + weight.stop_gradient = False + return weight + + +def embedding(ids, weight, out_grad, deterministic_level=0, rank=None): + weight = clone_weight(weight) + with deterministic_guard(deterministic_level): + if rank is not None: + vocab_size, _ = weight.shape + start_idx = vocab_size * rank + out = _c_lookup_table(weight, ids, start_index=start_idx) + else: + out = paddle.nn.functional.embedding(ids, weight) + out.backward(out_grad.clone()) + return to_numpy(out), to_numpy(weight.grad) + + +def embedding_ground_truth(ids, weight, out_grad, rank=None): + weight = clone_weight(weight.astype(paddle.float32)) + out_grad = out_grad.astype(paddle.float32) + return embedding(ids, weight, out_grad, deterministic_level=2, rank=rank) + + +def generate_input_data( + ids_shape, + vocab_size, + hidden_size, + weight_dtype, + ids_dtype, + allow_duplicate_id=True, + rank=None, + nranks=None, + allow_pure_random=False, +): + max_id = vocab_size if rank is None else vocab_size * nranks + if allow_duplicate_id: + ids = np.random.randint(low=0, high=max_id, size=ids_shape) + else: + sequence = list(range(max_id)) + numel = int(np.prod(ids_shape)) + if len(sequence) < numel: + return None, None, None + ids = np.array(random.sample(sequence, numel)).reshape(ids_shape) + + ids = paddle.to_tensor(ids).astype(ids_dtype) + ids.stop_gradient = True + + weight = paddle.randn([vocab_size, hidden_size]).astype(weight_dtype) + weight.stop_gradient = False + + out_grad_shape = list(ids_shape) + [hidden_size] + if allow_duplicate_id and not allow_pure_random: + out_grad = paddle.randint(low=-10, high=10, shape=out_grad_shape) + else: + out_grad = paddle.randn(out_grad_shape) + out_grad = out_grad.astype(weight.dtype) + return ids, weight, out_grad + + +def get_all_dtypes(): + if not paddle.is_compiled_with_cuda() or paddle.is_compiled_with_rocm(): + return [] + + dtypes = [paddle.float32, paddle.float16] + if 'A100' in paddle.device.cuda.get_device_properties().name: + dtypes.append(paddle.bfloat16) + return dtypes + + +class TestEmbeddingBase(unittest.TestCase): + def setUp(self): + self.ids_shape = [32, 3] + self.vocab_size = 128 + self.hidden_size = 1024 + self.nranks = 8 + + def check_main( + self, + weight_dtype, + ids_dtype, + deterministic_level=0, + rank=None, + allow_duplicate_id=True, + allow_pure_random=False, + ): + if sys.platform == 'win32' and rank is not None: + return + + ids, weight, out_grad = generate_input_data( + ids_shape=self.ids_shape, + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + weight_dtype=weight_dtype, + ids_dtype=ids_dtype, + allow_duplicate_id=allow_duplicate_id, + rank=rank, + nranks=self.nranks, + allow_pure_random=allow_pure_random, + ) + if ids is None: + return + + if allow_pure_random: + out_1, weight_grad_1 = embedding_ground_truth( + ids, weight, out_grad, rank + ) + out_2, weight_grad_2 = embedding_ground_truth( + ids, weight, out_grad, rank + ) + else: + out_1, weight_grad_1 = embedding_ground_truth( + ids, weight, out_grad, rank + ) + out_2, weight_grad_2 = embedding( + ids, + weight, + out_grad, + deterministic_level=deterministic_level, + rank=rank, + ) + np.testing.assert_equal(out_1, out_2) + np.testing.assert_equal(weight_grad_1, weight_grad_2) + + def test_main(self): + weight_dtypes = get_all_dtypes() + ids_dtypes = [paddle.int64, paddle.int32] + deterministic_levels = [0, 1] + ranks = [None, 0, 2, 4, 8] + allow_duplicate_ids = [False, True] + allow_pure_randoms = [False, True] + for weight_dtype in weight_dtypes: + for ids_dtype in ids_dtypes: + for deterministic_level in deterministic_levels: + for rank in ranks: + for allow_duplicate_id in allow_duplicate_ids: + for allow_pure_random in allow_pure_randoms: + self.check_main( + weight_dtype, + ids_dtype, + deterministic_level, + rank, + allow_duplicate_id, + allow_pure_random, + ) + + +class TestEmbedding2(TestEmbeddingBase): + def setUp(self): + self.ids_shape = [32, 16] + self.vocab_size = 128 + self.hidden_size = 1024 + self.nranks = 8 + + +class TestEmbeddingDeterministic(unittest.TestCase): + def setUp(self): + self.ids_shape = [32, 16] + self.vocab_size = 128 + self.hidden_size = 1024 + + +if __name__ == "__main__": + unittest.main()