未验证 提交 8db66fc3 编写于 作者: C ceci3 提交者: GitHub

fix cos_sim, test=develop (#25017)

上级 7a6f4d64
...@@ -35,6 +35,7 @@ struct CosSimFunctor { ...@@ -35,6 +35,7 @@ struct CosSimFunctor {
inline HOSTDEVICE void operator()(size_t row_id) const { inline HOSTDEVICE void operator()(size_t row_id) const {
auto* x = x_ + cols_ * row_id; auto* x = x_ + cols_ * row_id;
T xx = 0, xy = 0, yy = 0; T xx = 0, xy = 0, yy = 0;
T eps = 1e-8;
if (same_row) { if (same_row) {
auto* y = y_ + cols_ * row_id; auto* y = y_ + cols_ * row_id;
T tep_x, tep_y; T tep_x, tep_y;
...@@ -45,6 +46,8 @@ struct CosSimFunctor { ...@@ -45,6 +46,8 @@ struct CosSimFunctor {
yy += tep_y * tep_y; yy += tep_y * tep_y;
xy += tep_x * tep_y; xy += tep_x * tep_y;
} }
xx = xx > eps ? xx : eps;
yy = yy > eps ? yy : eps;
xx = sqrt(xx); xx = sqrt(xx);
yy = sqrt(yy); yy = sqrt(yy);
y_norm_[row_id] = yy; y_norm_[row_id] = yy;
...@@ -59,6 +62,8 @@ struct CosSimFunctor { ...@@ -59,6 +62,8 @@ struct CosSimFunctor {
yy += tep_y * tep_y; yy += tep_y * tep_y;
xy += tep_x * tep_y; xy += tep_x * tep_y;
} }
xx = xx > eps ? xx : eps;
yy = yy > eps ? yy : eps;
xx = sqrt(xx); xx = sqrt(xx);
yy = sqrt(yy); yy = sqrt(yy);
if (row_id == 0) y_norm_[0] = yy; if (row_id == 0) y_norm_[0] = yy;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册