diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 8eb6e05fc04e6bded9c335a87cdc4b541005347d..f13f14cdde118ec24161f5a722edca384b4f684a 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1696,6 +1696,13 @@ def cross_entropy(input, out = _C_ops.elementwise_mul(out, weight_gather_reshape) else: + if input.shape[axis] != weight.shape[-1]: + raise ValueError( + "input's class_dimension({}) must equal to " + "weight's class_dimension({}) " + "when weight is provided" \ + .format(input.shape[axis], weight.shape[-1])) + valid_label = paddle.where(label == ignore_index, paddle.zeros_like(label), label) # TODO: Temporarily use paddle.nonzero instead of paddle.max @@ -1715,12 +1722,6 @@ def cross_entropy(input, raise ValueError( "Target({}) is out of class_dimension's upper bound({})". format(invalid_label[0], input.shape[axis] - 1)) - if input.shape[axis] != weight.shape[-1]: - raise ValueError( - "input's class_dimension({}) must equal to " - "weight's class_dimension({}) " - "when weight is provided" \ - .format(input.shape[axis], weight.shape[-1])) ignore_weight_mask = paddle.cast((label != ignore_index), out.dtype)