From 2105d146238dfe67f2b3b08503d09a902ab2f152 Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Wed, 17 Aug 2022 16:38:14 +0800 Subject: [PATCH] fix api sigmoid_focal_loss to final state (#45207) --- python/paddle/nn/functional/loss.py | 59 +++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 15 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 4e4f968e68..3a9dd59538 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -2616,23 +2616,54 @@ def sigmoid_focal_loss(logit, "Expected one dimension of normalizer in sigmoid_focal_loss but got {}." .format(normalizer_dims)) - if _non_static_mode(): - if in_dygraph_mode(): - place = _current_expected_place() - one = _C_ops.final_state_full(logit.shape, float(1.0), logit.dtype, - place) + if in_dygraph_mode(): + place = _current_expected_place() + one = _C_ops.final_state_full(logit.shape, float(1.0), logit.dtype, + place) - loss = _C_ops.final_state_sigmoid_cross_entropy_with_logits( - logit, label, False, -100) + loss = _C_ops.final_state_sigmoid_cross_entropy_with_logits( + logit, label, False, -100) - elif _in_legacy_dygraph(): - one = _varbase_creator(dtype=logit.dtype) - _C_ops.fill_constant(one, 'value', float(1.0), 'force_cpu', False, - 'dtype', one.dtype, 'str_value', '1.0', - 'shape', logit.shape) - loss = _C_ops.sigmoid_cross_entropy_with_logits(logit, label) + pred = _C_ops.final_state_sigmoid(logit) + + p_t = _C_ops.final_state_add( + _C_ops.final_state_multiply(pred, label), + _C_ops.final_state_multiply(_C_ops.final_state_subtract(one, pred), + _C_ops.final_state_subtract(one, + label))) + + alpha = fluid.dygraph.base.to_variable([alpha], dtype=loss.dtype) + alpha_t = _C_ops.final_state_add( + _C_ops.final_state_multiply(alpha, label), + _C_ops.final_state_multiply(_C_ops.final_state_subtract(one, alpha), + _C_ops.final_state_subtract(one, + label))) + loss = _C_ops.final_state_multiply(alpha_t, loss) + + gamma = fluid.dygraph.base.to_variable([gamma], dtype=loss.dtype) + gamma_t = _C_ops.final_state_pow(_C_ops.elementwise_sub(one, p_t), + gamma) + loss = _C_ops.final_state_multiply(gamma_t, loss) + + if normalizer is not None: + loss = _C_ops.final_state_divide(loss, normalizer) + + if reduction == "sum": + return _C_ops.final_state_sum(loss, [], None, False) + elif reduction == "mean": + return _C_ops.final_state_mean_all(loss) + + return loss + + elif _in_legacy_dygraph(): + one = _varbase_creator(dtype=logit.dtype) + _C_ops.fill_constant(one, 'value', float(1.0), 'force_cpu', False, + 'dtype', one.dtype, 'str_value', '1.0', 'shape', + logit.shape) + loss = _C_ops.sigmoid_cross_entropy_with_logits(logit, label) pred = _C_ops.sigmoid(logit) + p_t = _C_ops.elementwise_add( _C_ops.elementwise_mul(pred, label), _C_ops.elementwise_mul(_C_ops.elementwise_sub(one, pred), @@ -2656,8 +2687,6 @@ def sigmoid_focal_loss(logit, if reduction == "sum": return _C_ops.reduce_sum(loss, 'reduce_all', True) elif reduction == "mean": - if in_dygraph_mode(): - return _C_ops.final_state_mean_all(loss) return _C_ops.mean(loss) return loss -- GitLab