diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 3a9dd5953876aa7f2979574b500fc463bc88d433..004070d23c647869a0131f556437183e6392d88e 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -2416,7 +2416,7 @@ def cross_entropy(input, out = paddle.squeeze(out, axis=axis) return out - check_variable_and_dtype(input, 'input', ['float32', 'float64'], + check_variable_and_dtype(input, 'input', ['float16', 'float32', 'float64'], 'softmax_cross_entropy') check_variable_and_dtype( label, 'label',