From ef7e5fc76e42b997e61db91941ffe98898f29232 Mon Sep 17 00:00:00 2001 From: root <12272008@bjtu.edu.cn> Date: Tue, 27 Apr 2021 07:19:30 +0000 Subject: [PATCH] imporve efficiency --- python/paddle/nn/functional/loss.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index cf4d5b1ed35..eeb00625876 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1411,11 +1411,13 @@ def cross_entropy(input, out = core.ops.elementwise_mul(out, weight_gather_reshape) else: - for label_val in label.flatten(): - if label_val < 0 or label_val >= input.shape[-1]: - raise ValueError( - 'Expected 0 <= label_value < class_dimension({}), but got label_value {}'. - format(input.shape[-1], label_val.numpy())) + label_min = paddle.min(label) + label_max = paddle.max(label) + if label_min < 0 or label_max >= input.shape[-1]: + raise ValueError( + 'Expected 0 <= label_value < class_dimension({}), but got {} <= label_value <= {} '. + format(input.shape[-1], + label_min.numpy(), label_max.numpy())) weight_gather = core.ops.gather_nd(weight, label) input_shape = list(label.shape) weight_gather_reshape = reshape( -- GitLab