提交 b4a3f21c 编写于 作者: H HydrogenSulfate 提交者: chajchaj

Update loss.py

上级 fa4805b4
......@@ -1804,41 +1804,8 @@ def cross_entropy(input,
weight_gather_reshape = reshape(weight_gather, shape=out_shape)
out = paddle.cast(out, weight_gather_reshape.dtype)
else:
if input.shape[-1] != weight.shape[-1]:
raise ValueError("input's class_dimension({}) must equal to \
weight's class_dimension({}) \
when weight is provided"
.format(input.shape[-1], weight.shape[-1]))
valid_label = paddle.where(
label == ignore_index,
paddle.zeros(
[1], dtype=label.dtype),
label)
if (paddle.numel(paddle.nonzero(valid_label < 0)) > 0) or (
paddle.numel(
paddle.nonzero(valid_label >= input.shape[-1])) > 0):
invalid_label = paddle.gather_nd(
input, paddle.nonzero(valid_label < 0))
if paddle.numel(invalid_label) > 0:
raise ValueError(
"Target({}) is out of class_dimension's lower bound({})".
format(invalid_label[0], 0))
invalid_label = paddle.gather_nd(
input, paddle.nonzero(valid_label >= input.shape[-1]))
if paddle.numel(invalid_label) > 0:
raise ValueError(
"Target({}) is out of class_dimension's upper bound({})".
format(invalid_label[0], input.shape[-1]))
ignore_weight_mask = paddle.cast((label != ignore_index), out.dtype)
if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[
-1] == 1:
ignore_weight_mask = paddle.squeeze(ignore_weight_mask, -1)
weight_gather = paddle.gather_nd(
weight,
valid_label) #trans weight from class to sample, shape:N
weight_gather = paddle.multiply(weight_gather, ignore_weight_mask)
weight, label) #trans weight from class to sample, shape:N
input_shape = list(label.shape)
weight_gather_reshape = reshape(weight_gather, shape=input_shape)
out = paddle.multiply(out, weight_gather_reshape, name=weight_name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册