未验证 提交 8520a5b3 编写于 作者: S ShenLiang 提交者: GitHub

add check for cembedding (#55621)

上级 9daba606
...@@ -87,6 +87,10 @@ class CEmbeddingOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -87,6 +87,10 @@ class CEmbeddingOpMaker : public framework::OpProtoAndCheckerMaker {
"(int64, default 0), The starting index is indeed, " "(int64, default 0), The starting index is indeed, "
"and the out-of-bounds will be set to 0 ") "and the out-of-bounds will be set to 0 ")
.SetDefault(0); .SetDefault(0);
AddAttr<int64_t>("vocab_size",
"(int64, default -1), The total vocabulary size to check"
"the out-of-bounds ids. If it is -1, no check will be ")
.SetDefault(-1);
AddComment(R"DOC( AddComment(R"DOC(
c_embedding Operator. c_embedding Operator.
......
...@@ -42,21 +42,25 @@ __global__ void CEmbedding(T *out, ...@@ -42,21 +42,25 @@ __global__ void CEmbedding(T *out,
const int64_t N, const int64_t N,
const int64_t start_idx, const int64_t start_idx,
const int64_t end_idx, const int64_t end_idx,
const int64_t limit) { const int64_t limit,
const int64_t vocab_size) {
CUDA_KERNEL_LOOP(i, limit) { CUDA_KERNEL_LOOP(i, limit) {
size_t row = i / columns; size_t row = i / columns;
size_t col = i % columns; size_t col = i % columns;
auto id = ids[row]; auto id = ids[row];
PADDLE_ENFORCE(
id >= 0 && (vocab_size < 0 || id < vocab_size),
"The index is out of bounds, "
"please check whether the dimensions of index and "
"input meet the requirements. It should "
"be less than [%d] and greater than or equal to 0, but received [%d]",
vocab_size,
id);
if (id >= start_idx && id < end_idx) { if (id >= start_idx && id < end_idx) {
auto real_idx = id - start_idx; auto real_idx = id - start_idx;
PADDLE_ENFORCE(real_idx < N,
"The index is out of bounds, "
"please check whether the dimensions of index and "
"input meet the requirements. It should "
"be less than [%d], but received [%d]",
N,
real_idx);
out[i] = table[real_idx * columns + col]; out[i] = table[real_idx * columns + col];
} else { } else {
out[i] = static_cast<T>(0); out[i] = static_cast<T>(0);
...@@ -95,6 +99,8 @@ class CEmbeddingCUDAKernel : public framework::OpKernel<T> { ...@@ -95,6 +99,8 @@ class CEmbeddingCUDAKernel : public framework::OpKernel<T> {
const auto &dev_ctx = context.template device_context<phi::GPUContext>(); const auto &dev_ctx = context.template device_context<phi::GPUContext>();
const int64_t start_idx = context.Attr<int64_t>("start_index"); const int64_t start_idx = context.Attr<int64_t>("start_index");
const int64_t vocab_size = context.Attr<int64_t>("vocab_size");
size_t N = table_t->dims()[0]; size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1]; size_t D = table_t->dims()[1];
size_t K = ids_t->numel(); size_t K = ids_t->numel();
...@@ -119,7 +125,8 @@ class CEmbeddingCUDAKernel : public framework::OpKernel<T> { ...@@ -119,7 +125,8 @@ class CEmbeddingCUDAKernel : public framework::OpKernel<T> {
N, N,
start_idx, start_idx,
end_idx, end_idx,
limit); limit,
vocab_size);
} else if (index_type == framework::proto::VarType::INT64) { } else if (index_type == framework::proto::VarType::INT64) {
CEmbedding<T, int64_t> CEmbedding<T, int64_t>
...@@ -131,7 +138,8 @@ class CEmbeddingCUDAKernel : public framework::OpKernel<T> { ...@@ -131,7 +138,8 @@ class CEmbeddingCUDAKernel : public framework::OpKernel<T> {
N, N,
start_idx, start_idx,
end_idx, end_idx,
limit); limit,
vocab_size);
} else { } else {
PADDLE_THROW(platform::errors::Unavailable( PADDLE_THROW(platform::errors::Unavailable(
"GPU c_embedding ids only support int32 or int64.")); "GPU c_embedding ids only support int32 or int64."));
......
...@@ -124,6 +124,7 @@ class VocabParallelEmbedding(paddle.nn.Layer): ...@@ -124,6 +124,7 @@ class VocabParallelEmbedding(paddle.nn.Layer):
self._size = [per_part_size, embedding_dim] self._size = [per_part_size, embedding_dim]
self._weight_attr = weight_attr self._weight_attr = weight_attr
self._name = name self._name = name
self.num_embeddings = num_embeddings
if self.is_mp and paddle.in_dynamic_mode(): if self.is_mp and paddle.in_dynamic_mode():
with get_rng_state_tracker().rng_state(): with get_rng_state_tracker().rng_state():
...@@ -151,6 +152,7 @@ class VocabParallelEmbedding(paddle.nn.Layer): ...@@ -151,6 +152,7 @@ class VocabParallelEmbedding(paddle.nn.Layer):
self.weight, self.weight,
x, x,
start_index=self.vocab_start_index, start_index=self.vocab_start_index,
vocab_size=self.num_embeddings,
name=self._name, name=self._name,
) )
output = mp_ops._mp_allreduce( output = mp_ops._mp_allreduce(
......
...@@ -295,7 +295,7 @@ def _mp_allreduce( ...@@ -295,7 +295,7 @@ def _mp_allreduce(
return out return out
def _c_lookup_table(table, index, start_index=0, name=None): def _c_lookup_table(table, index, start_index=0, vocab_size=-1, name=None):
""" """
Lookup table according to index. Lookup table according to index.
...@@ -311,7 +311,7 @@ def _c_lookup_table(table, index, start_index=0, name=None): ...@@ -311,7 +311,7 @@ def _c_lookup_table(table, index, start_index=0, name=None):
""" """
if in_dygraph_mode(): if in_dygraph_mode():
return _legacy_C_ops.c_embedding( return _legacy_C_ops.c_embedding(
table, index, "start_index", start_index table, index, "start_index", start_index, "vocab_size", vocab_size
) )
else: else:
op_type = 'c_embedding' op_type = 'c_embedding'
...@@ -323,7 +323,7 @@ def _c_lookup_table(table, index, start_index=0, name=None): ...@@ -323,7 +323,7 @@ def _c_lookup_table(table, index, start_index=0, name=None):
type='c_embedding', type='c_embedding',
inputs={'Ids': index, 'W': table}, inputs={'Ids': index, 'W': table},
outputs={'Out': tmp}, outputs={'Out': tmp},
attrs={"start_index": start_index}, attrs={"start_index": start_index, "vocab_size": vocab_size},
) )
return tmp return tmp
...@@ -655,7 +655,11 @@ def _parallel_embedding( ...@@ -655,7 +655,11 @@ def _parallel_embedding(
main_block.vars[weight.name].is_distributed = True main_block.vars[weight.name].is_distributed = True
output_parallel = _c_lookup_table( output_parallel = _c_lookup_table(
weight, x, start_index=vocab_start_index, name=name weight,
x,
start_index=vocab_start_index,
vocab_size=origin_size[0],
name=name,
) )
out = _mp_allreduce( out = _mp_allreduce(
output_parallel, output_parallel,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册