提交 2f3f35ac 编写于 作者: Z zhaopu

code formate

上级 221d2f82
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import collections import collections
import os import os
# -- function -- # -- function --
def save_vocab(word_id_dict, vocab_file_name): def save_vocab(word_id_dict, vocab_file_name):
...@@ -10,12 +11,13 @@ def save_vocab(word_id_dict, vocab_file_name): ...@@ -10,12 +11,13 @@ def save_vocab(word_id_dict, vocab_file_name):
:param word_id_dict: dictionary with content of '{word, id}', 'word' is string type , 'id' is int type. :param word_id_dict: dictionary with content of '{word, id}', 'word' is string type , 'id' is int type.
:param vocab_file_name: vocab file name. :param vocab_file_name: vocab file name.
""" """
f = open(vocab_file_name,'w') f = open(vocab_file_name, 'w')
for(k, v) in word_id_dict.items(): for (k, v) in word_id_dict.items():
f.write(k.encode('utf-8') + '\t' + str(v) + '\n') f.write(k.encode('utf-8') + '\t' + str(v) + '\n')
print('save vocab to '+vocab_file_name) print('save vocab to ' + vocab_file_name)
f.close() f.close()
def load_vocab(vocab_file_name): def load_vocab(vocab_file_name):
""" """
load vocab from file load vocab from file
...@@ -32,6 +34,7 @@ def load_vocab(vocab_file_name): ...@@ -32,6 +34,7 @@ def load_vocab(vocab_file_name):
dict[kv[0]] = int(kv[1]) dict[kv[0]] = int(kv[1])
return dict return dict
def build_vocab(file_name, vocab_max_size): def build_vocab(file_name, vocab_max_size):
""" """
build vacab. build vacab.
...@@ -41,7 +44,7 @@ def build_vocab(file_name, vocab_max_size): ...@@ -41,7 +44,7 @@ def build_vocab(file_name, vocab_max_size):
""" """
words = [] words = []
for line in open(file_name): for line in open(file_name):
words += line.decode('utf-8','ignore').strip().split() words += line.decode('utf-8', 'ignore').strip().split()
counter = collections.Counter(words) counter = collections.Counter(words)
counter = sorted(counter.items(), key=lambda x: -x[1]) counter = sorted(counter.items(), key=lambda x: -x[1])
...@@ -53,6 +56,7 @@ def build_vocab(file_name, vocab_max_size): ...@@ -53,6 +56,7 @@ def build_vocab(file_name, vocab_max_size):
word_id_dict['<EOS>'] = 1 word_id_dict['<EOS>'] = 1
return word_id_dict return word_id_dict
def _read_by_fixed_length(file_name, word_id_dict, sentence_len=10): def _read_by_fixed_length(file_name, word_id_dict, sentence_len=10):
""" """
create reader, each sample with fixed length. create reader, each sample with fixed length.
...@@ -62,21 +66,24 @@ def _read_by_fixed_length(file_name, word_id_dict, sentence_len=10): ...@@ -62,21 +66,24 @@ def _read_by_fixed_length(file_name, word_id_dict, sentence_len=10):
:param sentence_len: each sample's length. :param sentence_len: each sample's length.
:return: data reader. :return: data reader.
""" """
def reader(): def reader():
words = [] words = []
UNK = word_id_dict['<UNK>'] UNK = word_id_dict['<UNK>']
for line in open(file_name): for line in open(file_name):
words += line.decode('utf-8','ignore').strip().split() words += line.decode('utf-8', 'ignore').strip().split()
ids = [word_id_dict.get(w, UNK) for w in words] ids = [word_id_dict.get(w, UNK) for w in words]
words_len = len(words) words_len = len(words)
sentence_num = (words_len-1) // sentence_len sentence_num = (words_len - 1) // sentence_len
count = 0 count = 0
while count < sentence_num: while count < sentence_num:
start = count * sentence_len start = count * sentence_len
count += 1 count += 1
yield ids[start:start+sentence_len], ids[start+1:start+sentence_len+1] yield ids[start:start + sentence_len], ids[start + 1:start + sentence_len + 1]
return reader return reader
def _read_by_line(file_name, min_sentence_length, max_sentence_length, word_id_dict): def _read_by_line(file_name, min_sentence_length, max_sentence_length, word_id_dict):
""" """
create reader, each line is a sample. create reader, each line is a sample.
...@@ -87,10 +94,11 @@ def _read_by_line(file_name, min_sentence_length, max_sentence_length, word_id_d ...@@ -87,10 +94,11 @@ def _read_by_line(file_name, min_sentence_length, max_sentence_length, word_id_d
:param word_id_dict: dictionary with content of '{word, id}', 'word' is string type , 'id' is int type. :param word_id_dict: dictionary with content of '{word, id}', 'word' is string type , 'id' is int type.
:return: data reader. :return: data reader.
""" """
def reader(): def reader():
UNK = word_id_dict['<UNK>'] UNK = word_id_dict['<UNK>']
for line in open(file_name): for line in open(file_name):
words = line.decode('utf-8','ignore').strip().split() words = line.decode('utf-8', 'ignore').strip().split()
if len(words) < min_sentence_length or len(words) > max_sentence_length: if len(words) < min_sentence_length or len(words) > max_sentence_length:
continue continue
ids = [word_id_dict.get(w, UNK) for w in words] ids = [word_id_dict.get(w, UNK) for w in words]
...@@ -98,8 +106,10 @@ def _read_by_line(file_name, min_sentence_length, max_sentence_length, word_id_d ...@@ -98,8 +106,10 @@ def _read_by_line(file_name, min_sentence_length, max_sentence_length, word_id_d
target = ids[1:] target = ids[1:]
target.append(word_id_dict['<EOS>']) target.append(word_id_dict['<EOS>'])
yield ids[:], target[:] yield ids[:], target[:]
return reader return reader
def _reader_creator_for_NGram(file_name, N, word_id_dict): def _reader_creator_for_NGram(file_name, N, word_id_dict):
""" """
create reader for ngram. create reader for ngram.
...@@ -110,25 +120,31 @@ def _reader_creator_for_NGram(file_name, N, word_id_dict): ...@@ -110,25 +120,31 @@ def _reader_creator_for_NGram(file_name, N, word_id_dict):
:return: data reader. :return: data reader.
""" """
assert N >= 2 assert N >= 2
def reader(): def reader():
words = [] words = []
UNK = word_id_dict['<UNK>'] UNK = word_id_dict['<UNK>']
for line in open(file_name): for line in open(file_name):
words += line.decode('utf-8','ignore').strip().split() words += line.decode('utf-8', 'ignore').strip().split()
ids = [word_id_dict.get(w, UNK) for w in words] ids = [word_id_dict.get(w, UNK) for w in words]
words_len = len(words) words_len = len(words)
for i in range(words_len-N-1): for i in range(words_len - N - 1):
yield tuple(ids[i:i+N]) yield tuple(ids[i:i + N])
return reader return reader
def train_data(train_file, min_sentence_length, max_sentence_length, word_id_dict): def train_data(train_file, min_sentence_length, max_sentence_length, word_id_dict):
return _read_by_line(train_file, min_sentence_length, max_sentence_length, word_id_dict) return _read_by_line(train_file, min_sentence_length, max_sentence_length, word_id_dict)
def test_data(test_file, min_sentence_length, max_sentence_length, word_id_dict): def test_data(test_file, min_sentence_length, max_sentence_length, word_id_dict):
return _read_by_line(test_file, min_sentence_length, max_sentence_length, word_id_dict) return _read_by_line(test_file, min_sentence_length, max_sentence_length, word_id_dict)
def train_data_for_NGram(train_file, N, word_id_dict): def train_data_for_NGram(train_file, N, word_id_dict):
return _reader_creator_for_NGram(train_file, N, word_id_dict) return _reader_creator_for_NGram(train_file, N, word_id_dict)
def test_data_for_NGram(test_file, N, word_id_dict): def test_data_for_NGram(test_file, N, word_id_dict):
return _reader_creator_for_NGram(test_file, N, word_id_dict) return _reader_creator_for_NGram(test_file, N, word_id_dict)
...@@ -82,7 +82,8 @@ def train(): ...@@ -82,7 +82,8 @@ def train():
# define data reader # define data reader
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
reader.train_data_for_NGram(train_file, N, word_id_dict), buf_size=65536), reader.train_data_for_NGram(train_file, N, word_id_dict),
buf_size=65536),
batch_size=32) batch_size=32)
test_reader = paddle.batch( test_reader = paddle.batch(
......
...@@ -71,7 +71,8 @@ def train(): ...@@ -71,7 +71,8 @@ def train():
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
reader.train_data( reader.train_data(
train_file, min_sentence_length, max_sentence_length, word_id_dict), buf_size=65536), train_file, min_sentence_length,
max_sentence_length, word_id_dict), buf_size=65536),
batch_size=32) batch_size=32)
test_reader = paddle.batch( test_reader = paddle.batch(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册