diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index aee0366ab3093bb2baed4ff402e9a1d1239987e0..f571f5d30285f3d08a0f1525197aae449b50a46c 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1665,6 +1665,27 @@ def cross_entropy(input, if input_dims - 1 == label_dims: label = paddle.unsqueeze(label, axis=axis) if in_dygraph_mode(): + if not soft_label: + valid_label = paddle.cast( + label != ignore_index, dtype=label.dtype) * label + # TODO: Temporarily use paddle.nonzero instead of paddle.max + # to detect and find out possible illegal label values + if len(paddle.nonzero(valid_label < 0)) > 0: + invalid_label = paddle.gather_nd( + valid_label, paddle.nonzero(valid_label < 0)) + raise ValueError( + "Target({}) is out of class_dimension's lower bound({})". + format(invalid_label[0], 0)) + # TODO: Temporarily use paddle.nonzero instead of paddle.max + # to detect and find out possible illegal label values + if len(paddle.nonzero(valid_label >= input.shape[axis])) > 0: + invalid_label = paddle.gather_nd( + valid_label, + paddle.nonzero(valid_label >= input.shape[axis])) + raise ValueError( + "Target({}) is out of class_dimension's upper bound({})". + format(invalid_label[0], input.shape[axis] - 1)) + if core.is_compiled_with_npu(): _, _, out = _C_ops.softmax_with_cross_entropy( input, label, 'soft_label', soft_label, 'ignore_index', @@ -1681,7 +1702,7 @@ def cross_entropy(input, # trans weight from class to sample, shape:N or [N,H,W] for 1d and 2d cases. if soft_label == True: # chajchaj: - # weight's shape is C, where C is class num. + # weight's shape is C, where C is class num. # for 1d case: label's shape is [N,C], weight_gather's shape is N. # for 2d case: label's shape is [N,H,W,C], weight_gather's shape is [N,H,W]. weight_gather = paddle.matmul( @@ -1703,32 +1724,8 @@ def cross_entropy(input, "when weight is provided" \ .format(input.shape[axis], weight.shape[-1])) - ignore_weight_mask = ( - label != ignore_index) # ignored position will be False - - valid_label = paddle.cast( - ignore_weight_mask, - dtype=label.dtype) * label # ignored position will be 0 - - if len(paddle.nonzero(valid_label < 0)) > 0: - invalid_label = paddle.gather_nd( - valid_label, paddle.nonzero(valid_label < 0)) - raise ValueError( - "Target({}) is out of class_dimension's lower bound({})". - format(invalid_label[0], 0)) - # TODO: Temporarily use paddle.nonzero instead of paddle.max - # to detect and find out possible illegal label values - if len(paddle.nonzero(valid_label >= input.shape[axis])) > 0: - invalid_label = paddle.gather_nd( - valid_label, - paddle.nonzero(valid_label >= input.shape[axis])) - raise ValueError( - "Target({}) is out of class_dimension's upper bound({})". - format(invalid_label[0], input.shape[axis] - 1)) - - ignore_weight_mask = paddle.cast( - ignore_weight_mask, out.dtype) # convert from 0 to 0.0 - + ignore_weight_mask = paddle.cast((label != ignore_index), + out.dtype) if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[ axis] == 1: # TODO: Temporarily use squeeze instead of squeeze_ @@ -1842,32 +1839,12 @@ def cross_entropy(input, "when weight is provided" \ .format(input.shape[axis], weight.shape[-1])) - ignore_weight_mask = ( - label != ignore_index) # ignored position will be False - - valid_label = paddle.cast( - ignore_weight_mask, - dtype=label.dtype) * label # ignored position will be 0 - - if len(paddle.nonzero(valid_label < 0)) > 0: - invalid_label = paddle.gather_nd( - valid_label, paddle.nonzero(valid_label < 0)) - raise ValueError( - "Target({}) is out of class_dimension's lower bound({})". - format(invalid_label[0], 0)) - # TODO: Temporarily use paddle.nonzero instead of paddle.max - # to detect and find out possible illegal label values - if len(paddle.nonzero(valid_label >= input.shape[axis])) > 0: - invalid_label = paddle.gather_nd( - valid_label, - paddle.nonzero(valid_label >= input.shape[axis])) - raise ValueError( - "Target({}) is out of class_dimension's upper bound({})". - format(invalid_label[0], input.shape[axis] - 1)) - - ignore_weight_mask = paddle.cast(ignore_weight_mask, - out.dtype) # convert from 0 to 0.0 + valid_label = paddle.multiply( + paddle.cast( + label != ignore_index, dtype=label.dtype), label) + ignore_weight_mask = paddle.cast((label != ignore_index), + input.dtype) if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[ axis] == 1: ignore_weight_mask = paddle.squeeze(ignore_weight_mask, axis)