From 6cd41cec2146da2f5008a42e972a4627a4deb26d Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Mon, 11 Oct 2021 22:15:05 +0800 Subject: [PATCH] Update loss.py --- python/paddle/nn/functional/loss.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index eb043c0056..38d4da17cb 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1712,11 +1712,12 @@ def cross_entropy(input, out.dtype) if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[ axis] == 1: - ignore_weight_mask.squeeze_(axis) + # TODO: Temporarily use squeeze instead of squeeze_ + ignore_weight_mask = paddle.squeeze(ignore_weight_mask, axis) if axis != -1 and axis != valid_label.ndim - 1: temp_perm = list(range(axis % valid_label.ndim)) \ + list(range((axis % valid_label.ndim + 1) , valid_label.ndim)) \ - + [axis%valid_label.ndim] + + [axis % valid_label.ndim] weight_gather = _C_ops.gather_nd( weight, valid_label.transpose(temp_perm)) else: @@ -1828,9 +1829,9 @@ def cross_entropy(input, if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[ axis] == 1: ignore_weight_mask = paddle.squeeze(ignore_weight_mask, axis) - if axis != -1: + if axis != -1 and axis != valid_label.ndim - 1: temp_perm = list(range(axis % valid_label.ndim)) \ - + list(range((axis + 1) % valid_label.ndim, valid_label.ndim)) \ + + list(range((axis % valid_label.ndim + 1), valid_label.ndim)) \ + [axis % valid_label.ndim] weight_gather = paddle.gather_nd( weight, paddle.transpose(valid_label, temp_perm)) -- GitLab