未验证 提交 8db66fc3 编写于 作者: C ceci3 提交者: GitHub

fix cos_sim, test=develop (#25017)

上级 7a6f4d64
......@@ -35,6 +35,7 @@ struct CosSimFunctor {
inline HOSTDEVICE void operator()(size_t row_id) const {
auto* x = x_ + cols_ * row_id;
T xx = 0, xy = 0, yy = 0;
T eps = 1e-8;
if (same_row) {
auto* y = y_ + cols_ * row_id;
T tep_x, tep_y;
......@@ -45,6 +46,8 @@ struct CosSimFunctor {
yy += tep_y * tep_y;
xy += tep_x * tep_y;
}
xx = xx > eps ? xx : eps;
yy = yy > eps ? yy : eps;
xx = sqrt(xx);
yy = sqrt(yy);
y_norm_[row_id] = yy;
......@@ -59,6 +62,8 @@ struct CosSimFunctor {
yy += tep_y * tep_y;
xy += tep_x * tep_y;
}
xx = xx > eps ? xx : eps;
yy = yy > eps ? yy : eps;
xx = sqrt(xx);
yy = sqrt(yy);
if (row_id == 0) y_norm_[0] = yy;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册