提交 9eaab43a 编写于 作者: J JiabinYang

accelerate infer

上级 af6aebe8
...@@ -51,8 +51,8 @@ def parse_args(): ...@@ -51,8 +51,8 @@ def parse_args():
'--test_acc', '--test_acc',
action='store_true', action='store_true',
required=False, required=False,
default=True, default=False,
help='if using test_files , (default: True)') help='if using test_files , (default: False)')
parser.add_argument( parser.add_argument(
'--test_files_dir', '--test_files_dir',
type=str, type=str,
...@@ -85,6 +85,7 @@ def build_test_case_from_file(args, emb): ...@@ -85,6 +85,7 @@ def build_test_case_from_file(args, emb):
logger.info("test files list: {}".format(current_list)) logger.info("test files list: {}".format(current_list))
test_cases = list() test_cases = list()
test_labels = list() test_labels = list()
test_case_descs = list()
exclude_lists = list() exclude_lists = list()
for file_dir in current_list: for file_dir in current_list:
with open(args.test_files_dir + "/" + file_dir, 'r') as f: with open(args.test_files_dir + "/" + file_dir, 'r') as f:
...@@ -102,14 +103,16 @@ def build_test_case_from_file(args, emb): ...@@ -102,14 +103,16 @@ def build_test_case_from_file(args, emb):
line.split()[2]]] line.split()[2]]]
test_case_desc = line.split()[0] + " - " + line.split()[ test_case_desc = line.split()[0] + " - " + line.split()[
1] + " + " + line.split()[2] + " = " + line.split()[3] 1] + " + " + line.split()[2] + " = " + line.split()[3]
test_cases.append([test_case, test_case_desc]) test_cases.append(test_case)
test_case_descs.append(test_case_desc)
test_labels.append(word_to_id[line.split()[3]]) test_labels.append(word_to_id[line.split()[3]])
exclude_lists.append([ exclude_lists.append([
word_to_id[line.split()[0]], word_to_id[line.split()[0]],
word_to_id[line.split()[1]], word_to_id[line.split()[2]] word_to_id[line.split()[1]], word_to_id[line.split()[2]]
]) ])
count += 1 count += 1
return test_cases, test_labels, exclude_lists test_cases = norm(np.array(test_cases))
return test_cases, test_case_descs, test_labels, exclude_lists
def build_small_test_case(emb): def build_small_test_case(emb):
...@@ -133,8 +136,11 @@ def build_small_test_case(emb): ...@@ -133,8 +136,11 @@ def build_small_test_case(emb):
'deeper']] 'deeper']]
desc5 = "old - older + deeper = deep" desc5 = "old - older + deeper = deep"
label5 = word_to_id["deep"] label5 = word_to_id["deep"]
return [[emb1, desc1], [emb2, desc2], [emb3, desc3], [emb4, desc4],
[emb5, desc5]], [label1, label2, label3, label4, label5] test_cases = [emb1, emb2, emb3, emb4, emb5]
test_case_desc = [desc1, desc2, desc3, desc4, desc5]
test_labels = [label1, label2, label3, label4, label5]
return norm(np.array(test_cases)), test_case_desc, test_labels
def build_test_case(args, emb): def build_test_case(args, emb):
...@@ -144,86 +150,105 @@ def build_test_case(args, emb): ...@@ -144,86 +150,105 @@ def build_test_case(args, emb):
return build_small_test_case(emb) return build_small_test_case(emb)
def norm(x):
emb = np.linalg.norm(x, axis=1, keepdims=True)
return x / emb
def inference_test(scope, model_dir, args): def inference_test(scope, model_dir, args):
BuildWord_IdMap(args.dict_path) BuildWord_IdMap(args.dict_path)
logger.info("model_dir is: {}".format(model_dir + "/")) logger.info("model_dir is: {}".format(model_dir + "/"))
emb = np.array(scope.find_var("embeding").get_tensor()) emb = np.array(scope.find_var("embeding").get_tensor())
x = norm(emb)
logger.info("inference result: ====================") logger.info("inference result: ====================")
test_cases = list() test_cases = None
test_case_desc = list()
test_labels = list() test_labels = list()
exclude_lists = list() exclude_lists = list()
if args.test_acc: if args.test_acc:
test_cases, test_labels, exclude_lists = build_test_case(args, emb) test_cases, test_case_desc, test_labels, exclude_lists = build_test_case(
args, emb)
else: else:
test_cases, test_labels = build_test_case(args, emb) test_cases, test_case_desc, test_labels = build_test_case(args, emb)
exclude_lists = [[-1]] exclude_lists = [[-1]]
accual_rank = 1 if args.test_acc else args.rank_num accual_rank = 1 if args.test_acc else args.rank_num
correct_num = 0 correct_num = 0
cosine_similarity_matrix = np.dot(test_cases, x.T)
results = topKs(accual_rank, cosine_similarity_matrix, exclude_lists,
args.test_acc)
for i in range(len(test_labels)): for i in range(len(test_labels)):
pq = None logger.info("Test result for {}".format(test_case_desc[i]))
if args.test_acc: result = results[i]
pq = topK(
accual_rank,
emb,
test_cases[i][0],
exclude_lists[i],
is_acc=True)
else:
pq = pq = topK(
accual_rank,
emb,
test_cases[i][0],
exclude_lists[0],
is_acc=False)
logger.info("Test result for {}".format(test_cases[i][1]))
for j in range(accual_rank): for j in range(accual_rank):
pq_tmps = pq.get()
if (j == accual_rank - 1) and ( if (j == accual_rank - 1) and (
pq_tmps.id == test_labels[i] result[j][1] == test_labels[i]
): # if the nearest word is what we want ): # if the nearest word is what we want
correct_num += 1 correct_num += 1
logger.info("{} nearest is {}, rate is {}".format( logger.info("{} nearest is {}, rate is {}".format(j, id_to_word[
accual_rank - j, id_to_word[pq_tmps.id], pq_tmps.priority)) result[j][1]], result[j][0]))
acc = correct_num / len(test_labels) logger.info("Test acc is: {}, there are {} / {}".format(correct_num / len(
logger.info("Test acc is: {}, there are {} / {}}".format(acc, correct_num, test_labels), correct_num, len(test_labels)))
len(test_labels)))
def topK(k, cosine_similarity_list, exclude_list, is_acc=False):
class PQ_Entry(object): if k == 1 and is_acc: # accelerate acc calculate
def __init__(self, cos_similarity, id): max = cosine_similarity_list[0]
self.priority = cos_similarity id = 0
self.id = id for i in range(len(cosine_similarity_list)):
if cosine_similarity_list[i] >= max and (i not in exclude_list):
def __cmp__(self, other): max = cosine_similarity_list[i]
return cmp(self.priority, other.priority) id = i
else:
pass
def topK(k, emb, test_emb, exclude_list, is_acc=False): return [[max, id]]
pq = PriorityQueue(k + 1) else:
while not pq.empty(): pq = PriorityQueue(k + 1)
try: while not pq.empty():
pq.get(False) try:
except Empty: pq.get(False)
continue except Empty:
pq.task_done() continue
pq.task_done()
if len(emb) <= k: if len(cosine_similarity_list) <= k:
for i in range(len(emb)): for i in range(len(cosine_similarity_list)):
x = cosine_similarity([emb[i]], [test_emb]) pq.put([cosine_similarity_list[i], i])
pq.put(PQ_Entry(x, i)) return pq
for i in range(len(cosine_similarity_list)):
if is_acc and (i in exclude_list):
pass
else:
if pq.full():
pq.get()
pq.put([cosine_similarity_list[i], i])
pq.get()
return pq return pq
for i in range(len(emb)):
if is_acc and (i in exclude_list): def topKs(k, cosine_similarity_matrix, exclude_lists, is_acc=False):
pass results = list()
result_queues = list()
correct_num = 0
for i in range(cosine_similarity_matrix.shape[0]):
tmp_pq = None
if is_acc:
tmp_pq = topK(k, cosine_similarity_matrix[i], exclude_lists[i],
is_acc)
else: else:
x = cosine_similarity([emb[i]], [test_emb]) tmp_pq = topK(k, cosine_similarity_matrix[i], exclude_lists[0],
pq_e = PQ_Entry(x, i) is_acc)
if pq.full(): result_queues.append(tmp_pq)
pq.get() if is_acc and k == 1: # accelerate acc calculate
pq.put(pq_e) return result_queues
pq.get() else:
return pq for pq in result_queues:
tmp_result = list()
for i in range(k):
tmp_result.append(pq.get())
tmp_result.reverse()
results.append(tmp_result)
return results
def infer_during_train(args): def infer_during_train(args):
......
...@@ -222,19 +222,6 @@ def preprocess(args): ...@@ -222,19 +222,6 @@ def preprocess(args):
word_count[item] = word_count[item] + 1 word_count[item] = word_count[item] + 1
else: else:
word_count[item] = 1 word_count[item] = 1
# with open(args.data_path + "/tmp.txt") as f:
# for line in f:
# print("line before strip is: {}".format(line))
# line = strip_lines(line, word_count)
# print("line after strip is: {}".format(line))
# words = line.split()
# print("words after split is: {}".format(words))
# for item in words:
# if item in word_count:
# word_count[item] = word_count[item] + 1
# else:
# word_count[item] = 1
item_to_remove = [] item_to_remove = []
for item in word_count: for item in word_count:
if word_count[item] <= args.freq: if word_count[item] <= args.freq:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册