未验证 提交 7e21ac92 编写于 作者: H HydrogenSulfate 提交者: GitHub

cherry pick from remove where (#38849)

上级 4cd8a78a
...@@ -1650,25 +1650,16 @@ def cross_entropy(input, ...@@ -1650,25 +1650,16 @@ def cross_entropy(input,
label = paddle.unsqueeze(label, axis=axis) label = paddle.unsqueeze(label, axis=axis)
if in_dygraph_mode(): if in_dygraph_mode():
if soft_label == False: if soft_label == False:
valid_label = paddle.where(label == ignore_index, valid_label = paddle.cast(
paddle.zeros_like(label), label) label != ignore_index, dtype=label.dtype) * label
# TODO: Temporarily use paddle.nonzero instead of paddle.max label_min = paddle.min(valid_label)
# to detect and find out possible illegal label values label_max = paddle.max(valid_label)
if len(paddle.nonzero(valid_label < 0)) > 0: if label_min < 0:
invalid_label = paddle.gather_nd( raise ValueError("label should not out of bound, but got{}".
valid_label, paddle.nonzero(valid_label < 0)) format(label_min))
raise ValueError( if label_max >= input.shape[axis]:
"Target({}) is out of class_dimension's lower bound({})". raise ValueError("label should not out of bound, but got{}".
format(invalid_label[0], 0)) format(label_max))
# TODO: Temporarily use paddle.nonzero instead of paddle.max
# to detect and find out possible illegal label values
if len(paddle.nonzero(valid_label >= input.shape[axis])) > 0:
invalid_label = paddle.gather_nd(
valid_label,
paddle.nonzero(valid_label >= input.shape[axis]))
raise ValueError(
"Target({}) is out of class_dimension's upper bound({})".
format(invalid_label[0], input.shape[axis] - 1))
_, out = _C_ops.softmax_with_cross_entropy( _, out = _C_ops.softmax_with_cross_entropy(
input, label, 'soft_label', soft_label, 'ignore_index', input, label, 'soft_label', soft_label, 'ignore_index',
...@@ -1817,8 +1808,9 @@ def cross_entropy(input, ...@@ -1817,8 +1808,9 @@ def cross_entropy(input,
"when weight is provided"\ "when weight is provided"\
.format(input.shape[axis], weight.shape[-1])) .format(input.shape[axis], weight.shape[-1]))
valid_label = paddle.where(label == ignore_index, valid_label = paddle.multiply(
paddle.zeros_like(label), label) paddle.cast(
label != ignore_index, dtype=label.dtype), label)
ignore_weight_mask = paddle.cast((label != ignore_index), ignore_weight_mask = paddle.cast((label != ignore_index),
input.dtype) input.dtype)
if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[ if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册