提交 dd0140bd 编写于 作者: H HydrogenSulfate 提交者: chajchaj

Update loss.py

上级 3ca813e6
...@@ -1651,10 +1651,8 @@ def cross_entropy(input, ...@@ -1651,10 +1651,8 @@ def cross_entropy(input,
label = paddle.unsqueeze(label, axis=axis) label = paddle.unsqueeze(label, axis=axis)
if in_dygraph_mode(): if in_dygraph_mode():
if soft_label == False: if soft_label == False:
valid_label = paddle.where( valid_label = paddle.where(label == ignore_index,
label == ignore_index, paddle.zeros_like(label), label)
paddle.zeros_like(label),
label)
if len(paddle.nonzero(valid_label < 0)) > 0: if len(paddle.nonzero(valid_label < 0)) > 0:
invalid_label = paddle.gather_nd( invalid_label = paddle.gather_nd(
valid_label, paddle.nonzero(valid_label < 0)) valid_label, paddle.nonzero(valid_label < 0))
...@@ -1705,8 +1703,7 @@ def cross_entropy(input, ...@@ -1705,8 +1703,7 @@ def cross_entropy(input,
if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[ if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[
-1] == 1: -1] == 1:
ignore_weight_mask.squeeze_(-1) ignore_weight_mask.squeeze_(-1)
weight_gather = _C_ops.gather_nd( weight_gather = _C_ops.gather_nd(weight, valid_label)
weight, valid_label)
weight_gather = _C_ops.elementwise_mul(weight_gather, weight_gather = _C_ops.elementwise_mul(weight_gather,
ignore_weight_mask) ignore_weight_mask)
input_shape = list(label.shape) input_shape = list(label.shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册