提交 34c57120 编写于 作者: K Kaipeng Deng 提交者: Tao Luo

polish cross_entropy ENFORCE (#22056)

上级 1c39efb7
......@@ -53,7 +53,21 @@ class CrossEntropyFunctor<platform::CPUDeviceContext, T> {
for (int i = 0; i < batch_size; ++i) {
for (int j = 0; j < num_remain; j++) {
int lbl = label_data[i * num_remain + j];
PADDLE_ENFORCE((lbl >= 0 && lbl < axis_dim) || lbl == ignore_index);
if (lbl != ignore_index) {
PADDLE_ENFORCE_GE(lbl, 0,
platform::errors::OutOfRange(
"label value should >= 0 when label "
"value(%f) not equal to ignore_index(%f)",
lbl, ignore_index));
PADDLE_ENFORCE_LT(
lbl, axis_dim,
platform::errors::OutOfRange(
"label value should less than the shape of axis dimension "
"when label value(%f) not equal to ignore_index(%f), But "
"received label value as %ld and shape of axis dimension "
"is %d",
lbl, ignore_index, lbl, axis_dim));
}
int index = i * num_classes + lbl * num_remain + j;
int loss_idx = i * num_remain + j;
loss_data[loss_idx] =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册