提交 4946f5f5 编写于 作者: J JiabinYang

for test w2v on small

上级 b8555a93
......@@ -164,7 +164,7 @@ def async_train_loop(args, train_program, dataset, loss, thread_num):
debug=False)
epoch_stop = time.time()
run_time = epoch_stop - epoch_start
lines = len(filelist) * 1000000.0
lines = 109984625
print("run epoch%d done, lines=%s, time=%d, sample/second=%s" %
(i + 1, lines, run_time, lines / run_time))
epoch_model = "word2vec_model/epoch" + str(i + 1)
......
......@@ -59,7 +59,7 @@ def parse_args():
parser.add_argument(
'--test_batch_size',
type=int,
default=1000,
default=100,
help="test used batch size (default: 1000)")
return parser.parse_args()
......
......@@ -198,10 +198,10 @@ def preprocess(args):
"""
# word to count
if args.with_other_dict:
with io.open(args.other_dict_path, 'r', encoding='utf-8') as f:
for line in f:
word_count[native_to_unicode(line.strip())] = 1
# if args.with_other_dict:
# with io.open(args.other_dict_path, 'r', encoding='utf-8') as f:
# for line in f:
# word_count[native_to_unicode(line.strip())] = 1
# if args.is_local:
# for i in range(1, 100):
......@@ -223,31 +223,37 @@ def preprocess(args):
# word_count[item] = word_count[item] + 1
# else:
# word_count[item] = 1
if args.is_local:
with io.open(args.data_path + "/text8", encoding='utf-8') as f:
for line in f:
if args.with_other_dict:
line = strip_lines(line)
words = line.split()
for item in words:
if item in word_count:
word_count[item] = word_count[item] + 1
else:
word_count[native_to_unicode('<UNK>')] += 1
else:
line = text_strip(line)
words = line.split()
for item in words:
if item in word_count:
word_count[item] = word_count[item] + 1
else:
word_count[item] = 1
item_to_remove = []
for item in word_count:
if word_count[item] <= args.freq:
item_to_remove.append(item)
for item in item_to_remove:
del word_count[item]
# if args.is_local:
# with io.open(args.data_path + "/text8", encoding='utf-8') as f:
# for line in f:
# if args.with_other_dict:
# line = strip_lines(line)
# words = line.split()
# for item in words:
# if item in word_count:
# word_count[item] = word_count[item] + 1
# else:
# word_count[native_to_unicode('<UNK>')] += 1
# else:
# line = text_strip(line)
# words = line.split()
# for item in words:
# if item in word_count:
# word_count[item] = word_count[item] + 1
# else:
# word_count[item] = 1
# item_to_remove = []
# for item in word_count:
# if word_count[item] <= args.freq:
# item_to_remove.append(item)
# for item in item_to_remove:
# del word_count[item]
with io.open(args.dict_path, 'r', encoding='utf-8') as f:
for line in f:
word, count = line.split()[0], int(line.split()[1])
word_count[word] = count
print(word_count)
path_table, path_code, word_code_len = build_Huffman(word_count, 40)
......
......@@ -141,6 +141,7 @@ def convert_python_to_tensor(batch_size, sample_reader, is_hs):
result = [[], []]
for sample in sample_reader():
for i, fea in enumerate(sample):
print(fea)
result[i].append(fea)
if len(result[0]) == batch_size:
tensor_result = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册