提交 de58898f 编写于 作者: J JiabinYang

polish code and fix python3 error

上级 ac301305
......@@ -3,6 +3,7 @@ import os
import paddle.fluid as fluid
import numpy as np
from Queue import PriorityQueue
import heapq
import logging
import argparse
import preprocess
......@@ -196,27 +197,12 @@ def topK(k, cosine_similarity_list, exclude_list, is_acc=False):
pass
return [[max, id]]
else:
pq = PriorityQueue(k + 1)
while not pq.empty():
try:
pq.get(False)
except Empty:
continue
pq.task_done()
if len(cosine_similarity_list) <= k:
for i in range(len(cosine_similarity_list)):
pq.put([cosine_similarity_list[i], i])
return pq
for i in range(len(cosine_similarity_list)):
if is_acc and (i in exclude_list):
pass
else:
if pq.full():
pq.get()
pq.put([cosine_similarity_list[i], i])
pq.get()
return pq
result = list()
result_index = np.argpartition(cosine_similarity_list, -k)[-k:]
for index in result_index:
result.append([cosine_similarity_list[index], index])
result.sort(reverse=True)
return result
def topKs(k, cosine_similarity_matrix, exclude_lists, is_acc=False):
......@@ -233,16 +219,7 @@ def topKs(k, cosine_similarity_matrix, exclude_lists, is_acc=False):
tmp_pq = topK(k, cosine_similarity_matrix[i], exclude_lists[0],
is_acc)
result_queues.append(tmp_pq)
if is_acc and k == 1: # accelerate acc calculate
return result_queues
else:
for pq in result_queues:
tmp_result = list()
for i in range(k):
tmp_result.append(pq.get())
tmp_result.reverse()
results.append(tmp_result)
return results
def infer_during_train(args):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册