diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 4e538fed64547b65b75c0d1a556e7fdff68486a9..a3115cff6a3b3ba4196a83b0715bfc203e67bcf4 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1650,6 +1650,24 @@ def cross_entropy(input, if input_dims - 1 == label_dims: label = paddle.unsqueeze(label, axis=axis) if in_dygraph_mode(): + if soft_label == False: + valid_label = paddle.where( + label == ignore_index, + paddle.zeros_like(label), + label) + if len(paddle.nonzero(valid_label < 0)) > 0: + invalid_label = paddle.gather_nd( + valid_label, paddle.nonzero(valid_label < 0)) + raise ValueError( + "Target({}) is out of class_dimension's lower bound({})". + format(invalid_label[0], 0)) + if len(paddle.nonzero(valid_label >= input.shape[-1])) > 0: + invalid_label = paddle.gather_nd( + valid_label, paddle.nonzero(valid_label >= input.shape[-1])) + raise ValueError( + "Target({}) is out of class_dimension's upper bound({})". + format(invalid_label[0], input.shape[-1] - 1)) + _, out = _C_ops.softmax_with_cross_entropy( input, label, 'soft_label', soft_label, 'ignore_index', ignore_index, 'numeric_stable_mode', True, 'axis', axis, @@ -1681,27 +1699,6 @@ def cross_entropy(input, weight's class_dimension({}) \ when weight is provided" .format(input.shape[-1], weight.shape[-1])) - valid_label = paddle.where( - label == ignore_index, - paddle.to_tensor( - 0, dtype=label.dtype), - label) - - if (len(paddle.nonzero(valid_label < 0)) > 0) or ( - len(paddle.nonzero(valid_label >= input.shape[-1])) > 0 - ): - invalid_label = paddle.gather_nd( - input, paddle.nonzero(valid_label < 0)) - if invalid_label.numel() > 0: - raise ValueError( - "Target({}) is out of class_dimension's lower bound({})". - format(invalid_label[0], 0)) - invalid_label = paddle.gather_nd( - input, paddle.nonzero(valid_label >= input.shape[-1])) - if invalid_label.numel() > 0: - raise ValueError( - "Target({}) is out of class_dimension's upper bound({})". - format(invalid_label[0], input.shape[-1])) ignore_weight_mask = paddle.cast((label != ignore_index), out.dtype) @@ -1709,7 +1706,7 @@ def cross_entropy(input, -1] == 1: ignore_weight_mask.squeeze_(-1) weight_gather = _C_ops.gather_nd( - weight, valid_label) # ignore的位置暂时用label0的权重代替 + weight, valid_label) weight_gather = _C_ops.elementwise_mul(weight_gather, ignore_weight_mask) input_shape = list(label.shape) @@ -1804,6 +1801,12 @@ def cross_entropy(input, weight_gather_reshape = reshape(weight_gather, shape=out_shape) out = paddle.cast(out, weight_gather_reshape.dtype) else: + if input.shape[-1] != weight.shape[-1]: + raise ValueError("input's class_dimension({}) must equal to "\ + "weight's class_dimension({}) "\ + "when weight is provided" + .format(input.shape[-1], weight.shape[-1])) + weight_gather = paddle.gather_nd( weight, label) #trans weight from class to sample, shape:N input_shape = list(label.shape)