提交 de972c50 编写于 作者: H HydrogenSulfate 提交者: chajchaj

Update loss.py

上级 c61027e8
...@@ -1650,6 +1650,24 @@ def cross_entropy(input, ...@@ -1650,6 +1650,24 @@ def cross_entropy(input,
if input_dims - 1 == label_dims: if input_dims - 1 == label_dims:
label = paddle.unsqueeze(label, axis=axis) label = paddle.unsqueeze(label, axis=axis)
if in_dygraph_mode(): if in_dygraph_mode():
if soft_label == False:
valid_label = paddle.where(
label == ignore_index,
paddle.zeros_like(label),
label)
if len(paddle.nonzero(valid_label < 0)) > 0:
invalid_label = paddle.gather_nd(
valid_label, paddle.nonzero(valid_label < 0))
raise ValueError(
"Target({}) is out of class_dimension's lower bound({})".
format(invalid_label[0], 0))
if len(paddle.nonzero(valid_label >= input.shape[-1])) > 0:
invalid_label = paddle.gather_nd(
valid_label, paddle.nonzero(valid_label >= input.shape[-1]))
raise ValueError(
"Target({}) is out of class_dimension's upper bound({})".
format(invalid_label[0], input.shape[-1] - 1))
_, out = _C_ops.softmax_with_cross_entropy( _, out = _C_ops.softmax_with_cross_entropy(
input, label, 'soft_label', soft_label, 'ignore_index', input, label, 'soft_label', soft_label, 'ignore_index',
ignore_index, 'numeric_stable_mode', True, 'axis', axis, ignore_index, 'numeric_stable_mode', True, 'axis', axis,
...@@ -1681,27 +1699,6 @@ def cross_entropy(input, ...@@ -1681,27 +1699,6 @@ def cross_entropy(input,
weight's class_dimension({}) \ weight's class_dimension({}) \
when weight is provided" when weight is provided"
.format(input.shape[-1], weight.shape[-1])) .format(input.shape[-1], weight.shape[-1]))
valid_label = paddle.where(
label == ignore_index,
paddle.to_tensor(
0, dtype=label.dtype),
label)
if (len(paddle.nonzero(valid_label < 0)) > 0) or (
len(paddle.nonzero(valid_label >= input.shape[-1])) > 0
):
invalid_label = paddle.gather_nd(
input, paddle.nonzero(valid_label < 0))
if invalid_label.numel() > 0:
raise ValueError(
"Target({}) is out of class_dimension's lower bound({})".
format(invalid_label[0], 0))
invalid_label = paddle.gather_nd(
input, paddle.nonzero(valid_label >= input.shape[-1]))
if invalid_label.numel() > 0:
raise ValueError(
"Target({}) is out of class_dimension's upper bound({})".
format(invalid_label[0], input.shape[-1]))
ignore_weight_mask = paddle.cast((label != ignore_index), ignore_weight_mask = paddle.cast((label != ignore_index),
out.dtype) out.dtype)
...@@ -1709,7 +1706,7 @@ def cross_entropy(input, ...@@ -1709,7 +1706,7 @@ def cross_entropy(input,
-1] == 1: -1] == 1:
ignore_weight_mask.squeeze_(-1) ignore_weight_mask.squeeze_(-1)
weight_gather = _C_ops.gather_nd( weight_gather = _C_ops.gather_nd(
weight, valid_label) # ignore的位置暂时用label0的权重代替 weight, valid_label)
weight_gather = _C_ops.elementwise_mul(weight_gather, weight_gather = _C_ops.elementwise_mul(weight_gather,
ignore_weight_mask) ignore_weight_mask)
input_shape = list(label.shape) input_shape = list(label.shape)
...@@ -1804,6 +1801,12 @@ def cross_entropy(input, ...@@ -1804,6 +1801,12 @@ def cross_entropy(input,
weight_gather_reshape = reshape(weight_gather, shape=out_shape) weight_gather_reshape = reshape(weight_gather, shape=out_shape)
out = paddle.cast(out, weight_gather_reshape.dtype) out = paddle.cast(out, weight_gather_reshape.dtype)
else: else:
if input.shape[-1] != weight.shape[-1]:
raise ValueError("input's class_dimension({}) must equal to "\
"weight's class_dimension({}) "\
"when weight is provided"
.format(input.shape[-1], weight.shape[-1]))
weight_gather = paddle.gather_nd( weight_gather = paddle.gather_nd(
weight, label) #trans weight from class to sample, shape:N weight, label) #trans weight from class to sample, shape:N
input_shape = list(label.shape) input_shape = list(label.shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册