diff --git a/PaddleCV/metric_learning/utility.py b/PaddleCV/metric_learning/utility.py index f093f56d2dfa9a0ce147a525c143813d5ed59942..d45715900942f2b817116e4f3fc1faaa4f73d82b 100644 --- a/PaddleCV/metric_learning/utility.py +++ b/PaddleCV/metric_learning/utility.py @@ -74,9 +74,11 @@ def fmt_time(): now_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) return now_str -def recall_topk(fea, lab, k = 1): +def recall_topk(fea, lab, k=1): fea = np.array(fea) + lab = np.array(lab) fea = fea.reshape(fea.shape[0], -1) + lab = lab.reshape(-1) n = np.sqrt(np.sum(fea**2, 1)).reshape(-1, 1) fea = fea / n a = np.sum(fea ** 2, 1).reshape(-1, 1) @@ -87,8 +89,8 @@ def recall_topk(fea, lab, k = 1): sorted_index = np.argsort(d, 1) res = 0 for i in range(len(fea)): - pred = lab[sorted_index[i][0]] - if lab[i] == pred: + pred = lab[sorted_index[i][:k]] + if lab[i] in pred: res += 1.0 res = res / len(fea) return res