From f5ca2db2cc006b9211df85080c5d818c9c45f81f Mon Sep 17 00:00:00 2001 From: chajchaj <57249073+chajchaj@users.noreply.github.com> Date: Tue, 9 Feb 2021 20:39:33 +0800 Subject: [PATCH] support label with float input of cross_entropy, test=develop (#30929) * support label with float input of cross_entropy, test=develop * fix code style in nn/functional/loss.py, test=develop --- python/paddle/nn/functional/loss.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 90a3ebc679..c223addc26 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1275,7 +1275,8 @@ def cross_entropy(input, fluid.data_feeder.check_variable_and_dtype( input, 'input', ['float32', 'float64'], 'softmax_cross_entropy') fluid.data_feeder.check_variable_and_dtype( - label, 'label', ['int32', 'int64'], 'softmax_cross_entropy') + label, 'label', ['int32', 'int64', 'float32', 'float64'], + 'softmax_cross_entropy') out = softmax_with_cross_entropy( input, label, -- GitLab