From 18c46eb702306db8f85045bf192fdf28e3d5dc3c Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Tue, 18 Dec 2018 13:39:00 +0000 Subject: [PATCH] add feature to use third_party vocab and add acc test --- fluid/PaddleRec/word2vec/README.cn.md | 5 + fluid/PaddleRec/word2vec/README.md | 5 + fluid/PaddleRec/word2vec/data/download.sh | 1 - fluid/PaddleRec/word2vec/infer.py | 134 ++++++++++++++++++---- fluid/PaddleRec/word2vec/preprocess.py | 122 +++++++++++++++++--- fluid/PaddleRec/word2vec/reader.py | 19 +-- fluid/PaddleRec/word2vec/train.py | 1 - 7 files changed, 237 insertions(+), 50 deletions(-) diff --git a/fluid/PaddleRec/word2vec/README.cn.md b/fluid/PaddleRec/word2vec/README.cn.md index 39a0fb17..076b3eef 100644 --- a/fluid/PaddleRec/word2vec/README.cn.md +++ b/fluid/PaddleRec/word2vec/README.cn.md @@ -61,6 +61,11 @@ sh cluster_train.sh 您也可以在`build_test_case`方法中模仿给出的例子增加自己的测试 +要从测试文件运行测试用例,请将测试文件下载到“test”目录中 +我们为每个案例提供以下结构的测试: + `word1 word2 word3 word4` +所以我们可以将它构建成`word1 - word2 + word3 = word4` + 训练中预测: ```bash diff --git a/fluid/PaddleRec/word2vec/README.md b/fluid/PaddleRec/word2vec/README.md index 17e56bfc..9ed321e6 100644 --- a/fluid/PaddleRec/word2vec/README.md +++ b/fluid/PaddleRec/word2vec/README.md @@ -65,6 +65,11 @@ For: boy - girl + aunt = uncle You can also add your own tests by mimicking the examples given in the `build_test_case` method. +To running test case from test files, please download the test files into 'test' directory +we provide test for each case with the following structure: + `word1 word2 word3 word4` +so we can build it into `word1 - word2 + word3 = word4` + Forecast in training: ```bash diff --git a/fluid/PaddleRec/word2vec/data/download.sh b/fluid/PaddleRec/word2vec/data/download.sh index 4ba05c63..22cde6d9 100644 --- a/fluid/PaddleRec/word2vec/data/download.sh +++ b/fluid/PaddleRec/word2vec/data/download.sh @@ -2,4 +2,3 @@ wget http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz tar -zxvf 1-billion-word-language-modeling-benchmark-r13output.tar.gz - diff --git a/fluid/PaddleRec/word2vec/infer.py b/fluid/PaddleRec/word2vec/infer.py index cc72b5b2..2e0c8ea3 100644 --- a/fluid/PaddleRec/word2vec/infer.py +++ b/fluid/PaddleRec/word2vec/infer.py @@ -1,4 +1,3 @@ -import paddle import time import os import paddle.fluid as fluid @@ -6,6 +5,7 @@ import numpy as np from Queue import PriorityQueue import logging import argparse +import preprocess from sklearn.metrics.pairwise import cosine_similarity word_to_id = dict() @@ -47,6 +47,22 @@ def parse_args(): required=False, default=True, help='if using infer_during_train, (default: True)') + parser.add_argument( + '--test_acc', + action='store_true', + required=False, + default=True, + help='if using test_files , (default: True)') + parser.add_argument( + '--test_files_dir', + type=str, + default='test', + help="The path for test_files) (default: test)") + parser.add_argument( + '--test_batch_size', + type=int, + default=1000, + help="test used batch size (default: 1000)") return parser.parse_args() @@ -58,48 +74,119 @@ def BuildWord_IdMap(dict_path): id_to_word[int(line.split(' ')[1])] = line.split(' ')[0] -def inference_prog(): +def inference_prog(): # just to create program for test fluid.layers.create_parameter( shape=[1, 1], dtype='float32', name="embeding") -def build_test_case(emb): +def build_test_case_from_file(args, emb): + logger.info("test files dir: {}".format(args.test_files_dir)) + current_list = os.listdir(args.test_files_dir) + logger.info("test files list: {}".format(current_list)) + test_cases = list() + test_labels = list() + exclude_lists = list() + for file_dir in current_list: + with open(args.test_files_dir + "/" + file_dir, 'r') as f: + count = 0 + for line in f: + if count == 0: + pass + elif ':' in line: + logger.info("{}".format(line)) + pass + else: + line = preprocess.strip_lines(line, word_to_id) + test_case = emb[word_to_id[line.split()[0]]] - emb[ + word_to_id[line.split()[1]]] + emb[word_to_id[ + line.split()[2]]] + test_case_desc = line.split()[0] + " - " + line.split()[ + 1] + " + " + line.split()[2] + " = " + line.split()[3] + test_cases.append([test_case, test_case_desc]) + test_labels.append(word_to_id[line.split()[3]]) + exclude_lists.append([ + word_to_id[line.split()[0]], + word_to_id[line.split()[1]], word_to_id[line.split()[2]] + ]) + count += 1 + return test_cases, test_labels, exclude_lists + + +def build_small_test_case(emb): emb1 = emb[word_to_id['boy']] - emb[word_to_id['girl']] + emb[word_to_id[ 'aunt']] desc1 = "boy - girl + aunt = uncle" + label1 = word_to_id["uncle"] emb2 = emb[word_to_id['brother']] - emb[word_to_id['sister']] + emb[ word_to_id['sisters']] desc2 = "brother - sister + sisters = brothers" + label2 = word_to_id["brothers"] emb3 = emb[word_to_id['king']] - emb[word_to_id['queen']] + emb[word_to_id[ 'woman']] desc3 = "king - queen + woman = man" + label3 = word_to_id["man"] emb4 = emb[word_to_id['reluctant']] - emb[word_to_id['reluctantly']] + emb[ word_to_id['slowly']] desc4 = "reluctant - reluctantly + slowly = slow" + label4 = word_to_id["slow"] emb5 = emb[word_to_id['old']] - emb[word_to_id['older']] + emb[word_to_id[ 'deeper']] desc5 = "old - older + deeper = deep" + label5 = word_to_id["deep"] return [[emb1, desc1], [emb2, desc2], [emb3, desc3], [emb4, desc4], - [emb5, desc5]] + [emb5, desc5]], [label1, label2, label3, label4, label5] + + +def build_test_case(args, emb): + if args.test_acc: + return build_test_case_from_file(args, emb) + else: + return build_small_test_case(emb) def inference_test(scope, model_dir, args): BuildWord_IdMap(args.dict_path) logger.info("model_dir is: {}".format(model_dir + "/")) emb = np.array(scope.find_var("embeding").get_tensor()) - test_cases = build_test_case(emb) logger.info("inference result: ====================") - for case in test_cases: - pq = topK(args.rank_num, emb, case[0]) - logger.info("Test result for {}".format(case[1])) - pq_tmps = list() - for i in range(args.rank_num): - pq_tmps.append(pq.get()) - for i in range(len(pq_tmps)): - logger.info("{} nearest is {}, rate is {}".format(i, id_to_word[ - pq_tmps[len(pq_tmps) - 1 - i].id], pq_tmps[len(pq_tmps) - 1 - i] - .priority)) - del pq_tmps[:] + test_cases = list() + test_labels = list() + exclude_lists = list() + if args.test_acc: + test_cases, test_labels, exclude_lists = build_test_case(args, emb) + else: + test_cases, test_labels = build_test_case(args, emb) + exclude_lists = [[-1]] + accual_rank = 1 if args.test_acc else args.rank_num + correct_num = 0 + for i in range(len(test_labels)): + pq = None + if args.test_acc: + pq = topK( + accual_rank, + emb, + test_cases[i][0], + exclude_lists[i], + is_acc=True) + else: + pq = pq = topK( + accual_rank, + emb, + test_cases[i][0], + exclude_lists[0], + is_acc=False) + logger.info("Test result for {}".format(test_cases[i][1])) + for j in range(accual_rank): + pq_tmps = pq.get() + if (j == accual_rank - 1) and ( + pq_tmps.id == test_labels[i] + ): # if the nearest word is what we want + correct_num += 1 + logger.info("{} nearest is {}, rate is {}".format( + accual_rank - j, id_to_word[pq_tmps.id], pq_tmps.priority)) + acc = correct_num / len(test_labels) + logger.info("Test acc is: {}, there are {} / {}}".format(acc, correct_num, + len(test_labels))) class PQ_Entry(object): @@ -111,7 +198,7 @@ class PQ_Entry(object): return cmp(self.priority, other.priority) -def topK(k, emb, test_emb): +def topK(k, emb, test_emb, exclude_list, is_acc=False): pq = PriorityQueue(k + 1) while not pq.empty(): try: @@ -127,11 +214,14 @@ def topK(k, emb, test_emb): return pq for i in range(len(emb)): - x = cosine_similarity([emb[i]], [test_emb]) - pq_e = PQ_Entry(x, i) - if pq.full(): - pq.get() - pq.put(pq_e) + if is_acc and (i in exclude_list): + pass + else: + x = cosine_similarity([emb[i]], [test_emb]) + pq_e = PQ_Entry(x, i) + if pq.full(): + pq.get() + pq.put(pq_e) pq.get() return pq diff --git a/fluid/PaddleRec/word2vec/preprocess.py b/fluid/PaddleRec/word2vec/preprocess.py index cb8dd100..d231f776 100644 --- a/fluid/PaddleRec/word2vec/preprocess.py +++ b/fluid/PaddleRec/word2vec/preprocess.py @@ -1,8 +1,12 @@ # -*- coding: utf-8 -* import re +import six import argparse +prog = re.compile("[^a-z ]", flags=0) +word_count = dict() + def parse_args(): parser = argparse.ArgumentParser( @@ -29,11 +33,75 @@ def parse_args(): default=False, help='Local train or not, (default: False)') + parser.add_argument( + '--with_other_dict', + action='store_true', + required=False, + default=False, + help='Using third party provided dict , (default: False)') + + parser.add_argument( + '--other_dict_path', + type=str, + default='', + help='The path for third party provided dict (default: ' + ')') + return parser.parse_args() def text_strip(text): - return re.sub("[^a-z ]", "", text) + return prog.sub("", text) + + +# users can self-define their own strip rules by modifing this method +def strip_lines(line, vocab=word_count): + return _replace_oov(vocab, native_to_unicode(line)) + + +# Shameless copy from Tensorflow https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/text_encoder.py +def _replace_oov(original_vocab, line): + """Replace out-of-vocab words with "". + This maintains compatibility with published results. + Args: + original_vocab: a set of strings (The standard vocabulary for the dataset) + line: a unicode string - a space-delimited sequence of words. + Returns: + a unicode string - a space-delimited sequence of words. + """ + return u" ".join([ + word if word in original_vocab else u"" for word in line.split() + ]) + + +# Shameless copy from Tensorflow https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/text_encoder.py +# Unicode utility functions that work with Python 2 and 3 +def native_to_unicode(s): + if _is_unicode(s): + return s + try: + return _to_unicode(s) + except UnicodeDecodeError: + res = _to_unicode(s, ignore_errors=True) + tf.logging.info("Ignoring Unicode error, outputting: %s" % res) + return res + + +def _is_unicode(s): + if six.PY2: + if isinstance(s, unicode): + return True + else: + if isinstance(s, str): + return True + return False + + +def _to_unicode(s, ignore_errors=False): + if _is_unicode(s): + return s + error_mode = "ignore" if ignore_errors else "strict" + return s.decode("utf-8", errors=error_mode) def build_Huffman(word_count, max_code_length): @@ -120,7 +188,7 @@ def build_Huffman(word_count, max_code_length): return word_point, word_code, word_code_len -def preprocess(data_path, dict_path, freq, is_local): +def preprocess(args): """ proprocess the data, generate dictionary and save into dict_path. :param data_path: the input data path. @@ -129,43 +197,61 @@ def preprocess(data_path, dict_path, freq, is_local): :return: """ # word to count - word_count = dict() - if is_local: + if args.with_other_dict: + with open(args.other_dict_path, 'r') as f: + for line in f: + word_count[native_to_unicode(line.strip())] = 1 + + if args.is_local: for i in range(1, 100): - with open(data_path + "/news.en-000{:0>2d}-of-00100".format( + with open(args.data_path + "/news.en-000{:0>2d}-of-00100".format( i)) as f: for line in f: - line = line.lower() - line = text_strip(line) + line = strip_lines(line) words = line.split() for item in words: if item in word_count: word_count[item] = word_count[item] + 1 else: - word_count[item] = 1 + word_count[native_to_unicode('')] += 1 + # with open(args.data_path + "/tmp.txt") as f: + # for line in f: + # print("line before strip is: {}".format(line)) + # line = strip_lines(line, word_count) + # print("line after strip is: {}".format(line)) + # words = line.split() + # print("words after split is: {}".format(words)) + # for item in words: + # if item in word_count: + # word_count[item] = word_count[item] + 1 + # else: + # word_count[item] = 1 item_to_remove = [] for item in word_count: - if word_count[item] <= freq: + if word_count[item] <= args.freq: item_to_remove.append(item) for item in item_to_remove: del word_count[item] path_table, path_code, word_code_len = build_Huffman(word_count, 40) - with open(dict_path, 'w+') as f: + with open(args.dict_path, 'w+') as f: for k, v in word_count.items(): - f.write(str(k) + " " + str(v) + '\n') + f.write(k.encode("utf-8") + " " + str(v).encode("utf-8") + '\n') - with open(dict_path + "_ptable", 'w+') as f2: + with open(args.dict_path + "_ptable", 'w+') as f2: for pk, pv in path_table.items(): - f2.write(str(pk) + ":" + ' '.join((str(x) for x in pv)) + '\n') + f2.write( + pk.encode("utf-8") + "\t" + ' '.join((str(x).encode("utf-8") + for x in pv)) + '\n') - with open(dict_path + "_pcode", 'w+') as f3: - for pck, pcv in path_table.items(): - f3.write(str(pck) + ":" + ' '.join((str(x) for x in pcv)) + '\n') + with open(args.dict_path + "_pcode", 'w+') as f3: + for pck, pcv in path_code.items(): + f3.write( + pck.encode("utf-8") + "\t" + ' '.join((str(x).encode("utf-8") + for x in pcv)) + '\n') if __name__ == "__main__": - args = parse_args() - preprocess(args.data_path, args.dict_path, args.freq, args.is_local) + preprocess(parse_args()) diff --git a/fluid/PaddleRec/word2vec/reader.py b/fluid/PaddleRec/word2vec/reader.py index 7755dc7b..ff7d79ec 100644 --- a/fluid/PaddleRec/word2vec/reader.py +++ b/fluid/PaddleRec/word2vec/reader.py @@ -35,6 +35,7 @@ class Word2VecReader(object): with open(dict_path, 'r') as f: for line in f: + line = line.decode(encoding='UTF-8') word, count = line.split()[0], int(line.split()[1]) self.word_to_id_[word] = word_id self.id_to_word[word_id] = word #build id to word dict @@ -44,7 +45,8 @@ class Word2VecReader(object): with open(dict_path + "_word_to_id_", 'w+') as f6: for k, v in self.word_to_id_.items(): - f6.write(str(k) + " " + str(v) + '\n') + f6.write( + k.encode("utf-8") + " " + str(v).encode("utf-8") + '\n') self.dict_size = len(self.word_to_id_) self.word_frequencys = [ @@ -55,16 +57,17 @@ class Word2VecReader(object): with open(dict_path + "_ptable", 'r') as f2: for line in f2: - self.word_to_path[line.split(":")[0]] = np.fromstring( - line.split(':')[1], dtype=int, sep=' ') + self.word_to_path[line.split("\t")[0]] = np.fromstring( + line.split('\t')[1], dtype=int, sep=' ') self.num_non_leaf = np.fromstring( - line.split(':')[1], dtype=int, sep=' ')[0] + line.split('\t')[1], dtype=int, sep=' ')[0] print("word_ptable dict_size = " + str(len(self.word_to_path))) with open(dict_path + "_pcode", 'r') as f3: for line in f3: - self.word_to_code[line.split(":")[0]] = np.fromstring( - line.split(':')[1], dtype=int, sep=' ') + line = line.decode(encoding='UTF-8') + self.word_to_code[line.split("\t")[0]] = np.fromstring( + line.split('\t')[1], dtype=int, sep=' ') print("word_pcode dict_size = " + str(len(self.word_to_code))) def get_context_words(self, words, idx, window_size): @@ -92,7 +95,7 @@ class Word2VecReader(object): count = 1 for line in f: if self.trainer_id == count % self.trainer_num: - line = preprocess.text_strip(line) + line = preprocess.strip_lines(line) word_ids = [ self.word_to_id_[word] for word in line.split() if word in self.word_to_id_ @@ -114,7 +117,7 @@ class Word2VecReader(object): count = 1 for line in f: if self.trainer_id == count % self.trainer_num: - line = preprocess.text_strip(line) + line = preprocess.strip_lines(line) word_ids = [ self.word_to_id_[word] for word in line.split() if word in self.word_to_id_ diff --git a/fluid/PaddleRec/word2vec/train.py b/fluid/PaddleRec/word2vec/train.py index c20d02f9..58f98430 100644 --- a/fluid/PaddleRec/word2vec/train.py +++ b/fluid/PaddleRec/word2vec/train.py @@ -1,5 +1,4 @@ from __future__ import print_function - import argparse import logging import os -- GitLab