From b4eec5d5adaa151479720dbc0e49b6408e3c7f95 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 28 Dec 2021 16:39:32 +0800 Subject: [PATCH] replace .where to '==' --- python/paddle/nn/functional/loss.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index c1800a781d4..4d09f1d5c38 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1703,11 +1703,16 @@ def cross_entropy(input, "when weight is provided" \ .format(input.shape[axis], weight.shape[-1])) - valid_label = paddle.where(label == ignore_index, - paddle.zeros_like(label), label) + ignore_weight_mask = ( + label != ignore_index) # ignored position will be False + + valid_label = paddle.cast( + ignore_weight_mask, + dtype=label.dtype) * label # ignored position will be 0 + + ignore_weight_mask = paddle.cast( + ignore_weight_mask, out.dtype) # convert from 0 to 0.0 - ignore_weight_mask = paddle.cast((label != ignore_index), - out.dtype) if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[ axis] == 1: # TODO: Temporarily use squeeze instead of squeeze_ @@ -1821,10 +1826,16 @@ def cross_entropy(input, "when weight is provided" \ .format(input.shape[axis], weight.shape[-1])) - valid_label = paddle.where(label == ignore_index, - paddle.zeros_like(label), label) - ignore_weight_mask = paddle.cast((label != ignore_index), - input.dtype) + ignore_weight_mask = ( + label != ignore_index) # ignored position will be False + + valid_label = paddle.cast( + ignore_weight_mask, + dtype=label.dtype) * label # ignored position will be 0 + + ignore_weight_mask = paddle.cast(ignore_weight_mask, + out.dtype) # convert from 0 to 0.0 + if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[ axis] == 1: ignore_weight_mask = paddle.squeeze(ignore_weight_mask, axis) -- GitLab