From dd0140bd66639c76f27c4a5f415bb731d6247cde Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Sun, 15 Aug 2021 21:58:19 +0800 Subject: [PATCH] Update loss.py --- python/paddle/nn/functional/loss.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 21b1e444ef8..fd4e83a6e87 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1651,10 +1651,8 @@ def cross_entropy(input, label = paddle.unsqueeze(label, axis=axis) if in_dygraph_mode(): if soft_label == False: - valid_label = paddle.where( - label == ignore_index, - paddle.zeros_like(label), - label) + valid_label = paddle.where(label == ignore_index, + paddle.zeros_like(label), label) if len(paddle.nonzero(valid_label < 0)) > 0: invalid_label = paddle.gather_nd( valid_label, paddle.nonzero(valid_label < 0)) @@ -1705,8 +1703,7 @@ def cross_entropy(input, if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[ -1] == 1: ignore_weight_mask.squeeze_(-1) - weight_gather = _C_ops.gather_nd( - weight, valid_label) + weight_gather = _C_ops.gather_nd(weight, valid_label) weight_gather = _C_ops.elementwise_mul(weight_gather, ignore_weight_mask) input_shape = list(label.shape) -- GitLab