From fa4805b479813c5009823d9100f01a10210ed278 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Sun, 15 Aug 2021 17:11:18 +0800 Subject: [PATCH] Update loss.py --- python/paddle/nn/functional/loss.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 3fd9f45d962..f12f897dae2 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1812,10 +1812,11 @@ def cross_entropy(input, valid_label = paddle.where( label == ignore_index, paddle.zeros( - [1], dtype=label.dtype), - label) + [1], dtype=label.dtype), + label) if (paddle.numel(paddle.nonzero(valid_label < 0)) > 0) or ( - paddle.numel(paddle.nonzero(valid_label >= input.shape[-1])) > 0): + paddle.numel( + paddle.nonzero(valid_label >= input.shape[-1])) > 0): invalid_label = paddle.gather_nd( input, paddle.nonzero(valid_label < 0)) if paddle.numel(invalid_label) > 0: -- GitLab