未验证 提交 944c3165 编写于 作者: F FlyingQianMM 提交者: GitHub

fix type error of std::pow in sigmoid_focal_loss_op.cu and sigmoid_focal_loss_op.h (#18152)

* test=develop
fix type error of std::pow in sigmoid_focal_loss_op.cu and sigmoid_focal_loss_op.h

* test=develop
fix wrong code stype in sigmoid_focal_loss_op.cu and sigmoid_focal_loss_op.h
上级 25f3cd64
......@@ -61,8 +61,8 @@ __global__ void GPUSigmoidFocalLossForward(const T *x_data,
T p = 1. / (1. + real_exp(-x));
// (1 - p)**gamma * log(p)
T term_pos =
std::pow((1. - p), gamma) * real_log(p > FLT_MIN ? p : FLT_MIN);
T term_pos = std::pow(static_cast<T>(1. - p), gamma) *
real_log(p > FLT_MIN ? p : FLT_MIN);
// p**gamma * log(1 - p)
T term_neg =
std::pow(p, gamma) *
......@@ -97,7 +97,7 @@ __global__ void GPUSigmoidFocalLossBackward(
T p = 1. / (1. + real_exp(-x));
// (1-p)**g * (1 - p - g*p*log(p))
T term_pos = std::pow((1. - p), gamma) *
T term_pos = std::pow(static_cast<T>(1. - p), gamma) *
(1. - p - (p * gamma * real_log(p > FLT_MIN ? p : FLT_MIN)));
// (p**g) * (g*(1-p)*log(1-p) - p)
T term_neg =
......
......@@ -59,12 +59,13 @@ class SigmoidFocalLossKernel : public framework::OpKernel<T> {
T p = 1. / (1. + std::exp(-x));
// (1 - p)**gamma * log(p) where
T term_pos =
std::pow((1. - p), gamma) * std::log(p > FLT_MIN ? p : FLT_MIN);
T term_pos = std::pow(static_cast<T>(1. - p), gamma) *
std::log(p > FLT_MIN ? p : FLT_MIN);
// p**gamma * log(1 - p)
float term_neg =
T term_neg =
std::pow(p, gamma) *
(-1. * x * (x >= 0) - std::log(1. + std::exp(x - 2. * x * (x >= 0))));
out_data[idx] = 0.0;
out_data[idx] += -c_pos * term_pos * s_pos;
out_data[idx] += -c_neg * term_neg * s_neg;
......@@ -107,7 +108,7 @@ class SigmoidFocalLossGradKernel : public framework::OpKernel<T> {
T p = 1. / (1. + std::exp(-x));
// (1-p)**g * (1 - p - g*p*log(p))
T term_pos = std::pow((1. - p), gamma) *
T term_pos = std::pow(static_cast<T>(1. - p), gamma) *
(1. - p - (p * gamma * std::log(p > FLT_MIN ? p : FLT_MIN)));
// (p**g) * (g*(1-p)*log(1-p) - p)
T term_neg = std::pow(p, gamma) *
......@@ -115,7 +116,6 @@ class SigmoidFocalLossGradKernel : public framework::OpKernel<T> {
std::log(1. + std::exp(x - 2. * x * (x >= 0)))) *
(1. - p) * gamma -
p);
dx_data[idx] = 0.0;
dx_data[idx] += -c_pos * s_pos * term_pos;
dx_data[idx] += -c_neg * s_neg * term_neg;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册