未验证 提交 8520a5b3 编写于 作者: S ShenLiang 提交者: GitHub

add check for cembedding (#55621)

上级 9daba606
......@@ -87,6 +87,10 @@ class CEmbeddingOpMaker : public framework::OpProtoAndCheckerMaker {
"(int64, default 0), The starting index is indeed, "
"and the out-of-bounds will be set to 0 ")
.SetDefault(0);
AddAttr<int64_t>("vocab_size",
"(int64, default -1), The total vocabulary size to check"
"the out-of-bounds ids. If it is -1, no check will be ")
.SetDefault(-1);
AddComment(R"DOC(
c_embedding Operator.
......
......@@ -42,21 +42,25 @@ __global__ void CEmbedding(T *out,
const int64_t N,
const int64_t start_idx,
const int64_t end_idx,
const int64_t limit) {
const int64_t limit,
const int64_t vocab_size) {
CUDA_KERNEL_LOOP(i, limit) {
size_t row = i / columns;
size_t col = i % columns;
auto id = ids[row];
if (id >= start_idx && id < end_idx) {
auto real_idx = id - start_idx;
PADDLE_ENFORCE(real_idx < N,
PADDLE_ENFORCE(
id >= 0 && (vocab_size < 0 || id < vocab_size),
"The index is out of bounds, "
"please check whether the dimensions of index and "
"input meet the requirements. It should "
"be less than [%d], but received [%d]",
N,
real_idx);
"be less than [%d] and greater than or equal to 0, but received [%d]",
vocab_size,
id);
if (id >= start_idx && id < end_idx) {
auto real_idx = id - start_idx;
out[i] = table[real_idx * columns + col];
} else {
out[i] = static_cast<T>(0);
......@@ -95,6 +99,8 @@ class CEmbeddingCUDAKernel : public framework::OpKernel<T> {
const auto &dev_ctx = context.template device_context<phi::GPUContext>();
const int64_t start_idx = context.Attr<int64_t>("start_index");
const int64_t vocab_size = context.Attr<int64_t>("vocab_size");
size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1];
size_t K = ids_t->numel();
......@@ -119,7 +125,8 @@ class CEmbeddingCUDAKernel : public framework::OpKernel<T> {
N,
start_idx,
end_idx,
limit);
limit,
vocab_size);
} else if (index_type == framework::proto::VarType::INT64) {
CEmbedding<T, int64_t>
......@@ -131,7 +138,8 @@ class CEmbeddingCUDAKernel : public framework::OpKernel<T> {
N,
start_idx,
end_idx,
limit);
limit,
vocab_size);
} else {
PADDLE_THROW(platform::errors::Unavailable(
"GPU c_embedding ids only support int32 or int64."));
......
......@@ -124,6 +124,7 @@ class VocabParallelEmbedding(paddle.nn.Layer):
self._size = [per_part_size, embedding_dim]
self._weight_attr = weight_attr
self._name = name
self.num_embeddings = num_embeddings
if self.is_mp and paddle.in_dynamic_mode():
with get_rng_state_tracker().rng_state():
......@@ -151,6 +152,7 @@ class VocabParallelEmbedding(paddle.nn.Layer):
self.weight,
x,
start_index=self.vocab_start_index,
vocab_size=self.num_embeddings,
name=self._name,
)
output = mp_ops._mp_allreduce(
......
......@@ -295,7 +295,7 @@ def _mp_allreduce(
return out
def _c_lookup_table(table, index, start_index=0, name=None):
def _c_lookup_table(table, index, start_index=0, vocab_size=-1, name=None):
"""
Lookup table according to index.
......@@ -311,7 +311,7 @@ def _c_lookup_table(table, index, start_index=0, name=None):
"""
if in_dygraph_mode():
return _legacy_C_ops.c_embedding(
table, index, "start_index", start_index
table, index, "start_index", start_index, "vocab_size", vocab_size
)
else:
op_type = 'c_embedding'
......@@ -323,7 +323,7 @@ def _c_lookup_table(table, index, start_index=0, name=None):
type='c_embedding',
inputs={'Ids': index, 'W': table},
outputs={'Out': tmp},
attrs={"start_index": start_index},
attrs={"start_index": start_index, "vocab_size": vocab_size},
)
return tmp
......@@ -655,7 +655,11 @@ def _parallel_embedding(
main_block.vars[weight.name].is_distributed = True
output_parallel = _c_lookup_table(
weight, x, start_index=vocab_start_index, name=name
weight,
x,
start_index=vocab_start_index,
vocab_size=origin_size[0],
name=name,
)
out = _mp_allreduce(
output_parallel,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册