提交 18c46eb7 编写于 作者: J JiabinYang

add feature to use third_party vocab and add acc test

上级 d480a0d1
......@@ -61,6 +61,11 @@ sh cluster_train.sh
您也可以在`build_test_case`方法中模仿给出的例子增加自己的测试
要从测试文件运行测试用例,请将测试文件下载到“test”目录中
我们为每个案例提供以下结构的测试:
`word1 word2 word3 word4`
所以我们可以将它构建成`word1 - word2 + word3 = word4`
训练中预测:
```bash
......
......@@ -65,6 +65,11 @@ For: boy - girl + aunt = uncle
You can also add your own tests by mimicking the examples given in the `build_test_case` method.
To running test case from test files, please download the test files into 'test' directory
we provide test for each case with the following structure:
`word1 word2 word3 word4`
so we can build it into `word1 - word2 + word3 = word4`
Forecast in training:
```bash
......
......@@ -2,4 +2,3 @@
wget http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz
tar -zxvf 1-billion-word-language-modeling-benchmark-r13output.tar.gz
import paddle
import time
import os
import paddle.fluid as fluid
......@@ -6,6 +5,7 @@ import numpy as np
from Queue import PriorityQueue
import logging
import argparse
import preprocess
from sklearn.metrics.pairwise import cosine_similarity
word_to_id = dict()
......@@ -47,6 +47,22 @@ def parse_args():
required=False,
default=True,
help='if using infer_during_train, (default: True)')
parser.add_argument(
'--test_acc',
action='store_true',
required=False,
default=True,
help='if using test_files , (default: True)')
parser.add_argument(
'--test_files_dir',
type=str,
default='test',
help="The path for test_files) (default: test)")
parser.add_argument(
'--test_batch_size',
type=int,
default=1000,
help="test used batch size (default: 1000)")
return parser.parse_args()
......@@ -58,48 +74,119 @@ def BuildWord_IdMap(dict_path):
id_to_word[int(line.split(' ')[1])] = line.split(' ')[0]
def inference_prog():
def inference_prog(): # just to create program for test
fluid.layers.create_parameter(
shape=[1, 1], dtype='float32', name="embeding")
def build_test_case(emb):
def build_test_case_from_file(args, emb):
logger.info("test files dir: {}".format(args.test_files_dir))
current_list = os.listdir(args.test_files_dir)
logger.info("test files list: {}".format(current_list))
test_cases = list()
test_labels = list()
exclude_lists = list()
for file_dir in current_list:
with open(args.test_files_dir + "/" + file_dir, 'r') as f:
count = 0
for line in f:
if count == 0:
pass
elif ':' in line:
logger.info("{}".format(line))
pass
else:
line = preprocess.strip_lines(line, word_to_id)
test_case = emb[word_to_id[line.split()[0]]] - emb[
word_to_id[line.split()[1]]] + emb[word_to_id[
line.split()[2]]]
test_case_desc = line.split()[0] + " - " + line.split()[
1] + " + " + line.split()[2] + " = " + line.split()[3]
test_cases.append([test_case, test_case_desc])
test_labels.append(word_to_id[line.split()[3]])
exclude_lists.append([
word_to_id[line.split()[0]],
word_to_id[line.split()[1]], word_to_id[line.split()[2]]
])
count += 1
return test_cases, test_labels, exclude_lists
def build_small_test_case(emb):
emb1 = emb[word_to_id['boy']] - emb[word_to_id['girl']] + emb[word_to_id[
'aunt']]
desc1 = "boy - girl + aunt = uncle"
label1 = word_to_id["uncle"]
emb2 = emb[word_to_id['brother']] - emb[word_to_id['sister']] + emb[
word_to_id['sisters']]
desc2 = "brother - sister + sisters = brothers"
label2 = word_to_id["brothers"]
emb3 = emb[word_to_id['king']] - emb[word_to_id['queen']] + emb[word_to_id[
'woman']]
desc3 = "king - queen + woman = man"
label3 = word_to_id["man"]
emb4 = emb[word_to_id['reluctant']] - emb[word_to_id['reluctantly']] + emb[
word_to_id['slowly']]
desc4 = "reluctant - reluctantly + slowly = slow"
label4 = word_to_id["slow"]
emb5 = emb[word_to_id['old']] - emb[word_to_id['older']] + emb[word_to_id[
'deeper']]
desc5 = "old - older + deeper = deep"
label5 = word_to_id["deep"]
return [[emb1, desc1], [emb2, desc2], [emb3, desc3], [emb4, desc4],
[emb5, desc5]]
[emb5, desc5]], [label1, label2, label3, label4, label5]
def build_test_case(args, emb):
if args.test_acc:
return build_test_case_from_file(args, emb)
else:
return build_small_test_case(emb)
def inference_test(scope, model_dir, args):
BuildWord_IdMap(args.dict_path)
logger.info("model_dir is: {}".format(model_dir + "/"))
emb = np.array(scope.find_var("embeding").get_tensor())
test_cases = build_test_case(emb)
logger.info("inference result: ====================")
for case in test_cases:
pq = topK(args.rank_num, emb, case[0])
logger.info("Test result for {}".format(case[1]))
pq_tmps = list()
for i in range(args.rank_num):
pq_tmps.append(pq.get())
for i in range(len(pq_tmps)):
logger.info("{} nearest is {}, rate is {}".format(i, id_to_word[
pq_tmps[len(pq_tmps) - 1 - i].id], pq_tmps[len(pq_tmps) - 1 - i]
.priority))
del pq_tmps[:]
test_cases = list()
test_labels = list()
exclude_lists = list()
if args.test_acc:
test_cases, test_labels, exclude_lists = build_test_case(args, emb)
else:
test_cases, test_labels = build_test_case(args, emb)
exclude_lists = [[-1]]
accual_rank = 1 if args.test_acc else args.rank_num
correct_num = 0
for i in range(len(test_labels)):
pq = None
if args.test_acc:
pq = topK(
accual_rank,
emb,
test_cases[i][0],
exclude_lists[i],
is_acc=True)
else:
pq = pq = topK(
accual_rank,
emb,
test_cases[i][0],
exclude_lists[0],
is_acc=False)
logger.info("Test result for {}".format(test_cases[i][1]))
for j in range(accual_rank):
pq_tmps = pq.get()
if (j == accual_rank - 1) and (
pq_tmps.id == test_labels[i]
): # if the nearest word is what we want
correct_num += 1
logger.info("{} nearest is {}, rate is {}".format(
accual_rank - j, id_to_word[pq_tmps.id], pq_tmps.priority))
acc = correct_num / len(test_labels)
logger.info("Test acc is: {}, there are {} / {}}".format(acc, correct_num,
len(test_labels)))
class PQ_Entry(object):
......@@ -111,7 +198,7 @@ class PQ_Entry(object):
return cmp(self.priority, other.priority)
def topK(k, emb, test_emb):
def topK(k, emb, test_emb, exclude_list, is_acc=False):
pq = PriorityQueue(k + 1)
while not pq.empty():
try:
......@@ -127,11 +214,14 @@ def topK(k, emb, test_emb):
return pq
for i in range(len(emb)):
x = cosine_similarity([emb[i]], [test_emb])
pq_e = PQ_Entry(x, i)
if pq.full():
pq.get()
pq.put(pq_e)
if is_acc and (i in exclude_list):
pass
else:
x = cosine_similarity([emb[i]], [test_emb])
pq_e = PQ_Entry(x, i)
if pq.full():
pq.get()
pq.put(pq_e)
pq.get()
return pq
......
# -*- coding: utf-8 -*
import re
import six
import argparse
prog = re.compile("[^a-z ]", flags=0)
word_count = dict()
def parse_args():
parser = argparse.ArgumentParser(
......@@ -29,11 +33,75 @@ def parse_args():
default=False,
help='Local train or not, (default: False)')
parser.add_argument(
'--with_other_dict',
action='store_true',
required=False,
default=False,
help='Using third party provided dict , (default: False)')
parser.add_argument(
'--other_dict_path',
type=str,
default='',
help='The path for third party provided dict (default: '
')')
return parser.parse_args()
def text_strip(text):
return re.sub("[^a-z ]", "", text)
return prog.sub("", text)
# users can self-define their own strip rules by modifing this method
def strip_lines(line, vocab=word_count):
return _replace_oov(vocab, native_to_unicode(line))
# Shameless copy from Tensorflow https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/text_encoder.py
def _replace_oov(original_vocab, line):
"""Replace out-of-vocab words with "<UNK>".
This maintains compatibility with published results.
Args:
original_vocab: a set of strings (The standard vocabulary for the dataset)
line: a unicode string - a space-delimited sequence of words.
Returns:
a unicode string - a space-delimited sequence of words.
"""
return u" ".join([
word if word in original_vocab else u"<UNK>" for word in line.split()
])
# Shameless copy from Tensorflow https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/text_encoder.py
# Unicode utility functions that work with Python 2 and 3
def native_to_unicode(s):
if _is_unicode(s):
return s
try:
return _to_unicode(s)
except UnicodeDecodeError:
res = _to_unicode(s, ignore_errors=True)
tf.logging.info("Ignoring Unicode error, outputting: %s" % res)
return res
def _is_unicode(s):
if six.PY2:
if isinstance(s, unicode):
return True
else:
if isinstance(s, str):
return True
return False
def _to_unicode(s, ignore_errors=False):
if _is_unicode(s):
return s
error_mode = "ignore" if ignore_errors else "strict"
return s.decode("utf-8", errors=error_mode)
def build_Huffman(word_count, max_code_length):
......@@ -120,7 +188,7 @@ def build_Huffman(word_count, max_code_length):
return word_point, word_code, word_code_len
def preprocess(data_path, dict_path, freq, is_local):
def preprocess(args):
"""
proprocess the data, generate dictionary and save into dict_path.
:param data_path: the input data path.
......@@ -129,43 +197,61 @@ def preprocess(data_path, dict_path, freq, is_local):
:return:
"""
# word to count
word_count = dict()
if is_local:
if args.with_other_dict:
with open(args.other_dict_path, 'r') as f:
for line in f:
word_count[native_to_unicode(line.strip())] = 1
if args.is_local:
for i in range(1, 100):
with open(data_path + "/news.en-000{:0>2d}-of-00100".format(
with open(args.data_path + "/news.en-000{:0>2d}-of-00100".format(
i)) as f:
for line in f:
line = line.lower()
line = text_strip(line)
line = strip_lines(line)
words = line.split()
for item in words:
if item in word_count:
word_count[item] = word_count[item] + 1
else:
word_count[item] = 1
word_count[native_to_unicode('<UNK>')] += 1
# with open(args.data_path + "/tmp.txt") as f:
# for line in f:
# print("line before strip is: {}".format(line))
# line = strip_lines(line, word_count)
# print("line after strip is: {}".format(line))
# words = line.split()
# print("words after split is: {}".format(words))
# for item in words:
# if item in word_count:
# word_count[item] = word_count[item] + 1
# else:
# word_count[item] = 1
item_to_remove = []
for item in word_count:
if word_count[item] <= freq:
if word_count[item] <= args.freq:
item_to_remove.append(item)
for item in item_to_remove:
del word_count[item]
path_table, path_code, word_code_len = build_Huffman(word_count, 40)
with open(dict_path, 'w+') as f:
with open(args.dict_path, 'w+') as f:
for k, v in word_count.items():
f.write(str(k) + " " + str(v) + '\n')
f.write(k.encode("utf-8") + " " + str(v).encode("utf-8") + '\n')
with open(dict_path + "_ptable", 'w+') as f2:
with open(args.dict_path + "_ptable", 'w+') as f2:
for pk, pv in path_table.items():
f2.write(str(pk) + ":" + ' '.join((str(x) for x in pv)) + '\n')
f2.write(
pk.encode("utf-8") + "\t" + ' '.join((str(x).encode("utf-8")
for x in pv)) + '\n')
with open(dict_path + "_pcode", 'w+') as f3:
for pck, pcv in path_table.items():
f3.write(str(pck) + ":" + ' '.join((str(x) for x in pcv)) + '\n')
with open(args.dict_path + "_pcode", 'w+') as f3:
for pck, pcv in path_code.items():
f3.write(
pck.encode("utf-8") + "\t" + ' '.join((str(x).encode("utf-8")
for x in pcv)) + '\n')
if __name__ == "__main__":
args = parse_args()
preprocess(args.data_path, args.dict_path, args.freq, args.is_local)
preprocess(parse_args())
......@@ -35,6 +35,7 @@ class Word2VecReader(object):
with open(dict_path, 'r') as f:
for line in f:
line = line.decode(encoding='UTF-8')
word, count = line.split()[0], int(line.split()[1])
self.word_to_id_[word] = word_id
self.id_to_word[word_id] = word #build id to word dict
......@@ -44,7 +45,8 @@ class Word2VecReader(object):
with open(dict_path + "_word_to_id_", 'w+') as f6:
for k, v in self.word_to_id_.items():
f6.write(str(k) + " " + str(v) + '\n')
f6.write(
k.encode("utf-8") + " " + str(v).encode("utf-8") + '\n')
self.dict_size = len(self.word_to_id_)
self.word_frequencys = [
......@@ -55,16 +57,17 @@ class Word2VecReader(object):
with open(dict_path + "_ptable", 'r') as f2:
for line in f2:
self.word_to_path[line.split(":")[0]] = np.fromstring(
line.split(':')[1], dtype=int, sep=' ')
self.word_to_path[line.split("\t")[0]] = np.fromstring(
line.split('\t')[1], dtype=int, sep=' ')
self.num_non_leaf = np.fromstring(
line.split(':')[1], dtype=int, sep=' ')[0]
line.split('\t')[1], dtype=int, sep=' ')[0]
print("word_ptable dict_size = " + str(len(self.word_to_path)))
with open(dict_path + "_pcode", 'r') as f3:
for line in f3:
self.word_to_code[line.split(":")[0]] = np.fromstring(
line.split(':')[1], dtype=int, sep=' ')
line = line.decode(encoding='UTF-8')
self.word_to_code[line.split("\t")[0]] = np.fromstring(
line.split('\t')[1], dtype=int, sep=' ')
print("word_pcode dict_size = " + str(len(self.word_to_code)))
def get_context_words(self, words, idx, window_size):
......@@ -92,7 +95,7 @@ class Word2VecReader(object):
count = 1
for line in f:
if self.trainer_id == count % self.trainer_num:
line = preprocess.text_strip(line)
line = preprocess.strip_lines(line)
word_ids = [
self.word_to_id_[word] for word in line.split()
if word in self.word_to_id_
......@@ -114,7 +117,7 @@ class Word2VecReader(object):
count = 1
for line in f:
if self.trainer_id == count % self.trainer_num:
line = preprocess.text_strip(line)
line = preprocess.strip_lines(line)
word_ids = [
self.word_to_id_[word] for word in line.split()
if word in self.word_to_id_
......
from __future__ import print_function
import argparse
import logging
import os
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册