提交 2f3f35ac 编写于 作者: Z zhaopu

code formate

上级 221d2f82
......@@ -2,6 +2,7 @@
import collections
import os
# -- function --
def save_vocab(word_id_dict, vocab_file_name):
......@@ -10,12 +11,13 @@ def save_vocab(word_id_dict, vocab_file_name):
:param word_id_dict: dictionary with content of '{word, id}', 'word' is string type , 'id' is int type.
:param vocab_file_name: vocab file name.
"""
f = open(vocab_file_name,'w')
for(k, v) in word_id_dict.items():
f = open(vocab_file_name, 'w')
for (k, v) in word_id_dict.items():
f.write(k.encode('utf-8') + '\t' + str(v) + '\n')
print('save vocab to '+vocab_file_name)
print('save vocab to ' + vocab_file_name)
f.close()
def load_vocab(vocab_file_name):
"""
load vocab from file
......@@ -32,6 +34,7 @@ def load_vocab(vocab_file_name):
dict[kv[0]] = int(kv[1])
return dict
def build_vocab(file_name, vocab_max_size):
"""
build vacab.
......@@ -41,7 +44,7 @@ def build_vocab(file_name, vocab_max_size):
"""
words = []
for line in open(file_name):
words += line.decode('utf-8','ignore').strip().split()
words += line.decode('utf-8', 'ignore').strip().split()
counter = collections.Counter(words)
counter = sorted(counter.items(), key=lambda x: -x[1])
......@@ -53,6 +56,7 @@ def build_vocab(file_name, vocab_max_size):
word_id_dict['<EOS>'] = 1
return word_id_dict
def _read_by_fixed_length(file_name, word_id_dict, sentence_len=10):
"""
create reader, each sample with fixed length.
......@@ -62,21 +66,24 @@ def _read_by_fixed_length(file_name, word_id_dict, sentence_len=10):
:param sentence_len: each sample's length.
:return: data reader.
"""
def reader():
words = []
UNK = word_id_dict['<UNK>']
for line in open(file_name):
words += line.decode('utf-8','ignore').strip().split()
words += line.decode('utf-8', 'ignore').strip().split()
ids = [word_id_dict.get(w, UNK) for w in words]
words_len = len(words)
sentence_num = (words_len-1) // sentence_len
sentence_num = (words_len - 1) // sentence_len
count = 0
while count < sentence_num:
start = count * sentence_len
count += 1
yield ids[start:start+sentence_len], ids[start+1:start+sentence_len+1]
yield ids[start:start + sentence_len], ids[start + 1:start + sentence_len + 1]
return reader
def _read_by_line(file_name, min_sentence_length, max_sentence_length, word_id_dict):
"""
create reader, each line is a sample.
......@@ -87,10 +94,11 @@ def _read_by_line(file_name, min_sentence_length, max_sentence_length, word_id_d
:param word_id_dict: dictionary with content of '{word, id}', 'word' is string type , 'id' is int type.
:return: data reader.
"""
def reader():
UNK = word_id_dict['<UNK>']
for line in open(file_name):
words = line.decode('utf-8','ignore').strip().split()
words = line.decode('utf-8', 'ignore').strip().split()
if len(words) < min_sentence_length or len(words) > max_sentence_length:
continue
ids = [word_id_dict.get(w, UNK) for w in words]
......@@ -98,8 +106,10 @@ def _read_by_line(file_name, min_sentence_length, max_sentence_length, word_id_d
target = ids[1:]
target.append(word_id_dict['<EOS>'])
yield ids[:], target[:]
return reader
def _reader_creator_for_NGram(file_name, N, word_id_dict):
"""
create reader for ngram.
......@@ -110,25 +120,31 @@ def _reader_creator_for_NGram(file_name, N, word_id_dict):
:return: data reader.
"""
assert N >= 2
def reader():
words = []
UNK = word_id_dict['<UNK>']
for line in open(file_name):
words += line.decode('utf-8','ignore').strip().split()
words += line.decode('utf-8', 'ignore').strip().split()
ids = [word_id_dict.get(w, UNK) for w in words]
words_len = len(words)
for i in range(words_len-N-1):
yield tuple(ids[i:i+N])
for i in range(words_len - N - 1):
yield tuple(ids[i:i + N])
return reader
def train_data(train_file, min_sentence_length, max_sentence_length, word_id_dict):
return _read_by_line(train_file, min_sentence_length, max_sentence_length, word_id_dict)
def test_data(test_file, min_sentence_length, max_sentence_length, word_id_dict):
return _read_by_line(test_file, min_sentence_length, max_sentence_length, word_id_dict)
def train_data_for_NGram(train_file, N, word_id_dict):
return _reader_creator_for_NGram(train_file, N, word_id_dict)
def test_data_for_NGram(test_file, N, word_id_dict):
return _reader_creator_for_NGram(test_file, N, word_id_dict)
......@@ -82,7 +82,8 @@ def train():
# define data reader
train_reader = paddle.batch(
paddle.reader.shuffle(
reader.train_data_for_NGram(train_file, N, word_id_dict), buf_size=65536),
reader.train_data_for_NGram(train_file, N, word_id_dict),
buf_size=65536),
batch_size=32)
test_reader = paddle.batch(
......
......@@ -71,7 +71,8 @@ def train():
train_reader = paddle.batch(
paddle.reader.shuffle(
reader.train_data(
train_file, min_sentence_length, max_sentence_length, word_id_dict), buf_size=65536),
train_file, min_sentence_length,
max_sentence_length, word_id_dict), buf_size=65536),
batch_size=32)
test_reader = paddle.batch(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册