提交 24f3a9df 编写于 作者: C Cao Ying 提交者: GitHub

Merge pull request #228 from lcy-seso/gen_chinese_poetry

Add a deep stacked LSTM seq2seq model for text generation.
import os
import sys
import gzip
import logging
import numpy as np
import click
import reader
import paddle.v2 as paddle
from paddle.v2.layer import parse_network
from network_conf import encoder_decoder_network
logger = logging.getLogger("paddle")
logger.setLevel(logging.WARNING)
def infer_a_batch(inferer, test_batch, beam_size, id_to_text, fout):
beam_result = inferer.infer(input=test_batch, field=["prob", "id"])
gen_sen_idx = np.where(beam_result[1] == -1)[0]
assert len(gen_sen_idx) == len(test_batch) * beam_size, ("%d vs. %d" % (
len(gen_sen_idx), len(test_batch) * beam_size))
start_pos, end_pos = 1, 0
for i, sample in enumerate(test_batch):
fout.write("%s\n" % (
" ".join([id_to_text[w] for w in sample[0][1:-1]])
)) # skip the start and ending mark when print the source sentence
for j in xrange(beam_size):
end_pos = gen_sen_idx[i * beam_size + j]
fout.write("%s\n" % ("%.4f\t%s" % (beam_result[0][i][j], " ".join(
id_to_text[w] for w in beam_result[1][start_pos:end_pos]))))
start_pos = end_pos + 2
fout.write("\n")
fout.flush
@click.command("generate")
@click.option(
"--model_path",
default="",
help="The path of the trained model for generation.")
@click.option(
"--word_dict_path", required=True, help="The path of word dictionary.")
@click.option(
"--test_data_path",
required=True,
help="The path of input data for generation.")
@click.option(
"--batch_size",
default=1,
help="The number of testing examples in one forward pass in generation.")
@click.option(
"--beam_size", default=5, help="The beam expansion in beam search.")
@click.option(
"--save_file",
required=True,
help="The file path to save the generated results.")
@click.option(
"--use_gpu", default=False, help="Whether to use GPU in generation.")
def generate(model_path, word_dict_path, test_data_path, batch_size, beam_size,
save_file, use_gpu):
assert os.path.exists(model_path), "The given model does not exist."
assert os.path.exists(test_data_path), "The given test data does not exist."
with gzip.open(model_path, "r") as f:
parameters = paddle.parameters.Parameters.from_tar(f)
id_to_text = {}
assert os.path.exists(
word_dict_path), "The given word dictionary path does not exist."
with open(word_dict_path, "r") as f:
for i, line in enumerate(f):
id_to_text[i] = line.strip().split("\t")[0]
paddle.init(use_gpu=use_gpu, trainer_count=1)
beam_gen = encoder_decoder_network(
word_count=len(id_to_text),
emb_dim=512,
encoder_depth=3,
encoder_hidden_dim=512,
decoder_depth=3,
decoder_hidden_dim=512,
is_generating=True,
beam_size=beam_size,
max_length=10)
inferer = paddle.inference.Inference(
output_layer=beam_gen, parameters=parameters)
test_batch = []
with open(save_file, "w") as fout:
for idx, item in enumerate(
reader.gen_reader(test_data_path, word_dict_path)()):
test_batch.append([item])
if len(test_batch) == batch_size:
infer_a_batch(inferer, test_batch, beam_size, id_to_text, fout)
test_batch = []
if len(test_batch):
infer_a_batch(inferer, test_batch, beam_size, id_to_text, fout)
test_batch = []
if __name__ == "__main__":
generate()
import paddle.v2 as paddle
from paddle.v2.layer import parse_network
__all__ = ["encoder_decoder_network"]
def _bidirect_lstm_encoder(input, hidden_dim, depth):
lstm_last = []
for dirt in ["fwd", "bwd"]:
for i in range(depth):
input_proj = paddle.layer.mixed(
name="__in_proj_%0d_%s__" % (i, dirt),
size=hidden_dim * 4,
bias_attr=True,
input=[
paddle.layer.full_matrix_projection(input_proj),
paddle.layer.full_matrix_projection(
lstm, param_attr=paddle.attr.Param(initial_std=5e-4)),
] if i else [paddle.layer.full_matrix_projection(input)])
lstm = paddle.layer.lstmemory(
input=input_proj,
bias_attr=paddle.attr.Param(initial_std=0.),
param_attr=paddle.attr.Param(initial_std=5e-4),
reverse=i % 2 if dirt == "fwd" else not i % 2)
lstm_last.append(lstm)
return paddle.layer.concat(input=lstm_last)
def _attended_decoder_step(word_count, enc_out, enc_out_proj,
decoder_hidden_dim, depth, trg_emb):
decoder_memory = paddle.layer.memory(
name="__decoder_0__", size=decoder_hidden_dim, boot_layer=None)
context = paddle.networks.simple_attention(
encoded_sequence=enc_out,
encoded_proj=enc_out_proj,
decoder_state=decoder_memory)
for i in range(depth):
input_proj = paddle.layer.mixed(
act=paddle.activation.Linear(),
size=decoder_hidden_dim * 4,
bias_attr=False,
input=[
paddle.layer.full_matrix_projection(input_proj),
paddle.layer.full_matrix_projection(lstm)
] if i else [
paddle.layer.full_matrix_projection(context),
paddle.layer.full_matrix_projection(trg_emb)
])
lstm = paddle.networks.lstmemory_unit(
input=input_proj,
input_proj_layer_attr=paddle.attr.ExtraLayerAttribute(
error_clipping_threshold=25.),
out_memory=decoder_memory if not i else None,
name="__decoder_%d__" % (i),
size=decoder_hidden_dim,
act=paddle.activation.Tanh(),
gate_act=paddle.activation.Sigmoid(),
state_act=paddle.activation.Tanh())
next_word = paddle.layer.fc(
size=word_count,
bias_attr=True,
act=paddle.activation.Softmax(),
input=lstm)
return next_word
def encoder_decoder_network(word_count,
emb_dim,
encoder_depth,
encoder_hidden_dim,
decoder_depth,
decoder_hidden_dim,
beam_size=10,
max_length=15,
is_generating=False):
src_emb = paddle.layer.embedding(
input=paddle.layer.data(
name="src_word_id",
type=paddle.data_type.integer_value_sequence(word_count)),
size=emb_dim,
param_attr=paddle.attr.ParamAttr(name="__embedding__"))
enc_out = _bidirect_lstm_encoder(
input=src_emb, hidden_dim=encoder_hidden_dim, depth=encoder_depth)
enc_out_proj = paddle.layer.fc(
act=paddle.activation.Linear(),
size=encoder_hidden_dim,
bias_attr=False,
input=enc_out)
decoder_group_name = "decoder_group"
group_inputs = [
word_count, paddle.layer.StaticInput(input=enc_out),
paddle.layer.StaticInput(input=enc_out_proj), decoder_hidden_dim,
decoder_depth
]
if is_generating:
gen_trg_emb = paddle.layer.GeneratedInput(
size=word_count,
embedding_name="__embedding__",
embedding_size=emb_dim)
return paddle.layer.beam_search(
name=decoder_group_name,
step=_attended_decoder_step,
input=group_inputs + [gen_trg_emb],
bos_id=0,
eos_id=1,
beam_size=beam_size,
max_length=max_length)
else:
trg_emb = paddle.layer.embedding(
input=paddle.layer.data(
name="trg_word_id",
type=paddle.data_type.integer_value_sequence(word_count)),
size=emb_dim,
param_attr=paddle.attr.ParamAttr(name="__embedding__"))
lbl = paddle.layer.data(
name="trg_next_word",
type=paddle.data_type.integer_value_sequence(word_count))
next_word = paddle.layer.recurrent_group(
name=decoder_group_name,
step=_attended_decoder_step,
input=group_inputs + [trg_emb])
return paddle.layer.classification_cost(input=next_word, label=lbl)
from utils import load_dict
def train_reader(data_file_path, word_dict_file):
def reader():
word_dict = load_dict(word_dict_file)
unk_id = word_dict[u"<unk>"]
bos_id = word_dict[u"<s>"]
eos_id = word_dict[u"<e>"]
with open(data_file_path, "r") as f:
for line in f:
line_split = line.strip().decode(
"utf8", errors="ignore").split("\t")
if len(line_split) < 3: continue
poetry = line_split[2].split(".")
poetry_ids = []
for sen in poetry:
if sen:
poetry_ids.append([bos_id] + [
word_dict.get(word, unk_id)
for word in "".join(sen.split())
] + [eos_id])
l = len(poetry_ids)
if l < 2: continue
for i in range(l - 1):
yield poetry_ids[i], poetry_ids[i +
1][:-1], poetry_ids[i +
1][1:]
return reader
def gen_reader(data_file_path, word_dict_file):
def reader():
word_dict = load_dict(word_dict_file)
unk_id = word_dict[u"<unk>"]
bos_id = word_dict[u"<s>"]
eos_id = word_dict[u"<e>"]
with open(data_file_path, "r") as f:
for line in f:
input_line = "".join(
line.strip().decode("utf8", errors="ignore").split())
yield [bos_id
] + [word_dict.get(word, unk_id)
for word in input_line] + [eos_id]
return reader
import os
import gzip
import logging
import click
import paddle.v2 as paddle
import reader
from paddle.v2.layer import parse_network
from network_conf import encoder_decoder_network
logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO)
def save_model(save_path, parameters):
with gzip.open(save_path, "w") as f:
parameters.to_tar(f)
def load_initial_model(model_path, parameters):
with gzip.open(model_path, "rb") as f:
parameters.init_from_tar(f)
@click.command("train")
@click.option(
"--num_passes", default=10, help="Number of passes for the training task.")
@click.option(
"--batch_size",
default=16,
help="The number of training examples in one forward/backward pass.")
@click.option(
"--use_gpu", default=False, help="Whether to use gpu to train the model.")
@click.option(
"--trainer_count", default=1, help="The thread number used in training.")
@click.option(
"--save_dir_path",
default="models",
help="The path to saved the trained models.")
@click.option(
"--encoder_depth",
default=3,
help="The number of stacked LSTM layers in encoder.")
@click.option(
"--decoder_depth",
default=3,
help="The number of stacked LSTM layers in encoder.")
@click.option(
"--train_data_path", required=True, help="The path of trainning data.")
@click.option(
"--word_dict_path", required=True, help="The path of word dictionary.")
@click.option(
"--init_model_path",
default="",
help=("The path of a trained model used to initialized all "
"the model parameters."))
def train(num_passes,
batch_size,
use_gpu,
trainer_count,
save_dir_path,
encoder_depth,
decoder_depth,
train_data_path,
word_dict_path,
init_model_path=""):
if not os.path.exists(save_dir_path):
os.mkdir(save_dir_path)
assert os.path.exists(
word_dict_path), "The given word dictionary does not exist."
assert os.path.exists(
train_data_path), "The given training data does not exist."
# initialize PaddlePaddle
paddle.init(use_gpu=use_gpu, trainer_count=trainer_count)
# define optimization method and the trainer instance
optimizer = paddle.optimizer.AdaDelta(
learning_rate=1e-3,
gradient_clipping_threshold=25.0,
regularization=paddle.optimizer.L2Regularization(rate=8e-4),
model_average=paddle.optimizer.ModelAverage(
average_window=0.5, max_average_window=2500))
cost = encoder_decoder_network(
word_count=len(open(word_dict_path, "r").readlines()),
emb_dim=512,
encoder_depth=encoder_depth,
encoder_hidden_dim=512,
decoder_depth=decoder_depth,
decoder_hidden_dim=512)
parameters = paddle.parameters.create(cost)
if init_model_path:
load_initial_model(init_model_path, parameters)
trainer = paddle.trainer.SGD(
cost=cost, parameters=parameters, update_equation=optimizer)
# define data reader
train_reader = paddle.batch(
paddle.reader.shuffle(
reader.train_reader(train_data_path, word_dict_path),
buf_size=1024000),
batch_size=batch_size)
# define the event_handler callback
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if (not event.batch_id % 1000) and event.batch_id:
save_path = os.path.join(save_dir_path,
"pass_%05d_batch_%05d.tar.gz" %
(event.pass_id, event.batch_id))
save_model(save_path, parameters)
if not event.batch_id % 5:
logger.info("Pass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics))
if isinstance(event, paddle.event.EndPass):
save_path = os.path.join(save_dir_path,
"pass_%05d.tar.gz" % event.pass_id)
save_model(save_path, parameters)
# start training
trainer.train(
reader=train_reader, event_handler=event_handler, num_passes=num_passes)
if __name__ == "__main__":
train()
import os
import sys
import re
from collections import defaultdict
def load_dict(word_dict_file):
word_dict = {}
with open(word_dict_file, "r") as fin:
for i, line in enumerate(fin):
key = line.strip().decode("utf8", errors="ignore").split("\t")[0]
word_dict[key] = i
return word_dict
def find_optiaml_pass(log_file):
cost_info = defaultdict(list)
cost_pat = re.compile(r'Cost\s[\d]+.[\d]+')
pass_pat = re.compile(r'Pass\s[\d]+')
with open(log_file, 'r') as flog:
for line in flog:
if not 'Cost' in line: continue
pass_id = pass_pat.findall(line.strip())[0]
cost = float(cost_pat.findall(line.strip())[0].replace('Cost ', ''))
cost_info[pass_id].append(cost)
print("optimal pass : %s" % sorted(
cost_info.iteritems(),
key=lambda x: sum(x[1]) / (len(x[1])),
reverse=False)[0][0])
#!/usr/bin/env python
# coding=utf-8
import os import os
import math import math
import numpy as np import numpy as np
......
#!/usr/bin/env python
# coding=utf-8
import os import os
################## for building word dictionary ################## ################## for building word dictionary ##################
......
# coding=utf-8
import os import os
import gzip import gzip
import numpy as np import numpy as np
......
# coding=utf-8
import paddle.v2 as paddle import paddle.v2 as paddle
......
# coding=utf-8
import collections import collections
import os import os
......
#!/usr/bin/env python
# coding=utf-8
import os import os
import sys import sys
import gzip import gzip
......
#!/usr/bin/env python
# coding=utf-8
import os import os
import logging import logging
from collections import defaultdict from collections import defaultdict
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册