diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index e90db0b67d78fa324f862e9c3ee0535960387eea..cc5ff8de5e99dcdffdcb2a501fa720fd1d4861ec 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -680,9 +680,9 @@ def cosine_similarity(x1, x2, dim=1, eps=1e-8): # [0.99806249 0.9817672 0.94987036] """ - w12 = sum(elementwise_mul(x1, x2), dim=dim) - w1 = sum(elementwise_mul(x1, x1), dim=dim) - w2 = sum(elementwise_mul(x2, x2), dim=dim) + w12 = sum(elementwise_mul(x1, x2), axis=dim) + w1 = sum(elementwise_mul(x1, x1), axis=dim) + w2 = sum(elementwise_mul(x2, x2), axis=dim) n12 = sqrt(clamp(w1 * w2, min=eps * eps)) cos_sim = w12 / n12 return cos_sim