From 8db66fc3f64bbdf3dc8adad26a4a1f4a42828199 Mon Sep 17 00:00:00 2001 From: ceci3 Date: Fri, 12 Jun 2020 11:38:38 +0800 Subject: [PATCH] fix cos_sim, test=develop (#25017) --- paddle/fluid/operators/math/cos_sim_functor.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/paddle/fluid/operators/math/cos_sim_functor.h b/paddle/fluid/operators/math/cos_sim_functor.h index 30ea5e60e87..d74662e68e7 100644 --- a/paddle/fluid/operators/math/cos_sim_functor.h +++ b/paddle/fluid/operators/math/cos_sim_functor.h @@ -35,6 +35,7 @@ struct CosSimFunctor { inline HOSTDEVICE void operator()(size_t row_id) const { auto* x = x_ + cols_ * row_id; T xx = 0, xy = 0, yy = 0; + T eps = 1e-8; if (same_row) { auto* y = y_ + cols_ * row_id; T tep_x, tep_y; @@ -45,6 +46,8 @@ struct CosSimFunctor { yy += tep_y * tep_y; xy += tep_x * tep_y; } + xx = xx > eps ? xx : eps; + yy = yy > eps ? yy : eps; xx = sqrt(xx); yy = sqrt(yy); y_norm_[row_id] = yy; @@ -59,6 +62,8 @@ struct CosSimFunctor { yy += tep_y * tep_y; xy += tep_x * tep_y; } + xx = xx > eps ? xx : eps; + yy = yy > eps ? yy : eps; xx = sqrt(xx); yy = sqrt(yy); if (row_id == 0) y_norm_[0] = yy; -- GitLab