未验证 提交 4341ebd9 编写于 作者: G Ghost Screaming 提交者: GitHub

Fix ignore index of c_softmax_with_cross_entropy_op. (#52835)

* Fix bug of reduce_sum op. When input.numel() > INT32_MAX, its result
is wrong.

* Remove climits.

* Fix bug of c_softmax_with_cross_entropy_op. Support ignore_index is
negative number.
上级 0f2dc4ca
...@@ -40,6 +40,7 @@ template <typename T, typename IndexT> ...@@ -40,6 +40,7 @@ template <typename T, typename IndexT>
__global__ void MaskLabelByIndex(T* predicted_logits, __global__ void MaskLabelByIndex(T* predicted_logits,
const T* logit, const T* logit,
const IndexT* label, const IndexT* label,
const IndexT ignore_index,
const int start_index, const int start_index,
const int end_index, const int end_index,
const int64_t N, const int64_t N,
...@@ -47,13 +48,15 @@ __global__ void MaskLabelByIndex(T* predicted_logits, ...@@ -47,13 +48,15 @@ __global__ void MaskLabelByIndex(T* predicted_logits,
const int nranks) { const int nranks) {
CUDA_KERNEL_LOOP(i, N) { CUDA_KERNEL_LOOP(i, N) {
auto real_label = label[i]; auto real_label = label[i];
PADDLE_ENFORCE((real_label < D * nranks) && (real_label >= 0), PADDLE_ENFORCE(((real_label < D * nranks) && (real_label >= 0)) ||
(real_label == ignore_index),
"The index is out of bounds, " "The index is out of bounds, "
"please check whether the value of label and " "please check whether the value of label and "
"input meet the class number. It should " "input meet the class number. It should "
"be less than [%d], but received [%d]", "be less than [%ld] or equal to [%ld], but received [%ld]",
D * nranks, static_cast<int64_t>(D * nranks),
real_label); static_cast<int64_t>(ignore_index),
static_cast<int64_t>(real_label));
if (real_label >= start_index && real_label < end_index) { if (real_label >= start_index && real_label < end_index) {
predicted_logits[i] = logit[i * D + real_label - start_index]; predicted_logits[i] = logit[i * D + real_label - start_index];
...@@ -204,20 +207,22 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::GPUContext, T> { ...@@ -204,20 +207,22 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::GPUContext, T> {
const auto& label_type = framework::TransToProtoVarType(labels->dtype()); const auto& label_type = framework::TransToProtoVarType(labels->dtype());
if (label_type == framework::proto::VarType::INT32) { if (label_type == framework::proto::VarType::INT32) {
MaskLabelByIndex<T, int32_t> MaskLabelByIndex<T, int32_t><<<blocks, threads, 0, dev_ctx.stream()>>>(
<<<blocks, threads, 0, dev_ctx.stream()>>>(predicted_logits.data<T>(), predicted_logits.data<T>(),
softmax_2d.data<T>(), softmax_2d.data<T>(),
labels->data<int32_t>(), labels->data<int32_t>(),
start_index, static_cast<int32_t>(ignore_index),
end_index, start_index,
N, end_index,
D, N,
nranks); D,
nranks);
} else if (label_type == framework::proto::VarType::INT64) { } else if (label_type == framework::proto::VarType::INT64) {
MaskLabelByIndex<T, int64_t> MaskLabelByIndex<T, int64_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(predicted_logits.data<T>(), <<<blocks, threads, 0, dev_ctx.stream()>>>(predicted_logits.data<T>(),
softmax_2d.data<T>(), softmax_2d.data<T>(),
labels->data<int64_t>(), labels->data<int64_t>(),
ignore_index,
start_index, start_index,
end_index, end_index,
N, N,
...@@ -362,25 +367,27 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::GPUContext, T> { ...@@ -362,25 +367,27 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::GPUContext, T> {
const auto& label_type = framework::TransToProtoVarType(labels->dtype()); const auto& label_type = framework::TransToProtoVarType(labels->dtype());
if (label_type == framework::proto::VarType::INT32) { if (label_type == framework::proto::VarType::INT32) {
MaskLabelByIndex<T, int32_t> MaskLabelByIndex<T, int32_t><<<blocks, threads, 0, dev_ctx.stream()>>>(
<<<blocks, threads, 0, dev_ctx.stream()>>>(predicted_logits.data<T>(), predicted_logits.data<T>(),
softmax_2d.data<T>(), softmax_2d.data<T>(),
labels->data<int32_t>(), labels->data<int32_t>(),
start_index, static_cast<int32_t>(ignore_index),
end_index, start_index,
N, end_index,
D, N,
nranks); D,
nranks);
} else if (label_type == framework::proto::VarType::INT64) { } else if (label_type == framework::proto::VarType::INT64) {
MaskLabelByIndex<T, int64_t> MaskLabelByIndex<T, int64_t><<<blocks, threads, 0, dev_ctx.stream()>>>(
<<<blocks, threads, 0, dev_ctx.stream()>>>(predicted_logits.data<T>(), predicted_logits.data<T>(),
softmax_2d.data<T>(), softmax_2d.data<T>(),
labels->data<int64_t>(), labels->data<int64_t>(),
start_index, static_cast<int32_t>(ignore_index),
end_index, start_index,
N, end_index,
D, N,
nranks); D,
nranks);
} }
in_out.clear(); in_out.clear();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册