未验证 提交 2105d146 编写于 作者: W wanghuancoder 提交者: GitHub

fix api sigmoid_focal_loss to final state (#45207)

上级 a79d4a75
......@@ -2616,7 +2616,6 @@ def sigmoid_focal_loss(logit,
"Expected one dimension of normalizer in sigmoid_focal_loss but got {}."
.format(normalizer_dims))
if _non_static_mode():
if in_dygraph_mode():
place = _current_expected_place()
one = _C_ops.final_state_full(logit.shape, float(1.0), logit.dtype,
......@@ -2625,14 +2624,46 @@ def sigmoid_focal_loss(logit,
loss = _C_ops.final_state_sigmoid_cross_entropy_with_logits(
logit, label, False, -100)
pred = _C_ops.final_state_sigmoid(logit)
p_t = _C_ops.final_state_add(
_C_ops.final_state_multiply(pred, label),
_C_ops.final_state_multiply(_C_ops.final_state_subtract(one, pred),
_C_ops.final_state_subtract(one,
label)))
alpha = fluid.dygraph.base.to_variable([alpha], dtype=loss.dtype)
alpha_t = _C_ops.final_state_add(
_C_ops.final_state_multiply(alpha, label),
_C_ops.final_state_multiply(_C_ops.final_state_subtract(one, alpha),
_C_ops.final_state_subtract(one,
label)))
loss = _C_ops.final_state_multiply(alpha_t, loss)
gamma = fluid.dygraph.base.to_variable([gamma], dtype=loss.dtype)
gamma_t = _C_ops.final_state_pow(_C_ops.elementwise_sub(one, p_t),
gamma)
loss = _C_ops.final_state_multiply(gamma_t, loss)
if normalizer is not None:
loss = _C_ops.final_state_divide(loss, normalizer)
if reduction == "sum":
return _C_ops.final_state_sum(loss, [], None, False)
elif reduction == "mean":
return _C_ops.final_state_mean_all(loss)
return loss
elif _in_legacy_dygraph():
one = _varbase_creator(dtype=logit.dtype)
_C_ops.fill_constant(one, 'value', float(1.0), 'force_cpu', False,
'dtype', one.dtype, 'str_value', '1.0',
'shape', logit.shape)
'dtype', one.dtype, 'str_value', '1.0', 'shape',
logit.shape)
loss = _C_ops.sigmoid_cross_entropy_with_logits(logit, label)
pred = _C_ops.sigmoid(logit)
p_t = _C_ops.elementwise_add(
_C_ops.elementwise_mul(pred, label),
_C_ops.elementwise_mul(_C_ops.elementwise_sub(one, pred),
......@@ -2656,8 +2687,6 @@ def sigmoid_focal_loss(logit,
if reduction == "sum":
return _C_ops.reduce_sum(loss, 'reduce_all', True)
elif reduction == "mean":
if in_dygraph_mode():
return _C_ops.final_state_mean_all(loss)
return _C_ops.mean(loss)
return loss
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册