提交 cf6e543b 编写于 作者: H HydrogenSulfate 提交者: chajchaj

Update loss.py

上级 11e9d4e3
......@@ -1653,12 +1653,16 @@ def cross_entropy(input,
if soft_label == False:
valid_label = paddle.where(label == ignore_index,
paddle.zeros_like(label), label)
# TODO: Temporarily use paddle.nonzero instead of paddle.max
# to detect and find out possible illegal label values
if len(paddle.nonzero(valid_label < 0)) > 0:
invalid_label = paddle.gather_nd(
valid_label, paddle.nonzero(valid_label < 0))
raise ValueError(
"Target({}) is out of class_dimension's lower bound({})".
format(invalid_label[0], 0))
# TODO: Temporarily use paddle.nonzero instead of paddle.max
# to detect and find out possible illegal label values
if len(paddle.nonzero(valid_label >= input.shape[-1])) > 0:
invalid_label = paddle.gather_nd(
valid_label, paddle.nonzero(valid_label >= input.shape[-1]))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册