未验证 提交 7e71ae92 编写于 作者: C Chen Weihang 提交者: GitHub

fix test_cosine_similarity_api failed (#26467)

上级 a7cd61fd
...@@ -680,9 +680,9 @@ def cosine_similarity(x1, x2, dim=1, eps=1e-8): ...@@ -680,9 +680,9 @@ def cosine_similarity(x1, x2, dim=1, eps=1e-8):
# [0.99806249 0.9817672 0.94987036] # [0.99806249 0.9817672 0.94987036]
""" """
w12 = sum(elementwise_mul(x1, x2), dim=dim) w12 = sum(elementwise_mul(x1, x2), axis=dim)
w1 = sum(elementwise_mul(x1, x1), dim=dim) w1 = sum(elementwise_mul(x1, x1), axis=dim)
w2 = sum(elementwise_mul(x2, x2), dim=dim) w2 = sum(elementwise_mul(x2, x2), axis=dim)
n12 = sqrt(clamp(w1 * w2, min=eps * eps)) n12 = sqrt(clamp(w1 * w2, min=eps * eps))
cos_sim = w12 / n12 cos_sim = w12 / n12
return cos_sim return cos_sim
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册