From 16c4afee51d410dd3a1118b11e4a5ca3539608bc Mon Sep 17 00:00:00 2001 From: wangmeng28 Date: Fri, 10 Nov 2017 20:25:37 +0800 Subject: [PATCH] Refine parameters for training chinese poetry generation --- generate_chinese_poetry/README.md | 19 +++++++++++++++++++ generate_chinese_poetry/data/dict.txt | 6 ++++++ generate_chinese_poetry/generate.py | 8 +++++--- generate_chinese_poetry/index.html | 19 +++++++++++++++++++ generate_chinese_poetry/network_conf.py | 8 +++++--- generate_chinese_poetry/preprocess.py | 4 ++++ generate_chinese_poetry/train.py | 12 +++++++----- 7 files changed, 65 insertions(+), 11 deletions(-) diff --git a/generate_chinese_poetry/README.md b/generate_chinese_poetry/README.md index 7a914a71..2cacb81b 100644 --- a/generate_chinese_poetry/README.md +++ b/generate_chinese_poetry/README.md @@ -78,7 +78,11 @@ Options: ### 训练执行 ```bash python train.py \ +<<<<<<< HEAD + --num_passes 20 \ +======= --num_passes 10 \ +>>>>>>> 7943732ab34254df801d72b0b5e04f6f320e4127 --batch_size 256 \ --use_gpu True \ --trainer_count 1 \ @@ -126,9 +130,24 @@ Options: 例如将诗句 `白日依山盡,黃河入海流` 保存在文件 `input.txt` 中作为预测下句诗的输入,执行命令: ```bash python generate.py \ +<<<<<<< HEAD + --model_path models/pass_00014.tar.gz \ +======= --model_path models/pass_00100.tar.gz \ +>>>>>>> 7943732ab34254df801d72b0b5e04f6f320e4127 --word_dict_path data/dict.txt \ --test_data_path input.txt \ --save_file output.txt ``` +<<<<<<< HEAD +生成结果将保存在文件 `output.txt` 中。对于上述示例输入,生成的诗句如下: +```text +-21.2048 不 知 身 外 事 , 何 處 是 閑 遊 +-21.3982 不 知 身 外 事 , 何 處 是 何 由 +-21.6564 不 知 身 外 事 , 何 處 是 何 求 +-21.7312 不 知 身 外 事 , 何 事 是 何 求 +-22.1956 不 知 身 外 事 , 何 處 是 人 愁 +``` +======= 生成结果将保存在文件 `output.txt` 中。 +>>>>>>> 7943732ab34254df801d72b0b5e04f6f320e4127 diff --git a/generate_chinese_poetry/data/dict.txt b/generate_chinese_poetry/data/dict.txt index 7eef9785..d8328157 100644 --- a/generate_chinese_poetry/data/dict.txt +++ b/generate_chinese_poetry/data/dict.txt @@ -1,6 +1,12 @@ +<<<<<<< HEAD + + + +======= +>>>>>>> 7943732ab34254df801d72b0b5e04f6f320e4127 , 不 人 diff --git a/generate_chinese_poetry/generate.py b/generate_chinese_poetry/generate.py index b2d90917..ebecc586 100755 --- a/generate_chinese_poetry/generate.py +++ b/generate_chinese_poetry/generate.py @@ -28,7 +28,7 @@ def infer_a_batch(inferer, test_batch, beam_size, id_to_text, fout): for j in xrange(beam_size): end_pos = gen_sen_idx[i * beam_size + j] fout.write("%s\n" % ("%.4f\t%s" % (beam_result[0][i][j], " ".join( - id_to_text[w] for w in beam_result[1][start_pos:end_pos])))) + id_to_text[w] for w in beam_result[1][start_pos:end_pos - 1])))) start_pos = end_pos + 2 fout.write("\n") fout.flush @@ -80,9 +80,11 @@ def generate(model_path, word_dict_path, test_data_path, batch_size, beam_size, encoder_hidden_dim=512, decoder_depth=3, decoder_hidden_dim=512, - is_generating=True, + bos_id=0, + eos_id=1, + max_length=17, beam_size=beam_size, - max_length=10) + is_generating=True) inferer = paddle.inference.Inference( output_layer=beam_gen, parameters=parameters) diff --git a/generate_chinese_poetry/index.html b/generate_chinese_poetry/index.html index 28c31c4e..859091ef 100644 --- a/generate_chinese_poetry/index.html +++ b/generate_chinese_poetry/index.html @@ -120,7 +120,11 @@ Options: ### 训练执行 ```bash python train.py \ +<<<<<<< HEAD + --num_passes 20 \ +======= --num_passes 10 \ +>>>>>>> 7943732ab34254df801d72b0b5e04f6f320e4127 --batch_size 256 \ --use_gpu True \ --trainer_count 1 \ @@ -168,12 +172,27 @@ Options: 例如将诗句 `白日依山盡,黃河入海流` 保存在文件 `input.txt` 中作为预测下句诗的输入,执行命令: ```bash python generate.py \ +<<<<<<< HEAD + --model_path models/pass_00014.tar.gz \ +======= --model_path models/pass_00100.tar.gz \ +>>>>>>> 7943732ab34254df801d72b0b5e04f6f320e4127 --word_dict_path data/dict.txt \ --test_data_path input.txt \ --save_file output.txt ``` +<<<<<<< HEAD +生成结果将保存在文件 `output.txt` 中。对于上述示例输入,生成的诗句如下: +```text +-21.2048 不 知 身 外 事 , 何 處 是 閑 遊 +-21.3982 不 知 身 外 事 , 何 處 是 何 由 +-21.6564 不 知 身 外 事 , 何 處 是 何 求 +-21.7312 不 知 身 外 事 , 何 事 是 何 求 +-22.1956 不 知 身 外 事 , 何 處 是 人 愁 +``` +======= 生成结果将保存在文件 `output.txt` 中。 +>>>>>>> 7943732ab34254df801d72b0b5e04f6f320e4127 diff --git a/generate_chinese_poetry/network_conf.py b/generate_chinese_poetry/network_conf.py index 5aec3c06..1aee1aa2 100755 --- a/generate_chinese_poetry/network_conf.py +++ b/generate_chinese_poetry/network_conf.py @@ -73,8 +73,10 @@ def encoder_decoder_network(word_count, encoder_hidden_dim, decoder_depth, decoder_hidden_dim, + bos_id, + eos_id, + max_length, beam_size=10, - max_length=15, is_generating=False): src_emb = paddle.layer.embedding( input=paddle.layer.data( @@ -106,8 +108,8 @@ def encoder_decoder_network(word_count, name=decoder_group_name, step=_attended_decoder_step, input=group_inputs + [gen_trg_emb], - bos_id=0, - eos_id=1, + bos_id=bos_id, + eos_id=eos_id, beam_size=beam_size, max_length=max_length) diff --git a/generate_chinese_poetry/preprocess.py b/generate_chinese_poetry/preprocess.py index d24b5d1b..79e78de5 100755 --- a/generate_chinese_poetry/preprocess.py +++ b/generate_chinese_poetry/preprocess.py @@ -16,7 +16,11 @@ def build_vocabulary(dataset, cutoff=0): dictionary = filter(lambda x: x[1] >= cutoff, dictionary.items()) dictionary = sorted(dictionary, key=lambda x: (-x[1], x[0])) vocab, _ = list(zip(*dictionary)) +<<<<<<< HEAD + return (u"", u"", u"") + vocab +======= return (u"", u"", u"") + vocab +>>>>>>> 7943732ab34254df801d72b0b5e04f6f320e4127 @click.command("preprocess") diff --git a/generate_chinese_poetry/train.py b/generate_chinese_poetry/train.py index a9ef5646..70cafbb5 100755 --- a/generate_chinese_poetry/train.py +++ b/generate_chinese_poetry/train.py @@ -75,10 +75,9 @@ def train(num_passes, paddle.init(use_gpu=use_gpu, trainer_count=trainer_count) # define optimization method and the trainer instance - optimizer = paddle.optimizer.AdaDelta( + optimizer = paddle.optimizer.Adam( learning_rate=1e-3, - gradient_clipping_threshold=25.0, - regularization=paddle.optimizer.L2Regularization(rate=8e-4), + regularization=paddle.optimizer.L2Regularization(rate=1e-5), model_average=paddle.optimizer.ModelAverage( average_window=0.5, max_average_window=2500)) @@ -88,7 +87,10 @@ def train(num_passes, encoder_depth=encoder_depth, encoder_hidden_dim=512, decoder_depth=decoder_depth, - decoder_hidden_dim=512) + decoder_hidden_dim=512, + bos_id=0, + eos_id=1, + max_length=17) parameters = paddle.parameters.create(cost) if init_model_path: @@ -113,7 +115,7 @@ def train(num_passes, (event.pass_id, event.batch_id)) save_model(trainer, save_path, parameters) - if not event.batch_id % 5: + if not event.batch_id % 10: logger.info("Pass %d, Batch %d, Cost %f, %s" % ( event.pass_id, event.batch_id, event.cost, event.metrics)) -- GitLab