From de58898f963de12cce836b934bbdc69aecbb64c2 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Fri, 28 Dec 2018 03:39:54 +0000 Subject: [PATCH] polish code and fix python3 error --- fluid/PaddleRec/word2vec/infer.py | 39 +++++++------------------------ 1 file changed, 8 insertions(+), 31 deletions(-) diff --git a/fluid/PaddleRec/word2vec/infer.py b/fluid/PaddleRec/word2vec/infer.py index 57875610..9ed42d1c 100644 --- a/fluid/PaddleRec/word2vec/infer.py +++ b/fluid/PaddleRec/word2vec/infer.py @@ -3,6 +3,7 @@ import os import paddle.fluid as fluid import numpy as np from Queue import PriorityQueue +import heapq import logging import argparse import preprocess @@ -196,27 +197,12 @@ def topK(k, cosine_similarity_list, exclude_list, is_acc=False): pass return [[max, id]] else: - pq = PriorityQueue(k + 1) - while not pq.empty(): - try: - pq.get(False) - except Empty: - continue - pq.task_done() - if len(cosine_similarity_list) <= k: - for i in range(len(cosine_similarity_list)): - pq.put([cosine_similarity_list[i], i]) - return pq - - for i in range(len(cosine_similarity_list)): - if is_acc and (i in exclude_list): - pass - else: - if pq.full(): - pq.get() - pq.put([cosine_similarity_list[i], i]) - pq.get() - return pq + result = list() + result_index = np.argpartition(cosine_similarity_list, -k)[-k:] + for index in result_index: + result.append([cosine_similarity_list[index], index]) + result.sort(reverse=True) + return result def topKs(k, cosine_similarity_matrix, exclude_lists, is_acc=False): @@ -233,16 +219,7 @@ def topKs(k, cosine_similarity_matrix, exclude_lists, is_acc=False): tmp_pq = topK(k, cosine_similarity_matrix[i], exclude_lists[0], is_acc) result_queues.append(tmp_pq) - if is_acc and k == 1: # accelerate acc calculate - return result_queues - else: - for pq in result_queues: - tmp_result = list() - for i in range(k): - tmp_result.append(pq.get()) - tmp_result.reverse() - results.append(tmp_result) - return results + return result_queues def infer_during_train(args): -- GitLab