提交 4c0d936b 编写于 作者: Z zhaopu

code formate

上级 2f3f35ac
......@@ -2,9 +2,9 @@
import collections
import os
# -- function --
def save_vocab(word_id_dict, vocab_file_name):
"""
save vocab.
......@@ -79,12 +79,14 @@ def _read_by_fixed_length(file_name, word_id_dict, sentence_len=10):
while count < sentence_num:
start = count * sentence_len
count += 1
yield ids[start:start + sentence_len], ids[start + 1:start + sentence_len + 1]
yield ids[start:start + sentence_len], ids[start + 1:start +
sentence_len + 1]
return reader
def _read_by_line(file_name, min_sentence_length, max_sentence_length, word_id_dict):
def _read_by_line(file_name, min_sentence_length, max_sentence_length,
word_id_dict):
"""
create reader, each line is a sample.
......@@ -99,7 +101,8 @@ def _read_by_line(file_name, min_sentence_length, max_sentence_length, word_id_d
UNK = word_id_dict['<UNK>']
for line in open(file_name):
words = line.decode('utf-8', 'ignore').strip().split()
if len(words) < min_sentence_length or len(words) > max_sentence_length:
if len(words) < min_sentence_length or len(
words) > max_sentence_length:
continue
ids = [word_id_dict.get(w, UNK) for w in words]
ids.append(word_id_dict['<EOS>'])
......@@ -134,12 +137,16 @@ def _reader_creator_for_NGram(file_name, N, word_id_dict):
return reader
def train_data(train_file, min_sentence_length, max_sentence_length, word_id_dict):
return _read_by_line(train_file, min_sentence_length, max_sentence_length, word_id_dict)
def train_data(train_file, min_sentence_length, max_sentence_length,
word_id_dict):
return _read_by_line(train_file, min_sentence_length, max_sentence_length,
word_id_dict)
def test_data(test_file, min_sentence_length, max_sentence_length, word_id_dict):
return _read_by_line(test_file, min_sentence_length, max_sentence_length, word_id_dict)
def test_data(test_file, min_sentence_length, max_sentence_length,
word_id_dict):
return _read_by_line(test_file, min_sentence_length, max_sentence_length,
word_id_dict)
def train_data_for_NGram(train_file, N, word_id_dict):
......
......@@ -88,7 +88,8 @@ def train():
test_reader = paddle.batch(
paddle.reader.shuffle(
reader.test_data_for_NGram(train_file, N, word_id_dict), buf_size=65536),
reader.test_data_for_NGram(train_file, N, word_id_dict),
buf_size=65536),
batch_size=8)
# network config
......@@ -113,8 +114,7 @@ def train():
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0:
print("\nPass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost,
event.metrics))
event.pass_id, event.batch_id, event.cost, event.metrics))
else:
sys.stdout.write('.')
sys.stdout.flush()
......@@ -123,8 +123,9 @@ def train():
if isinstance(event, paddle.event.EndPass):
result = trainer.test(reader=test_reader)
print("\nTest with Pass %d, %s" % (event.pass_id, result.metrics))
with gzip.open(model_file_name_prefix + str(event.pass_id) + '.tar.gz',
'w') as f:
with gzip.open(
model_file_name_prefix + str(event.pass_id) + '.tar.gz',
'w') as f:
parameters.to_tar(f)
# start to train
......@@ -163,9 +164,11 @@ if __name__ == '__main__':
# prepare model
word_id_dict = reader.load_vocab(vocab_file) # load word dictionary
_, output_layer = lm(len(word_id_dict), emb_dim, hidden_size, num_layer) # network config
_, output_layer = lm(len(word_id_dict), emb_dim, hidden_size,
num_layer) # network config
model_file_name = model_file_name_prefix + str(num_passs - 1) + '.tar.gz'
parameters = paddle.parameters.Parameters.from_tar(gzip.open(model_file_name)) # load parameters
parameters = paddle.parameters.Parameters.from_tar(
gzip.open(model_file_name)) # load parameters
# generate
input = [[word_id_dict.get(w, word_id_dict['<UNK>']) for w in text.split()]]
predictions = paddle.infer(
......@@ -173,7 +176,9 @@ if __name__ == '__main__':
parameters=parameters,
input=input,
field=['value'])
id_word_dict = dict([(v, k) for k, v in word_id_dict.items()]) # dictionary with type {id : word}
id_word_dict = dict(
[(v, k)
for k, v in word_id_dict.items()]) # dictionary with type {id : word}
predictions[-1][word_id_dict['<UNK>']] = -1 # filter <UNK>
next_word = id_word_dict[np.argmax(predictions[-1])]
print(next_word.encode('utf-8'))
......@@ -24,21 +24,20 @@ def lm(vocab_size, emb_dim, rnn_type, hidden_size, num_layer):
# input layers
data = paddle.layer.data(
name="word", type=paddle.data_type.integer_value_sequence(vocab_size))
target = paddle.layer.data("label", paddle.data_type.integer_value_sequence(vocab_size))
target = paddle.layer.data(
"label", paddle.data_type.integer_value_sequence(vocab_size))
# embedding layer
emb = paddle.layer.embedding(input=data, size=emb_dim)
# rnn layer
if rnn_type == 'lstm':
rnn_cell = paddle.networks.simple_lstm(
input=emb, size=hidden_size)
rnn_cell = paddle.networks.simple_lstm(input=emb, size=hidden_size)
for _ in range(num_layer - 1):
rnn_cell = paddle.networks.simple_lstm(
input=rnn_cell, size=hidden_size)
elif rnn_type == 'gru':
rnn_cell = paddle.networks.simple_gru(
input=emb, size=hidden_size)
rnn_cell = paddle.networks.simple_gru(input=emb, size=hidden_size)
for _ in range(num_layer - 1):
rnn_cell = paddle.networks.simple_gru(
input=rnn_cell, size=hidden_size)
......@@ -70,15 +69,16 @@ def train():
# define data reader
train_reader = paddle.batch(
paddle.reader.shuffle(
reader.train_data(
train_file, min_sentence_length,
max_sentence_length, word_id_dict), buf_size=65536),
reader.train_data(train_file, min_sentence_length,
max_sentence_length, word_id_dict),
buf_size=65536),
batch_size=32)
test_reader = paddle.batch(
paddle.reader.shuffle(
reader.test_data(
test_file, min_sentence_length, max_sentence_length, word_id_dict), buf_size=65536),
reader.test_data(test_file, min_sentence_length,
max_sentence_length, word_id_dict),
buf_size=65536),
batch_size=8)
# network config
......@@ -103,8 +103,7 @@ def train():
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0:
print("\nPass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost,
event.metrics))
event.pass_id, event.batch_id, event.cost, event.metrics))
else:
sys.stdout.write('.')
sys.stdout.flush()
......@@ -113,8 +112,9 @@ def train():
if isinstance(event, paddle.event.EndPass):
result = trainer.test(reader=test_reader)
print("\nTest with Pass %d, %s" % (event.pass_id, result.metrics))
with gzip.open(model_file_name_prefix + str(event.pass_id) + '.tar.gz',
'w') as f:
with gzip.open(
model_file_name_prefix + str(event.pass_id) + '.tar.gz',
'w') as f:
parameters.to_tar(f)
# start to train
......@@ -126,7 +126,8 @@ def train():
print("Training finished.")
def _generate_with_beamSearch(inferer, word_id_dict, input, num_words, beam_size):
def _generate_with_beamSearch(inferer, word_id_dict, input, num_words,
beam_size):
"""
Demo: generate 'num_words' words using "beam search" algorithm.
......@@ -146,11 +147,14 @@ def _generate_with_beamSearch(inferer, word_id_dict, input, num_words, beam_size
assert beam_size > 0 and num_words > 0
# load word dictionary
id_word_dict = dict([(v, k) for k, v in word_id_dict.items()]) # {id : word}
id_word_dict = dict(
[(v, k) for k, v in word_id_dict.items()]) # {id : word}
# tools
def str2ids(str):
return [[[word_id_dict.get(w, word_id_dict['<UNK>']) for w in str.split()]]]
return [[[
word_id_dict.get(w, word_id_dict['<UNK>']) for w in str.split()
]]]
def ids2str(ids):
return [[[id_word_dict.get(id, ' ') for id in ids]]]
......@@ -167,14 +171,18 @@ def _generate_with_beamSearch(inferer, word_id_dict, input, num_words, beam_size
# find next beam_size words
for _ in range(beam_size):
cur_maxProb_index = np.argmax(predictions[-1]) # next word's id
text_new = text + ' ' + id_word_dict[cur_maxProb_index] # text append nextWord
texts_new[text_new] = texts[text] * predictions[-1][cur_maxProb_index]
text_new = text + ' ' + id_word_dict[
cur_maxProb_index] # text append nextWord
texts_new[text_new] = texts[text] * predictions[-1][
cur_maxProb_index]
predictions[-1][cur_maxProb_index] = -1
texts.clear()
if len(texts_new) <= beam_size:
texts = texts_new
else: # cutting
texts = dict(sorted(texts_new.items(), key=lambda d: d[1], reverse=True)[:beam_size])
texts = dict(
sorted(texts_new.items(), key=lambda d: d[1], reverse=True)
[:beam_size])
return texts
......@@ -190,21 +198,30 @@ def predict():
if os.path.isfile(vocab_file):
word_id_dict = reader.load_vocab(vocab_file) # load word dictionary
else:
word_id_dict = reader.build_vocab(train_file, vocab_max_size) # build vocab
word_id_dict = reader.build_vocab(train_file,
vocab_max_size) # build vocab
reader.save_vocab(word_id_dict, vocab_file) # save vocab
# prepare and cache model
_, output = lm(len(word_id_dict), emb_dim, rnn_type, hidden_size, num_layer) # network config
_, output = lm(
len(word_id_dict), emb_dim, rnn_type, hidden_size,
num_layer) # network config
model_file_name = model_file_name_prefix + str(num_passs - 1) + '.tar.gz'
parameters = paddle.parameters.Parameters.from_tar(gzip.open(model_file_name)) # load parameters
inferer = paddle.inference.Inference(output_layer=output, parameters=parameters)
parameters = paddle.parameters.Parameters.from_tar(
gzip.open(model_file_name)) # load parameters
inferer = paddle.inference.Inference(
output_layer=output, parameters=parameters)
# generate text
while True:
input_str = raw_input('input:')
input_str_uft8 = input_str.decode('utf-8')
generate_sentences = _generate_with_beamSearch(
inferer=inferer, word_id_dict=word_id_dict, input=input_str_uft8, num_words=5, beam_size=5)
inferer=inferer,
word_id_dict=word_id_dict,
input=input_str_uft8,
num_words=5,
beam_size=5)
# print result
for (sentence, prob) in generate_sentences.items():
print(sentence.encode('utf-8', 'replace'))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册