未验证 提交 3ba75d4a 编写于 作者: Q qingqing01 提交者: GitHub

Check label range in cross entropy calculation. (#10954)

上级 0c44efb8
...@@ -46,7 +46,10 @@ class CrossEntropyFunctor<platform::CPUDeviceContext, T> { ...@@ -46,7 +46,10 @@ class CrossEntropyFunctor<platform::CPUDeviceContext, T> {
const int64_t* label_data = labels->data<int64_t>(); const int64_t* label_data = labels->data<int64_t>();
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
int index = i * class_num + label_data[i]; int lbl = label_data[i];
PADDLE_ENFORCE_GE(lbl, 0);
PADDLE_ENFORCE_LT(lbl, class_num);
int index = i * class_num + lbl;
loss_data[i] = -math::TolerableValue<T>()(std::log(prob_data[index])); loss_data[i] = -math::TolerableValue<T>()(std::log(prob_data[index]));
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册