From e30150dd45514edaf59e29d4ef5e841360973233 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Mon, 10 Jan 2022 00:45:56 +0800 Subject: [PATCH] replace where with min and max --- python/paddle/nn/functional/loss.py | 29 +++++++++-------------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index f571f5d3028..90ada8c3c5e 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1665,27 +1665,17 @@ def cross_entropy(input, if input_dims - 1 == label_dims: label = paddle.unsqueeze(label, axis=axis) if in_dygraph_mode(): - if not soft_label: + if soft_label == False: valid_label = paddle.cast( label != ignore_index, dtype=label.dtype) * label - # TODO: Temporarily use paddle.nonzero instead of paddle.max - # to detect and find out possible illegal label values - if len(paddle.nonzero(valid_label < 0)) > 0: - invalid_label = paddle.gather_nd( - valid_label, paddle.nonzero(valid_label < 0)) - raise ValueError( - "Target({}) is out of class_dimension's lower bound({})". - format(invalid_label[0], 0)) - # TODO: Temporarily use paddle.nonzero instead of paddle.max - # to detect and find out possible illegal label values - if len(paddle.nonzero(valid_label >= input.shape[axis])) > 0: - invalid_label = paddle.gather_nd( - valid_label, - paddle.nonzero(valid_label >= input.shape[axis])) - raise ValueError( - "Target({}) is out of class_dimension's upper bound({})". - format(invalid_label[0], input.shape[axis] - 1)) - + label_min = paddle.min(valid_label) + label_max = paddle.max(valid_label) + if label_min < 0: + raise ValueError("label should not out of bound, but got{}". + format(label_min)) + if label_max >= input.shape[axis]: + raise ValueError("label should not out of bound, but got{}". + format(label_max)) if core.is_compiled_with_npu(): _, _, out = _C_ops.softmax_with_cross_entropy( input, label, 'soft_label', soft_label, 'ignore_index', @@ -1842,7 +1832,6 @@ def cross_entropy(input, valid_label = paddle.multiply( paddle.cast( label != ignore_index, dtype=label.dtype), label) - ignore_weight_mask = paddle.cast((label != ignore_index), input.dtype) if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[ -- GitLab