From 6bb873a7fd38be08d9c53a77319412c19fc3550f Mon Sep 17 00:00:00 2001 From: kebinC <1678073914@qq.com> Date: Thu, 22 Aug 2019 16:43:33 +0800 Subject: [PATCH] metric learning fic utility.py --- PaddleCV/metric_learning/utility.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/PaddleCV/metric_learning/utility.py b/PaddleCV/metric_learning/utility.py index f093f56d..d4571590 100644 --- a/PaddleCV/metric_learning/utility.py +++ b/PaddleCV/metric_learning/utility.py @@ -74,9 +74,11 @@ def fmt_time(): now_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) return now_str -def recall_topk(fea, lab, k = 1): +def recall_topk(fea, lab, k=1): fea = np.array(fea) + lab = np.array(lab) fea = fea.reshape(fea.shape[0], -1) + lab = lab.reshape(-1) n = np.sqrt(np.sum(fea**2, 1)).reshape(-1, 1) fea = fea / n a = np.sum(fea ** 2, 1).reshape(-1, 1) @@ -87,8 +89,8 @@ def recall_topk(fea, lab, k = 1): sorted_index = np.argsort(d, 1) res = 0 for i in range(len(fea)): - pred = lab[sorted_index[i][0]] - if lab[i] == pred: + pred = lab[sorted_index[i][:k]] + if lab[i] in pred: res += 1.0 res = res / len(fea) return res -- GitLab