diff --git a/generate_chinese_poetry/README.md b/generate_chinese_poetry/README.md index f6a09ed22d42d6de3b0d6fd14d826cb87de822f5..1f6bef0da8145098f70fd02030f6cf4f7284dd3e 100644 --- a/generate_chinese_poetry/README.md +++ b/generate_chinese_poetry/README.md @@ -1 +1,111 @@ -[TBD] +# 中国古诗生成 + +## 简介 +基于编码器-解码器(encoder-decoder)神经网络模型,利用全唐诗进行诗句-诗句(sequence to sequence)训练,实现给定诗句后,生成下一诗句。 + +模型中的编码器、解码器均使用堆叠双向LSTM (stacked bi-directional LSTM),默认均为3层,带有注意力单元(attention)。 + +以下是本例的简要目录结构及说明: + +```text +. +├── data # 存储训练数据及字典 +│ ├── download.sh # 下载原始数据 +├── README.md # 文档 +├── index.html # 文档(html格式) +├── preprocess.py # 原始数据预处理 +├── generate.py # 生成诗句脚本 +├── network_conf.py # 模型定义 +├── reader.py # 数据读取接口 +├── train.py # 训练脚本 +└── utils.py # 定义实用工具函数 +``` + +## 数据处理 +### 原始数据来源 +本例使用[中华古诗词数据库](https://github.com/chinese-poetry/chinese-poetry)中收集的全唐诗作为训练数据,共有约5.4万首唐诗。 + +### 原始数据下载 +```bash +cd data && ./download.sh && cd .. +``` +### 数据预处理 +```bash +python preprocess.py --datadir data/raw --outfile data/poems.txt --dictfile data/dict.txt +``` + +上述脚本执行完后将生成处理好的训练数据poems.txt和字典dict.txt。字典的构建以字为单位,使用出现频数至少为10的字构建字典。 + +poems.txt中每行为一首唐诗的信息,分为三列,分别为题目、作者、诗内容。在诗内容中,诗句之间用`.`分隔。 + +训练数据示例: +```text +登鸛雀樓 王之渙 白日依山盡.黃河入海流.欲窮千里目.更上一層樓 +觀獵 李白 太守耀清威.乘閑弄晚暉.江沙橫獵騎.山火遶行圍.箭逐雲鴻落.鷹隨月兔飛.不知白日暮.歡賞夜方歸 +晦日重宴 陳嘉言 高門引冠蓋.下客抱支離.綺席珍羞滿.文場翰藻摛.蓂華彫上月.柳色藹春池.日斜歸戚里.連騎勒金羈 +``` + +模型训练时,使用每一诗句作为模型输入,下一诗句作为预测目标。 + + +## 模型训练 +训练脚本[train.py](./train.py)中的命令行参数可以通过`python train.py --help`查看。主要参数说明如下: +- `num_passes`: 训练pass数 +- `batch_size`: batch大小 +- `use_gpu`: 是否使用GPU +- `trainer_count`: trainer数目,默认为1 +- `save_dir_path`: 模型存储路径,默认为当前目录下models目录 +- `encoder_depth`: 模型中编码器LSTM深度,默认为3 +- `decoder_depth`: 模型中解码器LSTM深度,默认为3 +- `train_data_path`: 训练数据路径 +- `word_dict_path`: 数据字典路径 +- `init_model_path`: 初始模型路径,从头训练时无需指定 + +### 训练执行 +```bash +python train.py \ + --num_passes 50 \ + --batch_size 256 \ + --use_gpu True \ + --trainer_count 1 \ + --save_dir_path models \ + --train_data_path data/poems.txt \ + --word_dict_path data/dict.txt \ + 2>&1 | tee train.log +``` +每个pass训练结束后,模型参数将保存在models目录下。训练日志保存在train.log中。 + +### 最优模型参数 +寻找cost最小的pass,使用该pass对应的模型参数用于后续预测。 +```bash +python -c 'import utils; utils.find_optiaml_pass("./train.log")' +``` + +## 生成诗句 +使用[generate.py](./generate.py)脚本对输入诗句生成下一诗句,命令行参数可通过`python generate.py --help`查看。 +主要参数说明如下: +- `model_path`: 训练好的模型参数文件 +- `word_dict_path`: 数据字典路径 +- `test_data_path`: 输入数据路径 +- `batch_size`: batch大小,默认为1 +- `beam_size`: beam search中搜索范围大小,默认为5 +- `save_file`: 输出保存路径 +- `use_gpu`: 是否使用GPU + +### 执行生成 +例如将诗句 `孤帆遠影碧空盡` 保存在文件 `input.txt` 中作为预测下句诗的输入,执行命令: +```bash +python generate.py \ + --model_path models/pass_00049.tar.gz \ + --word_dict_path data/dict.txt \ + --test_data_path input.txt \ + --save_file output.txt +``` +生成结果将保存在文件 `output.txt` 中。对于上述示例输入,生成的诗句如下: +```text +-9.6987 萬 壑 清 風 黃 葉 多 +-10.0737 萬 里 遠 山 紅 葉 深 +-10.4233 萬 壑 清 波 紅 一 流 +-10.4802 萬 壑 清 風 黃 葉 深 +-10.9060 萬 壑 清 風 紅 葉 多 +``` diff --git a/generate_chinese_poetry/data/download.sh b/generate_chinese_poetry/data/download.sh new file mode 100755 index 0000000000000000000000000000000000000000..988c09c0f27c81854d2e090913d2972cb0ffbb51 --- /dev/null +++ b/generate_chinese_poetry/data/download.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +git clone https://github.com/chinese-poetry/chinese-poetry.git + +if [ ! -d raw ] +then + mkdir raw +fi + +mv chinese-poetry/json/poet.tang.* raw/ +rm -rf chinese-poetry diff --git a/generate_chinese_poetry/generate.py b/generate_chinese_poetry/generate.py index b2d909171fa713bb2e01d711b63060cb8850c768..952de15fbcfdb1193c30c6828b73d3a6a825b473 100755 --- a/generate_chinese_poetry/generate.py +++ b/generate_chinese_poetry/generate.py @@ -28,7 +28,7 @@ def infer_a_batch(inferer, test_batch, beam_size, id_to_text, fout): for j in xrange(beam_size): end_pos = gen_sen_idx[i * beam_size + j] fout.write("%s\n" % ("%.4f\t%s" % (beam_result[0][i][j], " ".join( - id_to_text[w] for w in beam_result[1][start_pos:end_pos])))) + id_to_text[w] for w in beam_result[1][start_pos:end_pos - 1])))) start_pos = end_pos + 2 fout.write("\n") fout.flush @@ -80,9 +80,11 @@ def generate(model_path, word_dict_path, test_data_path, batch_size, beam_size, encoder_hidden_dim=512, decoder_depth=3, decoder_hidden_dim=512, - is_generating=True, + bos_id=0, + eos_id=1, + max_length=9, beam_size=beam_size, - max_length=10) + is_generating=True) inferer = paddle.inference.Inference( output_layer=beam_gen, parameters=parameters) diff --git a/generate_chinese_poetry/network_conf.py b/generate_chinese_poetry/network_conf.py index 5aec3c06b1b2cb9918d3489379df3d2083b39658..1aee1aa249014a70fe09c3fb5ea5da5d81db59a3 100755 --- a/generate_chinese_poetry/network_conf.py +++ b/generate_chinese_poetry/network_conf.py @@ -73,8 +73,10 @@ def encoder_decoder_network(word_count, encoder_hidden_dim, decoder_depth, decoder_hidden_dim, + bos_id, + eos_id, + max_length, beam_size=10, - max_length=15, is_generating=False): src_emb = paddle.layer.embedding( input=paddle.layer.data( @@ -106,8 +108,8 @@ def encoder_decoder_network(word_count, name=decoder_group_name, step=_attended_decoder_step, input=group_inputs + [gen_trg_emb], - bos_id=0, - eos_id=1, + bos_id=bos_id, + eos_id=eos_id, beam_size=beam_size, max_length=max_length) diff --git a/generate_chinese_poetry/preprocess.py b/generate_chinese_poetry/preprocess.py new file mode 100755 index 0000000000000000000000000000000000000000..4018e2e3cb83e00b4c65489f88b46b71f6f20a8f --- /dev/null +++ b/generate_chinese_poetry/preprocess.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +import os +import io +import re +import json +import click +import collections + + +def build_vocabulary(dataset, cutoff=0): + dictionary = collections.defaultdict(int) + for data in dataset: + for sent in data[2]: + for char in sent: + dictionary[char] += 1 + dictionary = filter(lambda x: x[1] >= cutoff, dictionary.items()) + dictionary = sorted(dictionary, key=lambda x: (-x[1], x[0])) + vocab, _ = list(zip(*dictionary)) + return (u"", u"", u"") + vocab + + +@click.command("preprocess") +@click.option("--datadir", type=str, help="Path to raw data") +@click.option("--outfile", type=str, help="Path to save the training data") +@click.option("--dictfile", type=str, help="Path to save the dictionary file") +def preprocess(datadir, outfile, dictfile): + dataset = [] + note_pattern1 = re.compile(u"(.*?)", re.U) + note_pattern2 = re.compile(u"〖.*?〗", re.U) + note_pattern3 = re.compile(u"-.*?-。?", re.U) + note_pattern4 = re.compile(u"(.*$", re.U) + note_pattern5 = re.compile(u"。。.*)$", re.U) + note_pattern6 = re.compile(u"。。", re.U) + note_pattern7 = re.compile(u"[《》「」\[\]]", re.U) + print("Load raw data...") + for fn in os.listdir(datadir): + with io.open(os.path.join(datadir, fn), "r", encoding="utf8") as f: + for data in json.load(f): + title = data['title'] + author = data['author'] + p = "".join(data['paragraphs']) + p = "".join(p.split()) + p = note_pattern1.sub(u"", p) + p = note_pattern2.sub(u"", p) + p = note_pattern3.sub(u"", p) + p = note_pattern4.sub(u"", p) + p = note_pattern5.sub(u"。", p) + p = note_pattern6.sub(u"。", p) + p = note_pattern7.sub(u"", p) + if (p == u"" or u"{" in p or u"}" in p or u"{" in p or + u"}" in p or u"、" in p or u":" in p or u";" in p or + u"!" in p or u"?" in p or u"●" in p or u"□" in p or + u"囗" in p or u")" in p): + continue + paragraphs = re.split(u"。|,", p) + paragraphs = filter(lambda x: len(x), paragraphs) + if len(paragraphs) > 1: + dataset.append((title, author, paragraphs)) + + print("Construct vocabularies...") + vocab = build_vocabulary(dataset, cutoff=10) + with io.open(dictfile, "w", encoding="utf8") as f: + for v in vocab: + f.write(v + "\n") + + print("Write processed data...") + with io.open(outfile, "w", encoding="utf8") as f: + for data in dataset: + title = data[0] + author = data[1] + paragraphs = ".".join(data[2]) + f.write("\t".join((title, author, paragraphs)) + "\n") + + +if __name__ == "__main__": + preprocess() diff --git a/generate_chinese_poetry/train.py b/generate_chinese_poetry/train.py index c6eb737b01da4a89d3d04d87a708fd41c17255c3..18c9d79b316fa6b0c39048212d919194dbe75436 100755 --- a/generate_chinese_poetry/train.py +++ b/generate_chinese_poetry/train.py @@ -44,7 +44,7 @@ def load_initial_model(model_path, parameters): @click.option( "--decoder_depth", default=3, - help="The number of stacked LSTM layers in encoder.") + help="The number of stacked LSTM layers in decoder.") @click.option( "--train_data_path", required=True, help="The path of trainning data.") @click.option( @@ -75,10 +75,9 @@ def train(num_passes, paddle.init(use_gpu=use_gpu, trainer_count=trainer_count) # define optimization method and the trainer instance - optimizer = paddle.optimizer.AdaDelta( - learning_rate=1e-3, - gradient_clipping_threshold=25.0, - regularization=paddle.optimizer.L2Regularization(rate=8e-4), + optimizer = paddle.optimizer.Adam( + learning_rate=1e-4, + regularization=paddle.optimizer.L2Regularization(rate=1e-5), model_average=paddle.optimizer.ModelAverage( average_window=0.5, max_average_window=2500)) @@ -88,7 +87,10 @@ def train(num_passes, encoder_depth=encoder_depth, encoder_hidden_dim=512, decoder_depth=decoder_depth, - decoder_hidden_dim=512) + decoder_hidden_dim=512, + bos_id=0, + eos_id=1, + max_length=9) parameters = paddle.parameters.create(cost) if init_model_path: @@ -113,7 +115,7 @@ def train(num_passes, (event.pass_id, event.batch_id)) save_model(trainer, save_path, parameters) - if not event.batch_id % 5: + if not event.batch_id % 10: logger.info("Pass %d, Batch %d, Cost %f, %s" % ( event.pass_id, event.batch_id, event.cost, event.metrics))