未验证 提交 35f5c245 编写于 作者: S sneaxiy 提交者: GitHub

Optimize c_embedding op in deterministic mode (#53197)

* optimize embedding deterministic mode

* fix compile error

* change FLAGS_cudnn_deterministic to int64

* fix 700 error

* add ut

* fix ut

* fix ut

* fix win32 ci

* fix flags with PHI_DEFINE_EXPORTED_int64
上级 66fbfba8
......@@ -18,8 +18,9 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/kernels/funcs/embedding_grad.h"
DECLARE_bool(cudnn_deterministic);
DECLARE_int64(embedding_deterministic);
namespace paddle {
namespace operators {
......@@ -154,7 +155,6 @@ class CEmbeddingGradCUDAKernel : public framework::OpKernel<T> {
int D = d_table_t->dims()[1];
int K = ids_t->numel();
const int64_t end_idx = start_idx + N;
auto limit = K * D;
int blocks = NumBlocks(limit);
int threads = kNumCUDAThreads;
......@@ -166,10 +166,36 @@ class CEmbeddingGradCUDAKernel : public framework::OpKernel<T> {
t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));
const auto &index_type = framework::TransToProtoVarType(ids_t->dtype());
if (FLAGS_cudnn_deterministic) {
if (FLAGS_embedding_deterministic == 1) {
if (index_type == framework::proto::VarType::INT32) {
phi::funcs::LaunchEmbeddingGradDeterministicKernel<T, int32_t>(
dev_ctx,
ids_t->data<int32_t>(),
d_output,
d_table,
N,
D,
K,
start_idx);
return;
} else if (index_type == framework::proto::VarType::INT64) {
phi::funcs::LaunchEmbeddingGradDeterministicKernel<T, int64_t>(
dev_ctx,
ids_t->data<int64_t>(),
d_output,
d_table,
N,
D,
K,
start_idx);
return;
}
} else {
if (FLAGS_embedding_deterministic > 1) {
VLOG(2) << "Run grad kernel of embedding with single thread.";
blocks = 1;
}
const int64_t end_idx = start_idx + N;
if (index_type == framework::proto::VarType::INT32) {
CEmbeddingGrad<T, int32_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(d_table,
......@@ -181,6 +207,7 @@ class CEmbeddingGradCUDAKernel : public framework::OpKernel<T> {
start_idx,
end_idx,
limit);
return;
} else if (index_type == framework::proto::VarType::INT64) {
CEmbeddingGrad<T, int64_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(d_table,
......@@ -192,7 +219,11 @@ class CEmbeddingGradCUDAKernel : public framework::OpKernel<T> {
start_idx,
end_idx,
limit);
return;
}
}
PADDLE_THROW(phi::errors::InvalidArgument(
"The data type of Input(Ids) must be int32 or int64."));
}
};
......
......@@ -236,17 +236,19 @@ PHI_DEFINE_EXPORTED_bool(
* CUDA related FLAG
* Name: FLAGS_embedding_deterministic
* Since Version: 2.5
* Value Range: bool, default=false
* Value Range: int64, default=0
* Example:
* Note: whether to use deterministic algorithm in embedding op.
* If true, it will use deterministic CUDA kernel in embedding op.
* If it is 1, it will use the optimized deterministic CUDA kernel in
* embedding op. If it is 2, it will use the legacy deterministic
* CUDA kernel in embedding op.
*/
PHI_DEFINE_EXPORTED_bool(
PHI_DEFINE_EXPORTED_int64(
embedding_deterministic,
false,
0,
"Whether allow using an deterministic algorithm for embedding "
"operator. The deterministic algorithm may be slower. If "
"true, the algorithm is deterministic.");
"it is larger than 0, the algorithm is deterministic.");
/**
* CUDNN related FLAG
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
namespace phi {
namespace funcs {
template <typename T, typename IdT, int WarpSize, int BlockDimY, bool UseLimit>
__global__ void EmbeddingGradDeterministicKernel(T* table,
const T* output,
const IdT* ids,
const int64_t K,
const int64_t D,
const int64_t start_idx,
const int64_t end_idx) {
using MT = typename dtype::MPTypeTrait<T>::Type;
constexpr int64_t kInvalidId = -1;
extern __shared__ char buf[];
MT* smem = reinterpret_cast<MT*>(buf);
MT* my_s = smem + WarpSize * threadIdx.y;
IdT* indices_batch =
reinterpret_cast<IdT*>(buf + sizeof(MT) * WarpSize * BlockDimY);
const int stride = static_cast<int>(D);
const int feature = threadIdx.x + blockIdx.x * WarpSize;
// To ensure determinism. If any other warps pulled grad data targeting
// dst_row, we elect the first warp in each matching group as the leader.
// Each leader warp serializes the accumulates targeting dst_row in shared
// memory, then adding the accumulated buffer to dst_row in table.
for (int batch_start = 0; batch_start < K;
batch_start += WarpSize * BlockDimY) {
int tid = threadIdx.x + threadIdx.y * WarpSize;
if (batch_start + tid < K) {
int64_t cur_id = static_cast<int64_t>(ids[batch_start + tid]);
if (UseLimit) {
if (cur_id >= start_idx && cur_id < end_idx) {
cur_id -= start_idx;
} else {
cur_id = kInvalidId;
}
}
indices_batch[tid] = cur_id;
}
int batch_end =
min(static_cast<int64_t>(batch_start + WarpSize * BlockDimY), K);
// Loop over the batch of <= 1024 loaded indices in chunks of BLOCKDIMY
for (int chunk_start = batch_start; chunk_start < batch_end;
chunk_start += BlockDimY) {
// This sync makes sure that indices_batch is ready and match-group
// leaders are done with their accumulates before other warps start
// loading again.
__syncthreads();
int n_this_chunk = min(batch_end - chunk_start, BlockDimY);
int64_t src_row = static_cast<int64_t>(chunk_start + threadIdx.y);
int64_t dst_row = indices_batch[src_row - batch_start];
if (src_row < K && feature < stride) {
if (UseLimit && dst_row == kInvalidId) {
my_s[threadIdx.x] = static_cast<MT>(0);
} else {
my_s[threadIdx.x] = static_cast<MT>(output[src_row * D + feature]);
}
}
__syncthreads();
if (src_row < K) {
int match_found_this_thread = 0;
if (threadIdx.x < n_this_chunk &&
(!UseLimit || dst_row != kInvalidId)) {
match_found_this_thread =
(dst_row ==
indices_batch[chunk_start - batch_start + threadIdx.x]);
}
#ifdef PADDLE_WITH_HIP
unsigned long long int matchmask = // NOLINT
__ballot(match_found_this_thread); // NOLINT
int first_remaining_peer = __ffsll(matchmask) - 1;
#else
// If and only if match_found_this_thread of the Nth thread is non-zero,
// set the Nth bit of matchmask to 1.
unsigned int matchmask =
__ballot_sync(0xffffffff, match_found_this_thread);
// Find the position of the first bit set to 1 in matchmask.
int first_remaining_peer = __ffs(matchmask) - 1;
#endif
// select lowest-indexed warp as the leader
if (threadIdx.y == first_remaining_peer) {
// Set the first bit 1 in matchmask to 0.
matchmask ^= (1 << first_remaining_peer);
while (matchmask) {
#ifdef PADDLE_WITH_HIP
first_remaining_peer = __ffsll(matchmask) - 1;
#else
first_remaining_peer = __ffs(matchmask) - 1;
#endif
my_s[threadIdx.x] +=
smem[threadIdx.x + WarpSize * first_remaining_peer];
matchmask ^= (1 << first_remaining_peer);
}
if (feature < stride && (!UseLimit || dst_row != kInvalidId)) {
auto table_idx = dst_row * D + feature;
table[table_idx] = static_cast<T>(
static_cast<MT>(table[table_idx]) + my_s[threadIdx.x]);
}
}
}
}
}
}
template <typename T, typename IdT>
void LaunchEmbeddingGradDeterministicKernel(const GPUContext& ctx,
const IdT* ids,
const T* d_out,
T* d_table,
int64_t N,
int64_t D,
int64_t K,
int64_t start_idx = -1) {
#ifdef PADDLE_WITH_HIP
constexpr int kWarpSize = 64;
constexpr int kBlockDimY = 16;
#else
constexpr int kWarpSize = 32;
constexpr int kBlockDimY = 32;
#endif
dim3 threads(kWarpSize, kBlockDimY);
dim3 grids(static_cast<int>((D + kWarpSize - 1) / kWarpSize));
using MT = typename dtype::MPTypeTrait<T>::Type;
constexpr auto kSharedMemSize = sizeof(MT) * kWarpSize * kBlockDimY +
sizeof(IdT) * kWarpSize * kBlockDimY;
if (start_idx < 0) {
EmbeddingGradDeterministicKernel<T, IdT, kWarpSize, kBlockDimY, false>
<<<grids, threads, kSharedMemSize, ctx.stream()>>>(
d_table, d_out, ids, K, D, -1, -1);
} else {
int64_t end_idx = start_idx + N;
EmbeddingGradDeterministicKernel<T, IdT, kWarpSize, kBlockDimY, true>
<<<grids, threads, kSharedMemSize, ctx.stream()>>>(
d_table, d_out, ids, K, D, start_idx, end_idx);
}
}
} // namespace funcs
} // namespace phi
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/phi/kernels/embedding_grad_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_grad.h"
#include "gflags/gflags.h"
#include "glog/logging.h"
......@@ -26,20 +27,10 @@
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
DECLARE_bool(embedding_deterministic);
DECLARE_int64(embedding_deterministic);
namespace phi {
#ifdef PADDLE_WITH_HIP
#define WARP_SIZE 64
#define BLOCKDIMY 16
#else
#define WARP_SIZE 32
#define BLOCKDIMY 32
#endif
#define MASK 0xffffffff
template <typename InT, typename OutT>
__global__ void InputTypeConvert(const InT* in_ids,
const int64_t K,
......@@ -74,91 +65,6 @@ __global__ void EmbeddingGrad(T* table,
}
}
template <typename T, typename IdT>
__global__ void EmbeddingGradDeterministic(
T* table, const T* output, const IdT* ids, const IdT K, const IdT D) {
using MT = typename dtype::MPTypeTrait<T>::Type;
extern __shared__ char buf[];
MT* smem = reinterpret_cast<MT*>(buf);
MT* my_s = smem + WARP_SIZE * threadIdx.y;
IdT* indices_batch =
reinterpret_cast<IdT*>(buf + sizeof(MT) * WARP_SIZE * BLOCKDIMY);
const int stride = static_cast<int>(D);
const int feature = threadIdx.x + blockIdx.x * WARP_SIZE;
// To ensure determinism. If any other warps pulled grad data targeting
// dst_row, we elect the first warp in each matching group as the leader.
// Each leader warp serializes the accumulates targeting dst_row in shared
// memory, then adding the accumulated buffer to dst_row in table.
for (int batch_start = 0; batch_start < K;
batch_start += WARP_SIZE * BLOCKDIMY) {
int tid = threadIdx.x + threadIdx.y * WARP_SIZE;
if (batch_start + tid < K)
indices_batch[tid] = static_cast<IdT>(ids[batch_start + tid]);
int batch_end =
min(static_cast<IdT>(batch_start + WARP_SIZE * BLOCKDIMY), K);
// Loop over the batch of <= 1024 loaded indices in chunks of BLOCKDIMY
for (int chunk_start = batch_start; chunk_start < batch_end;
chunk_start += BLOCKDIMY) {
// This sync makes sure that indices_batch is ready and match-group
// leaders are done with their accumulates before other warps start
// loading again.
__syncthreads();
int n_this_chunk = min(batch_end - chunk_start, BLOCKDIMY);
IdT src_row = static_cast<IdT>(chunk_start + threadIdx.y);
IdT dst_row = indices_batch[src_row - batch_start];
if (src_row < K && feature < stride)
my_s[threadIdx.x] = static_cast<MT>(output[src_row * D + feature]);
__syncthreads();
if (src_row < K) {
int match_found_this_thread = 0;
if (threadIdx.x < n_this_chunk) {
match_found_this_thread =
(dst_row ==
indices_batch[chunk_start - batch_start + threadIdx.x]);
}
#ifdef PADDLE_WITH_HIP
unsigned long long int matchmask = // NOLINT
__ballot(match_found_this_thread); // NOLINT
int first_remaining_peer = __ffsll(matchmask) - 1;
#else
// If and only if match_found_this_thread of the Nth thread is non-zero,
// set the Nth bit of matchmask to 1.
unsigned int matchmask = __ballot_sync(MASK, match_found_this_thread);
// Find the position of the first bit set to 1 in matchmask.
int first_remaining_peer = __ffs(matchmask) - 1;
#endif
// select lowest-indexed warp as the leader
if (threadIdx.y == first_remaining_peer) {
// Set the first bit 1 in matchmask to 0.
matchmask ^= (1 << first_remaining_peer);
while (matchmask) {
#ifdef PADDLE_WITH_HIP
first_remaining_peer = __ffsll(matchmask) - 1;
#else
first_remaining_peer = __ffs(matchmask) - 1;
#endif
my_s[threadIdx.x] +=
smem[threadIdx.x + WARP_SIZE * first_remaining_peer];
matchmask ^= (1 << first_remaining_peer);
}
if (feature < stride)
table[dst_row * D + feature] += static_cast<T>(my_s[threadIdx.x]);
}
}
}
}
}
template <typename T, typename Context>
struct EmbeddingGradCUDAFunctor {
EmbeddingGradCUDAFunctor(const Context& dev_ctx,
......@@ -198,20 +104,18 @@ struct EmbeddingGradCUDAFunctor {
cudaMemsetAsync(d_table, 0, N * D * sizeof(T), dev_ctx_.stream()));
#endif
if (FLAGS_embedding_deterministic) {
dim3 threads(WARP_SIZE, BLOCKDIMY);
dim3 grids(static_cast<int>((D + WARP_SIZE - 1) / WARP_SIZE));
using MT = typename dtype::MPTypeTrait<T>::Type;
EmbeddingGradDeterministic<T, IdT>
<<<grids,
threads,
sizeof(MT) * WARP_SIZE * BLOCKDIMY +
sizeof(IdT) * WARP_SIZE * BLOCKDIMY,
dev_ctx_.stream()>>>(d_table, d_output, ids, K, D);
if (FLAGS_embedding_deterministic == 1) {
phi::funcs::LaunchEmbeddingGradDeterministicKernel<T, IdT>(
dev_ctx_, ids, d_output, d_table, N, D, K);
} else {
const int gridx = 2 * dev_ctx_.GetSMCount();
dim3 threads(128, 8);
dim3 grids(gridx, 1);
if (FLAGS_embedding_deterministic > 1) {
VLOG(2) << "Run grad kernel of embedding with single thread.";
grids.x = 1;
threads.y = 1;
}
EmbeddingGrad<T, IdT><<<grids, threads, 0, dev_ctx_.stream()>>>(
d_table, d_output, ids, N, K, D);
}
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import random
import sys
import unittest
import numpy as np
import paddle
from paddle.distributed.fleet.layers.mpu.mp_ops import _c_lookup_table
@contextlib.contextmanager
def deterministic_guard(value):
flag_name = 'FLAGS_embedding_deterministic'
old_value = paddle.get_flags(flag_name)[flag_name]
paddle.set_flags({flag_name: value})
assert paddle.get_flags(flag_name)[flag_name] == value
yield
paddle.set_flags({flag_name: old_value})
assert paddle.get_flags(flag_name)[flag_name] == old_value
def to_numpy(tensor):
if tensor.dtype in [paddle.float16, paddle.bfloat16]:
tensor = tensor.astype(paddle.float32)
return tensor.numpy()
def clone_weight(weight):
if weight.dtype == paddle.bfloat16:
weight = weight.astype(paddle.float32).numpy()
weight = paddle.to_tensor(weight, dtype=paddle.float32).astype(
paddle.bfloat16
)
else:
weight = paddle.to_tensor(weight.numpy())
weight.stop_gradient = False
return weight
def embedding(ids, weight, out_grad, deterministic_level=0, rank=None):
weight = clone_weight(weight)
with deterministic_guard(deterministic_level):
if rank is not None:
vocab_size, _ = weight.shape
start_idx = vocab_size * rank
out = _c_lookup_table(weight, ids, start_index=start_idx)
else:
out = paddle.nn.functional.embedding(ids, weight)
out.backward(out_grad.clone())
return to_numpy(out), to_numpy(weight.grad)
def embedding_ground_truth(ids, weight, out_grad, rank=None):
weight = clone_weight(weight.astype(paddle.float32))
out_grad = out_grad.astype(paddle.float32)
return embedding(ids, weight, out_grad, deterministic_level=2, rank=rank)
def generate_input_data(
ids_shape,
vocab_size,
hidden_size,
weight_dtype,
ids_dtype,
allow_duplicate_id=True,
rank=None,
nranks=None,
allow_pure_random=False,
):
max_id = vocab_size if rank is None else vocab_size * nranks
if allow_duplicate_id:
ids = np.random.randint(low=0, high=max_id, size=ids_shape)
else:
sequence = list(range(max_id))
numel = int(np.prod(ids_shape))
if len(sequence) < numel:
return None, None, None
ids = np.array(random.sample(sequence, numel)).reshape(ids_shape)
ids = paddle.to_tensor(ids).astype(ids_dtype)
ids.stop_gradient = True
weight = paddle.randn([vocab_size, hidden_size]).astype(weight_dtype)
weight.stop_gradient = False
out_grad_shape = list(ids_shape) + [hidden_size]
if allow_duplicate_id and not allow_pure_random:
out_grad = paddle.randint(low=-10, high=10, shape=out_grad_shape)
else:
out_grad = paddle.randn(out_grad_shape)
out_grad = out_grad.astype(weight.dtype)
return ids, weight, out_grad
def get_all_dtypes():
if not paddle.is_compiled_with_cuda() or paddle.is_compiled_with_rocm():
return []
dtypes = [paddle.float32, paddle.float16]
if 'A100' in paddle.device.cuda.get_device_properties().name:
dtypes.append(paddle.bfloat16)
return dtypes
class TestEmbeddingBase(unittest.TestCase):
def setUp(self):
self.ids_shape = [32, 3]
self.vocab_size = 128
self.hidden_size = 1024
self.nranks = 8
def check_main(
self,
weight_dtype,
ids_dtype,
deterministic_level=0,
rank=None,
allow_duplicate_id=True,
allow_pure_random=False,
):
if sys.platform == 'win32' and rank is not None:
return
ids, weight, out_grad = generate_input_data(
ids_shape=self.ids_shape,
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
weight_dtype=weight_dtype,
ids_dtype=ids_dtype,
allow_duplicate_id=allow_duplicate_id,
rank=rank,
nranks=self.nranks,
allow_pure_random=allow_pure_random,
)
if ids is None:
return
if allow_pure_random:
out_1, weight_grad_1 = embedding_ground_truth(
ids, weight, out_grad, rank
)
out_2, weight_grad_2 = embedding_ground_truth(
ids, weight, out_grad, rank
)
else:
out_1, weight_grad_1 = embedding_ground_truth(
ids, weight, out_grad, rank
)
out_2, weight_grad_2 = embedding(
ids,
weight,
out_grad,
deterministic_level=deterministic_level,
rank=rank,
)
np.testing.assert_equal(out_1, out_2)
np.testing.assert_equal(weight_grad_1, weight_grad_2)
def test_main(self):
weight_dtypes = get_all_dtypes()
ids_dtypes = [paddle.int64, paddle.int32]
deterministic_levels = [0, 1]
ranks = [None, 0, 2, 4, 8]
allow_duplicate_ids = [False, True]
allow_pure_randoms = [False, True]
for weight_dtype in weight_dtypes:
for ids_dtype in ids_dtypes:
for deterministic_level in deterministic_levels:
for rank in ranks:
for allow_duplicate_id in allow_duplicate_ids:
for allow_pure_random in allow_pure_randoms:
self.check_main(
weight_dtype,
ids_dtype,
deterministic_level,
rank,
allow_duplicate_id,
allow_pure_random,
)
class TestEmbedding2(TestEmbeddingBase):
def setUp(self):
self.ids_shape = [32, 16]
self.vocab_size = 128
self.hidden_size = 1024
self.nranks = 8
class TestEmbeddingDeterministic(unittest.TestCase):
def setUp(self):
self.ids_shape = [32, 16]
self.vocab_size = 128
self.hidden_size = 1024
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册