提交 9eaab43a 编写于 作者: J JiabinYang

accelerate infer

上级 af6aebe8
......@@ -51,8 +51,8 @@ def parse_args():
'--test_acc',
action='store_true',
required=False,
default=True,
help='if using test_files , (default: True)')
default=False,
help='if using test_files , (default: False)')
parser.add_argument(
'--test_files_dir',
type=str,
......@@ -85,6 +85,7 @@ def build_test_case_from_file(args, emb):
logger.info("test files list: {}".format(current_list))
test_cases = list()
test_labels = list()
test_case_descs = list()
exclude_lists = list()
for file_dir in current_list:
with open(args.test_files_dir + "/" + file_dir, 'r') as f:
......@@ -102,14 +103,16 @@ def build_test_case_from_file(args, emb):
line.split()[2]]]
test_case_desc = line.split()[0] + " - " + line.split()[
1] + " + " + line.split()[2] + " = " + line.split()[3]
test_cases.append([test_case, test_case_desc])
test_cases.append(test_case)
test_case_descs.append(test_case_desc)
test_labels.append(word_to_id[line.split()[3]])
exclude_lists.append([
word_to_id[line.split()[0]],
word_to_id[line.split()[1]], word_to_id[line.split()[2]]
])
count += 1
return test_cases, test_labels, exclude_lists
test_cases = norm(np.array(test_cases))
return test_cases, test_case_descs, test_labels, exclude_lists
def build_small_test_case(emb):
......@@ -133,8 +136,11 @@ def build_small_test_case(emb):
'deeper']]
desc5 = "old - older + deeper = deep"
label5 = word_to_id["deep"]
return [[emb1, desc1], [emb2, desc2], [emb3, desc3], [emb4, desc4],
[emb5, desc5]], [label1, label2, label3, label4, label5]
test_cases = [emb1, emb2, emb3, emb4, emb5]
test_case_desc = [desc1, desc2, desc3, desc4, desc5]
test_labels = [label1, label2, label3, label4, label5]
return norm(np.array(test_cases)), test_case_desc, test_labels
def build_test_case(args, emb):
......@@ -144,86 +150,105 @@ def build_test_case(args, emb):
return build_small_test_case(emb)
def norm(x):
emb = np.linalg.norm(x, axis=1, keepdims=True)
return x / emb
def inference_test(scope, model_dir, args):
BuildWord_IdMap(args.dict_path)
logger.info("model_dir is: {}".format(model_dir + "/"))
emb = np.array(scope.find_var("embeding").get_tensor())
x = norm(emb)
logger.info("inference result: ====================")
test_cases = list()
test_cases = None
test_case_desc = list()
test_labels = list()
exclude_lists = list()
if args.test_acc:
test_cases, test_labels, exclude_lists = build_test_case(args, emb)
test_cases, test_case_desc, test_labels, exclude_lists = build_test_case(
args, emb)
else:
test_cases, test_labels = build_test_case(args, emb)
test_cases, test_case_desc, test_labels = build_test_case(args, emb)
exclude_lists = [[-1]]
accual_rank = 1 if args.test_acc else args.rank_num
correct_num = 0
cosine_similarity_matrix = np.dot(test_cases, x.T)
results = topKs(accual_rank, cosine_similarity_matrix, exclude_lists,
args.test_acc)
for i in range(len(test_labels)):
pq = None
if args.test_acc:
pq = topK(
accual_rank,
emb,
test_cases[i][0],
exclude_lists[i],
is_acc=True)
else:
pq = pq = topK(
accual_rank,
emb,
test_cases[i][0],
exclude_lists[0],
is_acc=False)
logger.info("Test result for {}".format(test_cases[i][1]))
logger.info("Test result for {}".format(test_case_desc[i]))
result = results[i]
for j in range(accual_rank):
pq_tmps = pq.get()
if (j == accual_rank - 1) and (
pq_tmps.id == test_labels[i]
result[j][1] == test_labels[i]
): # if the nearest word is what we want
correct_num += 1
logger.info("{} nearest is {}, rate is {}".format(
accual_rank - j, id_to_word[pq_tmps.id], pq_tmps.priority))
acc = correct_num / len(test_labels)
logger.info("Test acc is: {}, there are {} / {}}".format(acc, correct_num,
len(test_labels)))
class PQ_Entry(object):
def __init__(self, cos_similarity, id):
self.priority = cos_similarity
self.id = id
def __cmp__(self, other):
return cmp(self.priority, other.priority)
def topK(k, emb, test_emb, exclude_list, is_acc=False):
pq = PriorityQueue(k + 1)
while not pq.empty():
try:
pq.get(False)
except Empty:
continue
pq.task_done()
if len(emb) <= k:
for i in range(len(emb)):
x = cosine_similarity([emb[i]], [test_emb])
pq.put(PQ_Entry(x, i))
logger.info("{} nearest is {}, rate is {}".format(j, id_to_word[
result[j][1]], result[j][0]))
logger.info("Test acc is: {}, there are {} / {}".format(correct_num / len(
test_labels), correct_num, len(test_labels)))
def topK(k, cosine_similarity_list, exclude_list, is_acc=False):
if k == 1 and is_acc: # accelerate acc calculate
max = cosine_similarity_list[0]
id = 0
for i in range(len(cosine_similarity_list)):
if cosine_similarity_list[i] >= max and (i not in exclude_list):
max = cosine_similarity_list[i]
id = i
else:
pass
return [[max, id]]
else:
pq = PriorityQueue(k + 1)
while not pq.empty():
try:
pq.get(False)
except Empty:
continue
pq.task_done()
if len(cosine_similarity_list) <= k:
for i in range(len(cosine_similarity_list)):
pq.put([cosine_similarity_list[i], i])
return pq
for i in range(len(cosine_similarity_list)):
if is_acc and (i in exclude_list):
pass
else:
if pq.full():
pq.get()
pq.put([cosine_similarity_list[i], i])
pq.get()
return pq
for i in range(len(emb)):
if is_acc and (i in exclude_list):
pass
def topKs(k, cosine_similarity_matrix, exclude_lists, is_acc=False):
results = list()
result_queues = list()
correct_num = 0
for i in range(cosine_similarity_matrix.shape[0]):
tmp_pq = None
if is_acc:
tmp_pq = topK(k, cosine_similarity_matrix[i], exclude_lists[i],
is_acc)
else:
x = cosine_similarity([emb[i]], [test_emb])
pq_e = PQ_Entry(x, i)
if pq.full():
pq.get()
pq.put(pq_e)
pq.get()
return pq
tmp_pq = topK(k, cosine_similarity_matrix[i], exclude_lists[0],
is_acc)
result_queues.append(tmp_pq)
if is_acc and k == 1: # accelerate acc calculate
return result_queues
else:
for pq in result_queues:
tmp_result = list()
for i in range(k):
tmp_result.append(pq.get())
tmp_result.reverse()
results.append(tmp_result)
return results
def infer_during_train(args):
......
......@@ -222,19 +222,6 @@ def preprocess(args):
word_count[item] = word_count[item] + 1
else:
word_count[item] = 1
# with open(args.data_path + "/tmp.txt") as f:
# for line in f:
# print("line before strip is: {}".format(line))
# line = strip_lines(line, word_count)
# print("line after strip is: {}".format(line))
# words = line.split()
# print("words after split is: {}".format(words))
# for item in words:
# if item in word_count:
# word_count[item] = word_count[item] + 1
# else:
# word_count[item] = 1
item_to_remove = []
for item in word_count:
if word_count[item] <= args.freq:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册