提交 3ab9ace5 编写于 作者: H HydrogenSulfate 提交者: chajchaj

update code

上级 09d4a3a4
......@@ -1665,6 +1665,27 @@ def cross_entropy(input,
if input_dims - 1 == label_dims:
label = paddle.unsqueeze(label, axis=axis)
if in_dygraph_mode():
if not soft_label:
valid_label = paddle.cast(
label != ignore_index, dtype=label.dtype) * label
# TODO: Temporarily use paddle.nonzero instead of paddle.max
# to detect and find out possible illegal label values
if len(paddle.nonzero(valid_label < 0)) > 0:
invalid_label = paddle.gather_nd(
valid_label, paddle.nonzero(valid_label < 0))
raise ValueError(
"Target({}) is out of class_dimension's lower bound({})".
format(invalid_label[0], 0))
# TODO: Temporarily use paddle.nonzero instead of paddle.max
# to detect and find out possible illegal label values
if len(paddle.nonzero(valid_label >= input.shape[axis])) > 0:
invalid_label = paddle.gather_nd(
valid_label,
paddle.nonzero(valid_label >= input.shape[axis]))
raise ValueError(
"Target({}) is out of class_dimension's upper bound({})".
format(invalid_label[0], input.shape[axis] - 1))
if core.is_compiled_with_npu():
_, _, out = _C_ops.softmax_with_cross_entropy(
input, label, 'soft_label', soft_label, 'ignore_index',
......@@ -1703,32 +1724,8 @@ def cross_entropy(input,
"when weight is provided" \
.format(input.shape[axis], weight.shape[-1]))
ignore_weight_mask = (
label != ignore_index) # ignored position will be False
valid_label = paddle.cast(
ignore_weight_mask,
dtype=label.dtype) * label # ignored position will be 0
if len(paddle.nonzero(valid_label < 0)) > 0:
invalid_label = paddle.gather_nd(
valid_label, paddle.nonzero(valid_label < 0))
raise ValueError(
"Target({}) is out of class_dimension's lower bound({})".
format(invalid_label[0], 0))
# TODO: Temporarily use paddle.nonzero instead of paddle.max
# to detect and find out possible illegal label values
if len(paddle.nonzero(valid_label >= input.shape[axis])) > 0:
invalid_label = paddle.gather_nd(
valid_label,
paddle.nonzero(valid_label >= input.shape[axis]))
raise ValueError(
"Target({}) is out of class_dimension's upper bound({})".
format(invalid_label[0], input.shape[axis] - 1))
ignore_weight_mask = paddle.cast(
ignore_weight_mask, out.dtype) # convert from 0 to 0.0
ignore_weight_mask = paddle.cast((label != ignore_index),
out.dtype)
if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[
axis] == 1:
# TODO: Temporarily use squeeze instead of squeeze_
......@@ -1842,32 +1839,12 @@ def cross_entropy(input,
"when weight is provided" \
.format(input.shape[axis], weight.shape[-1]))
ignore_weight_mask = (
label != ignore_index) # ignored position will be False
valid_label = paddle.cast(
ignore_weight_mask,
dtype=label.dtype) * label # ignored position will be 0
if len(paddle.nonzero(valid_label < 0)) > 0:
invalid_label = paddle.gather_nd(
valid_label, paddle.nonzero(valid_label < 0))
raise ValueError(
"Target({}) is out of class_dimension's lower bound({})".
format(invalid_label[0], 0))
# TODO: Temporarily use paddle.nonzero instead of paddle.max
# to detect and find out possible illegal label values
if len(paddle.nonzero(valid_label >= input.shape[axis])) > 0:
invalid_label = paddle.gather_nd(
valid_label,
paddle.nonzero(valid_label >= input.shape[axis]))
raise ValueError(
"Target({}) is out of class_dimension's upper bound({})".
format(invalid_label[0], input.shape[axis] - 1))
ignore_weight_mask = paddle.cast(ignore_weight_mask,
out.dtype) # convert from 0 to 0.0
valid_label = paddle.multiply(
paddle.cast(
label != ignore_index, dtype=label.dtype), label)
ignore_weight_mask = paddle.cast((label != ignore_index),
input.dtype)
if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[
axis] == 1:
ignore_weight_mask = paddle.squeeze(ignore_weight_mask, axis)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册