提交 de58898f 编写于 作者: J JiabinYang

polish code and fix python3 error

上级 ac301305
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
import paddle.fluid as fluid import paddle.fluid as fluid
import numpy as np import numpy as np
from Queue import PriorityQueue from Queue import PriorityQueue
import heapq
import logging import logging
import argparse import argparse
import preprocess import preprocess
...@@ -196,27 +197,12 @@ def topK(k, cosine_similarity_list, exclude_list, is_acc=False): ...@@ -196,27 +197,12 @@ def topK(k, cosine_similarity_list, exclude_list, is_acc=False):
pass pass
return [[max, id]] return [[max, id]]
else: else:
pq = PriorityQueue(k + 1) result = list()
while not pq.empty(): result_index = np.argpartition(cosine_similarity_list, -k)[-k:]
try: for index in result_index:
pq.get(False) result.append([cosine_similarity_list[index], index])
except Empty: result.sort(reverse=True)
continue return result
pq.task_done()
if len(cosine_similarity_list) <= k:
for i in range(len(cosine_similarity_list)):
pq.put([cosine_similarity_list[i], i])
return pq
for i in range(len(cosine_similarity_list)):
if is_acc and (i in exclude_list):
pass
else:
if pq.full():
pq.get()
pq.put([cosine_similarity_list[i], i])
pq.get()
return pq
def topKs(k, cosine_similarity_matrix, exclude_lists, is_acc=False): def topKs(k, cosine_similarity_matrix, exclude_lists, is_acc=False):
...@@ -233,16 +219,7 @@ def topKs(k, cosine_similarity_matrix, exclude_lists, is_acc=False): ...@@ -233,16 +219,7 @@ def topKs(k, cosine_similarity_matrix, exclude_lists, is_acc=False):
tmp_pq = topK(k, cosine_similarity_matrix[i], exclude_lists[0], tmp_pq = topK(k, cosine_similarity_matrix[i], exclude_lists[0],
is_acc) is_acc)
result_queues.append(tmp_pq) result_queues.append(tmp_pq)
if is_acc and k == 1: # accelerate acc calculate
return result_queues return result_queues
else:
for pq in result_queues:
tmp_result = list()
for i in range(k):
tmp_result.append(pq.get())
tmp_result.reverse()
results.append(tmp_result)
return results
def infer_during_train(args): def infer_during_train(args):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册