未验证 提交 4341ebd9 编写于 作者: G Ghost Screaming 提交者: GitHub

Fix ignore index of c_softmax_with_cross_entropy_op. (#52835)

* Fix bug of reduce_sum op. When input.numel() > INT32_MAX, its result
is wrong.

* Remove climits.

* Fix bug of c_softmax_with_cross_entropy_op. Support ignore_index is
negative number.
上级 0f2dc4ca
......@@ -40,6 +40,7 @@ template <typename T, typename IndexT>
__global__ void MaskLabelByIndex(T* predicted_logits,
const T* logit,
const IndexT* label,
const IndexT ignore_index,
const int start_index,
const int end_index,
const int64_t N,
......@@ -47,13 +48,15 @@ __global__ void MaskLabelByIndex(T* predicted_logits,
const int nranks) {
CUDA_KERNEL_LOOP(i, N) {
auto real_label = label[i];
PADDLE_ENFORCE((real_label < D * nranks) && (real_label >= 0),
PADDLE_ENFORCE(((real_label < D * nranks) && (real_label >= 0)) ||
(real_label == ignore_index),
"The index is out of bounds, "
"please check whether the value of label and "
"input meet the class number. It should "
"be less than [%d], but received [%d]",
D * nranks,
real_label);
"be less than [%ld] or equal to [%ld], but received [%ld]",
static_cast<int64_t>(D * nranks),
static_cast<int64_t>(ignore_index),
static_cast<int64_t>(real_label));
if (real_label >= start_index && real_label < end_index) {
predicted_logits[i] = logit[i * D + real_label - start_index];
......@@ -204,20 +207,22 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::GPUContext, T> {
const auto& label_type = framework::TransToProtoVarType(labels->dtype());
if (label_type == framework::proto::VarType::INT32) {
MaskLabelByIndex<T, int32_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(predicted_logits.data<T>(),
softmax_2d.data<T>(),
labels->data<int32_t>(),
start_index,
end_index,
N,
D,
nranks);
MaskLabelByIndex<T, int32_t><<<blocks, threads, 0, dev_ctx.stream()>>>(
predicted_logits.data<T>(),
softmax_2d.data<T>(),
labels->data<int32_t>(),
static_cast<int32_t>(ignore_index),
start_index,
end_index,
N,
D,
nranks);
} else if (label_type == framework::proto::VarType::INT64) {
MaskLabelByIndex<T, int64_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(predicted_logits.data<T>(),
softmax_2d.data<T>(),
labels->data<int64_t>(),
ignore_index,
start_index,
end_index,
N,
......@@ -362,25 +367,27 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::GPUContext, T> {
const auto& label_type = framework::TransToProtoVarType(labels->dtype());
if (label_type == framework::proto::VarType::INT32) {
MaskLabelByIndex<T, int32_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(predicted_logits.data<T>(),
softmax_2d.data<T>(),
labels->data<int32_t>(),
start_index,
end_index,
N,
D,
nranks);
MaskLabelByIndex<T, int32_t><<<blocks, threads, 0, dev_ctx.stream()>>>(
predicted_logits.data<T>(),
softmax_2d.data<T>(),
labels->data<int32_t>(),
static_cast<int32_t>(ignore_index),
start_index,
end_index,
N,
D,
nranks);
} else if (label_type == framework::proto::VarType::INT64) {
MaskLabelByIndex<T, int64_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(predicted_logits.data<T>(),
softmax_2d.data<T>(),
labels->data<int64_t>(),
start_index,
end_index,
N,
D,
nranks);
MaskLabelByIndex<T, int64_t><<<blocks, threads, 0, dev_ctx.stream()>>>(
predicted_logits.data<T>(),
softmax_2d.data<T>(),
labels->data<int64_t>(),
static_cast<int32_t>(ignore_index),
start_index,
end_index,
N,
D,
nranks);
}
in_out.clear();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册