提交 8c2fbc31 编写于 作者: H HydrogenSulfate 提交者: chajchaj

Update loss.py

上级 1d660eb6
...@@ -1668,12 +1668,12 @@ def cross_entropy(input, ...@@ -1668,12 +1668,12 @@ def cross_entropy(input,
format(invalid_label[0], 0)) format(invalid_label[0], 0))
# TODO: Temporarily use paddle.nonzero instead of paddle.max # TODO: Temporarily use paddle.nonzero instead of paddle.max
# to detect and find out possible illegal label values # to detect and find out possible illegal label values
if len(paddle.nonzero(valid_label >= input.shape[-1])) > 0: if len(paddle.nonzero(valid_label >= input.shape[axis])) > 0:
invalid_label = paddle.gather_nd( invalid_label = paddle.gather_nd(
valid_label, paddle.nonzero(valid_label >= input.shape[-1])) valid_label, paddle.nonzero(valid_label >= input.shape[axis]))
raise ValueError( raise ValueError(
"Target({}) is out of class_dimension's upper bound({})". "Target({}) is out of class_dimension's upper bound({})".
format(invalid_label[0], input.shape[-1] - 1)) format(invalid_label[0], input.shape[axis] - 1))
_, out = _C_ops.softmax_with_cross_entropy( _, out = _C_ops.softmax_with_cross_entropy(
input, label, 'soft_label', soft_label, 'ignore_index', input, label, 'soft_label', soft_label, 'ignore_index',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册