提交 4c0d936b 编写于 作者: Z zhaopu

code formate

上级 2f3f35ac
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
import collections import collections
import os import os
# -- function -- # -- function --
def save_vocab(word_id_dict, vocab_file_name): def save_vocab(word_id_dict, vocab_file_name):
""" """
save vocab. save vocab.
...@@ -79,12 +79,14 @@ def _read_by_fixed_length(file_name, word_id_dict, sentence_len=10): ...@@ -79,12 +79,14 @@ def _read_by_fixed_length(file_name, word_id_dict, sentence_len=10):
while count < sentence_num: while count < sentence_num:
start = count * sentence_len start = count * sentence_len
count += 1 count += 1
yield ids[start:start + sentence_len], ids[start + 1:start + sentence_len + 1] yield ids[start:start + sentence_len], ids[start + 1:start +
sentence_len + 1]
return reader return reader
def _read_by_line(file_name, min_sentence_length, max_sentence_length, word_id_dict): def _read_by_line(file_name, min_sentence_length, max_sentence_length,
word_id_dict):
""" """
create reader, each line is a sample. create reader, each line is a sample.
...@@ -99,7 +101,8 @@ def _read_by_line(file_name, min_sentence_length, max_sentence_length, word_id_d ...@@ -99,7 +101,8 @@ def _read_by_line(file_name, min_sentence_length, max_sentence_length, word_id_d
UNK = word_id_dict['<UNK>'] UNK = word_id_dict['<UNK>']
for line in open(file_name): for line in open(file_name):
words = line.decode('utf-8', 'ignore').strip().split() words = line.decode('utf-8', 'ignore').strip().split()
if len(words) < min_sentence_length or len(words) > max_sentence_length: if len(words) < min_sentence_length or len(
words) > max_sentence_length:
continue continue
ids = [word_id_dict.get(w, UNK) for w in words] ids = [word_id_dict.get(w, UNK) for w in words]
ids.append(word_id_dict['<EOS>']) ids.append(word_id_dict['<EOS>'])
...@@ -134,12 +137,16 @@ def _reader_creator_for_NGram(file_name, N, word_id_dict): ...@@ -134,12 +137,16 @@ def _reader_creator_for_NGram(file_name, N, word_id_dict):
return reader return reader
def train_data(train_file, min_sentence_length, max_sentence_length, word_id_dict): def train_data(train_file, min_sentence_length, max_sentence_length,
return _read_by_line(train_file, min_sentence_length, max_sentence_length, word_id_dict) word_id_dict):
return _read_by_line(train_file, min_sentence_length, max_sentence_length,
word_id_dict)
def test_data(test_file, min_sentence_length, max_sentence_length, word_id_dict): def test_data(test_file, min_sentence_length, max_sentence_length,
return _read_by_line(test_file, min_sentence_length, max_sentence_length, word_id_dict) word_id_dict):
return _read_by_line(test_file, min_sentence_length, max_sentence_length,
word_id_dict)
def train_data_for_NGram(train_file, N, word_id_dict): def train_data_for_NGram(train_file, N, word_id_dict):
......
...@@ -88,7 +88,8 @@ def train(): ...@@ -88,7 +88,8 @@ def train():
test_reader = paddle.batch( test_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
reader.test_data_for_NGram(train_file, N, word_id_dict), buf_size=65536), reader.test_data_for_NGram(train_file, N, word_id_dict),
buf_size=65536),
batch_size=8) batch_size=8)
# network config # network config
...@@ -113,8 +114,7 @@ def train(): ...@@ -113,8 +114,7 @@ def train():
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0: if event.batch_id % 100 == 0:
print("\nPass %d, Batch %d, Cost %f, %s" % ( print("\nPass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.pass_id, event.batch_id, event.cost, event.metrics))
event.metrics))
else: else:
sys.stdout.write('.') sys.stdout.write('.')
sys.stdout.flush() sys.stdout.flush()
...@@ -123,8 +123,9 @@ def train(): ...@@ -123,8 +123,9 @@ def train():
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
result = trainer.test(reader=test_reader) result = trainer.test(reader=test_reader)
print("\nTest with Pass %d, %s" % (event.pass_id, result.metrics)) print("\nTest with Pass %d, %s" % (event.pass_id, result.metrics))
with gzip.open(model_file_name_prefix + str(event.pass_id) + '.tar.gz', with gzip.open(
'w') as f: model_file_name_prefix + str(event.pass_id) + '.tar.gz',
'w') as f:
parameters.to_tar(f) parameters.to_tar(f)
# start to train # start to train
...@@ -163,9 +164,11 @@ if __name__ == '__main__': ...@@ -163,9 +164,11 @@ if __name__ == '__main__':
# prepare model # prepare model
word_id_dict = reader.load_vocab(vocab_file) # load word dictionary word_id_dict = reader.load_vocab(vocab_file) # load word dictionary
_, output_layer = lm(len(word_id_dict), emb_dim, hidden_size, num_layer) # network config _, output_layer = lm(len(word_id_dict), emb_dim, hidden_size,
num_layer) # network config
model_file_name = model_file_name_prefix + str(num_passs - 1) + '.tar.gz' model_file_name = model_file_name_prefix + str(num_passs - 1) + '.tar.gz'
parameters = paddle.parameters.Parameters.from_tar(gzip.open(model_file_name)) # load parameters parameters = paddle.parameters.Parameters.from_tar(
gzip.open(model_file_name)) # load parameters
# generate # generate
input = [[word_id_dict.get(w, word_id_dict['<UNK>']) for w in text.split()]] input = [[word_id_dict.get(w, word_id_dict['<UNK>']) for w in text.split()]]
predictions = paddle.infer( predictions = paddle.infer(
...@@ -173,7 +176,9 @@ if __name__ == '__main__': ...@@ -173,7 +176,9 @@ if __name__ == '__main__':
parameters=parameters, parameters=parameters,
input=input, input=input,
field=['value']) field=['value'])
id_word_dict = dict([(v, k) for k, v in word_id_dict.items()]) # dictionary with type {id : word} id_word_dict = dict(
[(v, k)
for k, v in word_id_dict.items()]) # dictionary with type {id : word}
predictions[-1][word_id_dict['<UNK>']] = -1 # filter <UNK> predictions[-1][word_id_dict['<UNK>']] = -1 # filter <UNK>
next_word = id_word_dict[np.argmax(predictions[-1])] next_word = id_word_dict[np.argmax(predictions[-1])]
print(next_word.encode('utf-8')) print(next_word.encode('utf-8'))
...@@ -24,21 +24,20 @@ def lm(vocab_size, emb_dim, rnn_type, hidden_size, num_layer): ...@@ -24,21 +24,20 @@ def lm(vocab_size, emb_dim, rnn_type, hidden_size, num_layer):
# input layers # input layers
data = paddle.layer.data( data = paddle.layer.data(
name="word", type=paddle.data_type.integer_value_sequence(vocab_size)) name="word", type=paddle.data_type.integer_value_sequence(vocab_size))
target = paddle.layer.data("label", paddle.data_type.integer_value_sequence(vocab_size)) target = paddle.layer.data(
"label", paddle.data_type.integer_value_sequence(vocab_size))
# embedding layer # embedding layer
emb = paddle.layer.embedding(input=data, size=emb_dim) emb = paddle.layer.embedding(input=data, size=emb_dim)
# rnn layer # rnn layer
if rnn_type == 'lstm': if rnn_type == 'lstm':
rnn_cell = paddle.networks.simple_lstm( rnn_cell = paddle.networks.simple_lstm(input=emb, size=hidden_size)
input=emb, size=hidden_size)
for _ in range(num_layer - 1): for _ in range(num_layer - 1):
rnn_cell = paddle.networks.simple_lstm( rnn_cell = paddle.networks.simple_lstm(
input=rnn_cell, size=hidden_size) input=rnn_cell, size=hidden_size)
elif rnn_type == 'gru': elif rnn_type == 'gru':
rnn_cell = paddle.networks.simple_gru( rnn_cell = paddle.networks.simple_gru(input=emb, size=hidden_size)
input=emb, size=hidden_size)
for _ in range(num_layer - 1): for _ in range(num_layer - 1):
rnn_cell = paddle.networks.simple_gru( rnn_cell = paddle.networks.simple_gru(
input=rnn_cell, size=hidden_size) input=rnn_cell, size=hidden_size)
...@@ -70,15 +69,16 @@ def train(): ...@@ -70,15 +69,16 @@ def train():
# define data reader # define data reader
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
reader.train_data( reader.train_data(train_file, min_sentence_length,
train_file, min_sentence_length, max_sentence_length, word_id_dict),
max_sentence_length, word_id_dict), buf_size=65536), buf_size=65536),
batch_size=32) batch_size=32)
test_reader = paddle.batch( test_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
reader.test_data( reader.test_data(test_file, min_sentence_length,
test_file, min_sentence_length, max_sentence_length, word_id_dict), buf_size=65536), max_sentence_length, word_id_dict),
buf_size=65536),
batch_size=8) batch_size=8)
# network config # network config
...@@ -103,8 +103,7 @@ def train(): ...@@ -103,8 +103,7 @@ def train():
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0: if event.batch_id % 100 == 0:
print("\nPass %d, Batch %d, Cost %f, %s" % ( print("\nPass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.pass_id, event.batch_id, event.cost, event.metrics))
event.metrics))
else: else:
sys.stdout.write('.') sys.stdout.write('.')
sys.stdout.flush() sys.stdout.flush()
...@@ -113,8 +112,9 @@ def train(): ...@@ -113,8 +112,9 @@ def train():
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
result = trainer.test(reader=test_reader) result = trainer.test(reader=test_reader)
print("\nTest with Pass %d, %s" % (event.pass_id, result.metrics)) print("\nTest with Pass %d, %s" % (event.pass_id, result.metrics))
with gzip.open(model_file_name_prefix + str(event.pass_id) + '.tar.gz', with gzip.open(
'w') as f: model_file_name_prefix + str(event.pass_id) + '.tar.gz',
'w') as f:
parameters.to_tar(f) parameters.to_tar(f)
# start to train # start to train
...@@ -126,7 +126,8 @@ def train(): ...@@ -126,7 +126,8 @@ def train():
print("Training finished.") print("Training finished.")
def _generate_with_beamSearch(inferer, word_id_dict, input, num_words, beam_size): def _generate_with_beamSearch(inferer, word_id_dict, input, num_words,
beam_size):
""" """
Demo: generate 'num_words' words using "beam search" algorithm. Demo: generate 'num_words' words using "beam search" algorithm.
...@@ -146,11 +147,14 @@ def _generate_with_beamSearch(inferer, word_id_dict, input, num_words, beam_size ...@@ -146,11 +147,14 @@ def _generate_with_beamSearch(inferer, word_id_dict, input, num_words, beam_size
assert beam_size > 0 and num_words > 0 assert beam_size > 0 and num_words > 0
# load word dictionary # load word dictionary
id_word_dict = dict([(v, k) for k, v in word_id_dict.items()]) # {id : word} id_word_dict = dict(
[(v, k) for k, v in word_id_dict.items()]) # {id : word}
# tools # tools
def str2ids(str): def str2ids(str):
return [[[word_id_dict.get(w, word_id_dict['<UNK>']) for w in str.split()]]] return [[[
word_id_dict.get(w, word_id_dict['<UNK>']) for w in str.split()
]]]
def ids2str(ids): def ids2str(ids):
return [[[id_word_dict.get(id, ' ') for id in ids]]] return [[[id_word_dict.get(id, ' ') for id in ids]]]
...@@ -167,14 +171,18 @@ def _generate_with_beamSearch(inferer, word_id_dict, input, num_words, beam_size ...@@ -167,14 +171,18 @@ def _generate_with_beamSearch(inferer, word_id_dict, input, num_words, beam_size
# find next beam_size words # find next beam_size words
for _ in range(beam_size): for _ in range(beam_size):
cur_maxProb_index = np.argmax(predictions[-1]) # next word's id cur_maxProb_index = np.argmax(predictions[-1]) # next word's id
text_new = text + ' ' + id_word_dict[cur_maxProb_index] # text append nextWord text_new = text + ' ' + id_word_dict[
texts_new[text_new] = texts[text] * predictions[-1][cur_maxProb_index] cur_maxProb_index] # text append nextWord
texts_new[text_new] = texts[text] * predictions[-1][
cur_maxProb_index]
predictions[-1][cur_maxProb_index] = -1 predictions[-1][cur_maxProb_index] = -1
texts.clear() texts.clear()
if len(texts_new) <= beam_size: if len(texts_new) <= beam_size:
texts = texts_new texts = texts_new
else: # cutting else: # cutting
texts = dict(sorted(texts_new.items(), key=lambda d: d[1], reverse=True)[:beam_size]) texts = dict(
sorted(texts_new.items(), key=lambda d: d[1], reverse=True)
[:beam_size])
return texts return texts
...@@ -190,21 +198,30 @@ def predict(): ...@@ -190,21 +198,30 @@ def predict():
if os.path.isfile(vocab_file): if os.path.isfile(vocab_file):
word_id_dict = reader.load_vocab(vocab_file) # load word dictionary word_id_dict = reader.load_vocab(vocab_file) # load word dictionary
else: else:
word_id_dict = reader.build_vocab(train_file, vocab_max_size) # build vocab word_id_dict = reader.build_vocab(train_file,
vocab_max_size) # build vocab
reader.save_vocab(word_id_dict, vocab_file) # save vocab reader.save_vocab(word_id_dict, vocab_file) # save vocab
# prepare and cache model # prepare and cache model
_, output = lm(len(word_id_dict), emb_dim, rnn_type, hidden_size, num_layer) # network config _, output = lm(
len(word_id_dict), emb_dim, rnn_type, hidden_size,
num_layer) # network config
model_file_name = model_file_name_prefix + str(num_passs - 1) + '.tar.gz' model_file_name = model_file_name_prefix + str(num_passs - 1) + '.tar.gz'
parameters = paddle.parameters.Parameters.from_tar(gzip.open(model_file_name)) # load parameters parameters = paddle.parameters.Parameters.from_tar(
inferer = paddle.inference.Inference(output_layer=output, parameters=parameters) gzip.open(model_file_name)) # load parameters
inferer = paddle.inference.Inference(
output_layer=output, parameters=parameters)
# generate text # generate text
while True: while True:
input_str = raw_input('input:') input_str = raw_input('input:')
input_str_uft8 = input_str.decode('utf-8') input_str_uft8 = input_str.decode('utf-8')
generate_sentences = _generate_with_beamSearch( generate_sentences = _generate_with_beamSearch(
inferer=inferer, word_id_dict=word_id_dict, input=input_str_uft8, num_words=5, beam_size=5) inferer=inferer,
word_id_dict=word_id_dict,
input=input_str_uft8,
num_words=5,
beam_size=5)
# print result # print result
for (sentence, prob) in generate_sentences.items(): for (sentence, prob) in generate_sentences.items():
print(sentence.encode('utf-8', 'replace')) print(sentence.encode('utf-8', 'replace'))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册