提交 5fea977f 编写于 作者: Z zhaopu7 提交者: GitHub

Add files via upload

上级 b0d82a76
......@@ -3,7 +3,8 @@ import sys
import paddle.v2 as paddle
import data_util as reader
import gzip
import generate_text as generator
import os
import numpy as np
def lm(vocab_size, emb_dim, rnn_type, hidden_size, num_layer):
"""
......@@ -60,19 +61,22 @@ def train():
:return: none, but this function will save the training model each epoch.
"""
# load word dictionary
print('load dictionary...')
word_id_dict = reader.build_vocab()
# prepare word dictionary
print('prepare vocab...')
word_id_dict = reader.build_vocab(train_file, vocab_max_size) # build vocab
reader.save_vocab(word_id_dict, vocab_file) # save vocab
# define data reader
train_reader = paddle.batch(
paddle.reader.shuffle(
reader.train_data(), buf_size=65536),
reader.train_data(
train_file, min_sentence_length, max_sentence_length, word_id_dict), buf_size=65536),
batch_size=32)
test_reader = paddle.batch(
paddle.reader.shuffle(
reader.test_data(), buf_size=65536),
reader.test_data(
test_file, min_sentence_length, max_sentence_length, word_id_dict), buf_size=65536),
batch_size=8)
# network config
......@@ -119,10 +123,95 @@ def train():
print("Training finished.")
def _generate_with_beamSearch(inferer, word_id_dict, input, num_words, beam_size):
"""
Demo: generate 'num_words' words using "beam search" algorithm.
:param inferer: paddle's inferer
:type inferer: paddle.inference.Inference
:param word_id_dict: vocab.
:type word_id_dict: dictionary with content of '{word, id}', 'word' is string type , 'id' is int type.
:param input: prefix text.
:type input: string.
:param num_words: the number of the words to generate.
:type num_words: int
:param beam_size: beam with.
:type beam_size: int
:return: text with generated words. dictionary with content of '{text, probability}'
"""
assert beam_size > 0 and num_words > 0
# load word dictionary
id_word_dict = dict([(v, k) for k, v in word_id_dict.items()]) # {id : word}
# tools
def str2ids(str):
return [[[word_id_dict.get(w, word_id_dict['<UNK>']) for w in str.split()]]]
def ids2str(ids):
return [[[id_word_dict.get(id, ' ') for id in ids]]]
# generate
texts = {} # type: {text : prob}
texts[input] = 1
for _ in range(num_words):
texts_new = {}
for (text, prob) in texts.items():
# next word's prob distubution
predictions = inferer.infer(input=str2ids(text))
predictions[-1][word_id_dict['<UNK>']] = -1 # filter <UNK>
# find next beam_size words
for _ in range(beam_size):
cur_maxProb_index = np.argmax(predictions[-1]) # next word's id
text_new = text + ' ' + id_word_dict[cur_maxProb_index] # text append nextWord
texts_new[text_new] = texts[text] * predictions[-1][cur_maxProb_index]
predictions[-1][cur_maxProb_index] = -1
texts.clear()
if len(texts_new) <= beam_size:
texts = texts_new
else: # cutting
texts = dict(sorted(texts_new.items(), key=lambda d: d[1], reverse=True)[:beam_size])
return texts
def predict():
"""
demo: use model to do prediction.
:return: print result to console.
"""
# prepare and cache vocab
if os.path.isfile(vocab_file):
word_id_dict = reader.load_vocab(vocab_file) # load word dictionary
else:
word_id_dict = reader.build_vocab(train_file, vocab_max_size) # build vocab
reader.save_vocab(word_id_dict, vocab_file) # save vocab
# prepare and cache model
_, output = lm(len(word_id_dict), emb_dim, rnn_type, hidden_size, num_layer) # network config
model_file_name = model_file_name_prefix + str(num_passs - 1) + '.tar.gz'
parameters = paddle.parameters.Parameters.from_tar(gzip.open(model_file_name)) # load parameters
inferer = paddle.inference.Inference(output_layer=output, parameters=parameters)
# generate text
while True:
input_str = raw_input('input:')
input_str_uft8 = input_str.decode('utf-8')
generate_sentences = _generate_with_beamSearch(
inferer=inferer, word_id_dict=word_id_dict, input=input_str_uft8, num_words=5, beam_size=5)
# print result
for (sentence, prob) in generate_sentences.items():
print(sentence.encode('utf-8', 'replace'))
print('prob: ', prob)
print('-------')
if __name__ == '__main__':
# -- config --
paddle.init(use_gpu=False, trainer_count=1)
# -- config : model --
rnn_type = 'gru' # or 'lstm'
emb_dim = 200
hidden_size = 200
......@@ -130,21 +219,17 @@ if __name__ == '__main__':
num_layer = 2
model_file_name_prefix = 'lm_' + rnn_type + '_params_pass_'
# -- config : data --
train_file = 'data/ptb.train.txt'
test_file = 'data/ptb.test.txt'
vocab_file = 'data/vocab_ptb.txt' # the file to save vocab
vocab_max_size = 3000
min_sentence_length = 3
max_sentence_length = 60
# -- train --
paddle.init(use_gpu=False, trainer_count=1)
train()
# -- predict --
# prepare model
word_id_dict = reader.build_vocab() # load word dictionary
_, output = lm(len(word_id_dict), emb_dim, rnn_type, hidden_size, num_layer) # network config
model_file_name = model_file_name_prefix + str(num_passs - 1) + '.tar.gz'
parameters = paddle.parameters.Parameters.from_tar(gzip.open(model_file_name)) # load parameters
# generate
text = 'the end of'
generate_sentences = generator.generate_with_beamSearch(output, parameters, word_id_dict, text, 5, 5)
# print result
for (sentence, prob) in generate_sentences.items():
print(sentence.encode('utf-8', 'replace'))
print('prob: ', prob)
print('-------')
predict()
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册