From b4a3f21ce6ef8c4134fb5530307105dc0bb14f4c Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Sun, 15 Aug 2021 17:31:04 +0800 Subject: [PATCH] Update loss.py --- python/paddle/nn/functional/loss.py | 35 +---------------------------- 1 file changed, 1 insertion(+), 34 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index f12f897dae2..4e538fed645 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1804,41 +1804,8 @@ def cross_entropy(input, weight_gather_reshape = reshape(weight_gather, shape=out_shape) out = paddle.cast(out, weight_gather_reshape.dtype) else: - if input.shape[-1] != weight.shape[-1]: - raise ValueError("input's class_dimension({}) must equal to \ - weight's class_dimension({}) \ - when weight is provided" - .format(input.shape[-1], weight.shape[-1])) - valid_label = paddle.where( - label == ignore_index, - paddle.zeros( - [1], dtype=label.dtype), - label) - if (paddle.numel(paddle.nonzero(valid_label < 0)) > 0) or ( - paddle.numel( - paddle.nonzero(valid_label >= input.shape[-1])) > 0): - invalid_label = paddle.gather_nd( - input, paddle.nonzero(valid_label < 0)) - if paddle.numel(invalid_label) > 0: - raise ValueError( - "Target({}) is out of class_dimension's lower bound({})". - format(invalid_label[0], 0)) - invalid_label = paddle.gather_nd( - input, paddle.nonzero(valid_label >= input.shape[-1])) - if paddle.numel(invalid_label) > 0: - raise ValueError( - "Target({}) is out of class_dimension's upper bound({})". - format(invalid_label[0], input.shape[-1])) - - ignore_weight_mask = paddle.cast((label != ignore_index), out.dtype) - if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[ - -1] == 1: - ignore_weight_mask = paddle.squeeze(ignore_weight_mask, -1) - weight_gather = paddle.gather_nd( - weight, - valid_label) #trans weight from class to sample, shape:N - weight_gather = paddle.multiply(weight_gather, ignore_weight_mask) + weight, label) #trans weight from class to sample, shape:N input_shape = list(label.shape) weight_gather_reshape = reshape(weight_gather, shape=input_shape) out = paddle.multiply(out, weight_gather_reshape, name=weight_name) -- GitLab