diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index b2b305fdf8161e6e8ad1f5e387a0bc18165ceba0..21b1e444ef844ed587c483a6565b87060008497a 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1801,7 +1801,7 @@ def cross_entropy(input, weight_gather_reshape = reshape(weight_gather, shape=out_shape) out = paddle.cast(out, weight_gather_reshape.dtype) else: - if input.shape[-1] != weight.shape[-1]: + if input.shape[-1] != weight.shape[-1]: raise ValueError("input's class_dimension({}) must equal to "\ "weight's class_dimension({}) "\ "when weight is provided"