提交 bbf4aa7f 编写于 作者: J JiabinYang

fix reader bug

上级 41351679
......@@ -131,9 +131,25 @@ def build_small_test_case(emb):
desc5 = "old - older + deeper = deep"
label5 = word_to_id["deep"]
test_cases = [emb1, emb2, emb3, emb4, emb5]
test_case_desc = [desc1, desc2, desc3, desc4, desc5]
test_labels = [label1, label2, label3, label4, label5]
emb6 = emb[word_to_id['boy']]
desc6 = "boy"
label6 = word_to_id["boy"]
emb7 = emb[word_to_id['king']]
desc7 = "king"
label7 = word_to_id["king"]
emb8 = emb[word_to_id['sun']]
desc8 = "sun"
label8 = word_to_id["sun"]
emb9 = emb[word_to_id['key']]
desc9 = "key"
label9 = word_to_id["key"]
test_cases = [emb1, emb2, emb3, emb4, emb5, emb6, emb7, emb8, emb9]
test_case_desc = [
desc1, desc2, desc3, desc4, desc5, desc6, desc7, desc8, desc9
]
test_labels = [
label1, label2, label3, label4, label5, label6, label7, label8, label9
]
return norm(np.array(test_cases)), test_case_desc, test_labels
......@@ -229,8 +245,6 @@ def infer_during_train(args):
while True:
time.sleep(60)
current_list = os.listdir(args.model_output_dir)
# logger.info("current_list is : {}".format(current_list))
# logger.info("model_file_list is : {}".format(model_file_list))
if set(model_file_list) == set(current_list):
if solved_new:
solved_new = False
......
......@@ -3,6 +3,7 @@
import re
import six
import argparse
import io
prog = re.compile("[^a-z ]", flags=0)
word_count = dict()
......@@ -83,7 +84,6 @@ def native_to_unicode(s):
return _to_unicode(s)
except UnicodeDecodeError:
res = _to_unicode(s, ignore_errors=True)
tf.logging.info("Ignoring Unicode error, outputting: %s" % res)
return res
......@@ -199,14 +199,15 @@ def preprocess(args):
# word to count
if args.with_other_dict:
with open(args.other_dict_path, 'r') as f:
with io.open(args.other_dict_path, 'r', encoding='utf-8') as f:
for line in f:
word_count[native_to_unicode(line.strip())] = 1
if args.is_local:
for i in range(1, 100):
with open(args.data_path + "/news.en-000{:0>2d}-of-00100".format(
i)) as f:
with io.open(
args.data_path + "/news.en-000{:0>2d}-of-00100".format(i),
encoding='utf-8') as f:
for line in f:
line = strip_lines(line)
words = line.split()
......@@ -231,21 +232,17 @@ def preprocess(args):
path_table, path_code, word_code_len = build_Huffman(word_count, 40)
with open(args.dict_path, 'w+') as f:
with io.open(args.dict_path, 'w+', encoding='utf-8') as f:
for k, v in word_count.items():
f.write(k.encode("utf-8") + " " + str(v).encode("utf-8") + '\n')
f.write(k + " " + str(v) + '\n')
with open(args.dict_path + "_ptable", 'w+') as f2:
with io.open(args.dict_path + "_ptable", 'w+', encoding='utf-8') as f2:
for pk, pv in path_table.items():
f2.write(
pk.encode("utf-8") + '\t' + ' '.join((str(x).encode("utf-8")
for x in pv)) + '\n')
f2.write(pk + '\t' + ' '.join((str(x) for x in pv)) + '\n')
with open(args.dict_path + "_pcode", 'w+') as f3:
with io.open(args.dict_path + "_pcode", 'w+', encoding='utf-8') as f3:
for pck, pcv in path_code.items():
f3.write(
pck.encode("utf-8") + '\t' + ' '.join((str(x).encode("utf-8")
for x in pcv)) + '\n')
f3.write(pck + '\t' + ' '.join((str(x) for x in pcv)) + '\n')
if __name__ == "__main__":
......
......@@ -2,8 +2,8 @@
import numpy as np
import preprocess
import logging
import io
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
......@@ -42,6 +42,7 @@ class Word2VecReader(object):
self.num_non_leaf = 0
self.word_to_id_ = dict()
self.id_to_word = dict()
self.word_count = dict()
self.word_to_path = dict()
self.word_to_code = dict()
self.trainer_id = trainer_id
......@@ -51,20 +52,19 @@ class Word2VecReader(object):
word_counts = []
word_id = 0
with open(dict_path, 'r') as f:
with io.open(dict_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.decode(encoding='UTF-8')
word, count = line.split()[0], int(line.split()[1])
self.word_count[word] = count
self.word_to_id_[word] = word_id
self.id_to_word[word_id] = word #build id to word dict
word_id += 1
word_counts.append(count)
word_all_count += count
with open(dict_path + "_word_to_id_", 'w+') as f6:
with io.open(dict_path + "_word_to_id_", 'w+', encoding='utf-8') as f6:
for k, v in self.word_to_id_.items():
f6.write(
k.encode("utf-8") + " " + str(v).encode("utf-8") + '\n')
f6.write(k + " " + str(v) + '\n')
self.dict_size = len(self.word_to_id_)
self.word_frequencys = [
......@@ -73,7 +73,7 @@ class Word2VecReader(object):
print("dict_size = " + str(
self.dict_size)) + " word_all_count = " + str(word_all_count)
with open(dict_path + "_ptable", 'r') as f2:
with io.open(dict_path + "_ptable", 'r', encoding='utf-8') as f2:
for line in f2:
self.word_to_path[line.split('\t')[0]] = np.fromstring(
line.split('\t')[1], dtype=int, sep=' ')
......@@ -81,9 +81,8 @@ class Word2VecReader(object):
line.split('\t')[1], dtype=int, sep=' ')[0]
print("word_ptable dict_size = " + str(len(self.word_to_path)))
with open(dict_path + "_pcode", 'r') as f3:
with io.open(dict_path + "_pcode", 'r', encoding='utf-8') as f3:
for line in f3:
line = line.decode(encoding='UTF-8')
self.word_to_code[line.split('\t')[0]] = np.fromstring(
line.split('\t')[1], dtype=int, sep=' ')
print("word_pcode dict_size = " + str(len(self.word_to_code)))
......@@ -109,13 +108,15 @@ class Word2VecReader(object):
def train(self, with_hs):
def _reader():
for file in self.filelist:
with open(self.data_path_ + "/" + file, 'r') as f:
with io.open(
self.data_path_ + "/" + file, 'r',
encoding='utf-8') as f:
logger.info("running data in {}".format(self.data_path_ +
"/" + file))
count = 1
for line in f:
if self.trainer_id == count % self.trainer_num:
line = preprocess.strip_lines(line)
line = preprocess.strip_lines(line, self.word_count)
word_ids = [
self.word_to_id_[word] for word in line.split()
if word in self.word_to_id_
......@@ -131,13 +132,15 @@ class Word2VecReader(object):
def _reader_hs():
for file in self.filelist:
with open(self.data_path_ + "/" + file, 'r') as f:
with io.open(
self.data_path_ + "/" + file, 'r',
encoding='utf-8') as f:
logger.info("running data in {}".format(self.data_path_ +
"/" + file))
count = 1
for line in f:
if self.trainer_id == count % self.trainer_num:
line = preprocess.strip_lines(line)
line = preprocess.strip_lines(line, self.word_count)
word_ids = [
self.word_to_id_[word] for word in line.split()
if word in self.word_to_id_
......@@ -164,13 +167,20 @@ class Word2VecReader(object):
if __name__ == "__main__":
window_size = 10
window_size = 5
reader = Word2VecReader(
"./data/1-billion_dict",
"./data/1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled/",
["news.en-00001-of-00100"], 0, 1)
reader = Word2VecReader("data/enwik9_dict", "data/enwik9", window_size)
i = 0
for x, y in reader.train()():
# print(reader.train(True))
for x, y, z, f in reader.train(True)():
print("x: " + str(x))
print("y: " + str(y))
print("path: " + str(z))
print("code: " + str(f))
print("\n")
if i == 10:
exit(0)
......
......@@ -135,7 +135,6 @@ def convert_python_to_tensor(batch_size, sample_reader, is_hs):
for sample in sample_reader():
for i, fea in enumerate(sample):
result[i].append(fea)
if len(result[0]) == batch_size:
tensor_result = []
for tensor in result:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册