diff --git a/generate_chinese_poetry/generate.py b/generate_chinese_poetry/generate.py index b412f7f75b167f2ea87d99af24c5d77ea8db261b..b2d909171fa713bb2e01d711b63060cb8850c768 100755 --- a/generate_chinese_poetry/generate.py +++ b/generate_chinese_poetry/generate.py @@ -1,11 +1,9 @@ -#!/usr/bin/env python -#coding=utf-8 - import os import sys import gzip import logging import numpy as np +import click import reader import paddle.v2 as paddle @@ -36,19 +34,45 @@ def infer_a_batch(inferer, test_batch, beam_size, id_to_text, fout): fout.flush +@click.command("generate") +@click.option( + "--model_path", + default="", + help="The path of the trained model for generation.") +@click.option( + "--word_dict_path", required=True, help="The path of word dictionary.") +@click.option( + "--test_data_path", + required=True, + help="The path of input data for generation.") +@click.option( + "--batch_size", + default=1, + help="The number of testing examples in one forward pass in generation.") +@click.option( + "--beam_size", default=5, help="The beam expansion in beam search.") +@click.option( + "--save_file", + required=True, + help="The file path to save the generated results.") +@click.option( + "--use_gpu", default=False, help="Whether to use GPU in generation.") def generate(model_path, word_dict_path, test_data_path, batch_size, beam_size, save_file, use_gpu): - assert os.path.exists(model_path), "trained model does not exist." + assert os.path.exists(model_path), "The given model does not exist." + assert os.path.exists(test_data_path), "The given test data does not exist." - paddle.init(use_gpu=use_gpu, trainer_count=1) with gzip.open(model_path, "r") as f: parameters = paddle.parameters.Parameters.from_tar(f) id_to_text = {} + assert os.path.exists( + word_dict_path), "The given word dictionary path does not exist." with open(word_dict_path, "r") as f: for i, line in enumerate(f): id_to_text[i] = line.strip().split("\t")[0] + paddle.init(use_gpu=use_gpu, trainer_count=1) beam_gen = encoder_decoder_network( word_count=len(id_to_text), emb_dim=512, @@ -78,11 +102,4 @@ def generate(model_path, word_dict_path, test_data_path, batch_size, beam_size, if __name__ == "__main__": - generate( - model_path="models/pass_00025.tar.gz", - word_dict_path="data/word_dict.txt", - test_data_path="data/input.txt", - save_file="gen_result.txt", - batch_size=4, - beam_size=5, - use_gpu=True) + generate() diff --git a/generate_chinese_poetry/network_conf.py b/generate_chinese_poetry/network_conf.py index 3ee6f64a5ac75bb5f9e5d5793d93ff77298780b6..5aec3c06b1b2cb9918d3489379df3d2083b39658 100755 --- a/generate_chinese_poetry/network_conf.py +++ b/generate_chinese_poetry/network_conf.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -#coding=utf-8 - import paddle.v2 as paddle from paddle.v2.layer import parse_network diff --git a/generate_chinese_poetry/reader.py b/generate_chinese_poetry/reader.py index b1f4f54d8d8ae1a0ff7f1d2b22173418c112852c..4ecdb041164e917f72a7a21965bc7368f53bc34c 100755 --- a/generate_chinese_poetry/reader.py +++ b/generate_chinese_poetry/reader.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - from utils import load_dict diff --git a/generate_chinese_poetry/train.py b/generate_chinese_poetry/train.py index 28a9ec28da94e26ae4c5e95d2223fcf535c75491..a641efa59a82d9a8e2b3e37c393e347063568685 100755 --- a/generate_chinese_poetry/train.py +++ b/generate_chinese_poetry/train.py @@ -1,9 +1,7 @@ -#!/usr/bin/env python -#coding=utf-8 - -import gzip import os +import gzip import logging +import click import paddle.v2 as paddle import reader @@ -24,24 +22,59 @@ def load_initial_model(model_path, parameters): parameters.init_from_tar(f) -def main(num_passes, - batch_size, - use_gpu, - trainer_count, - save_dir_path, - encoder_depth, - decoder_depth, - word_dict_path, - train_data_path, - init_model_path=""): +@click.command("train") +@click.option( + "--num_passes", default=10, help="Number of passes for the training task.") +@click.option( + "--batch_size", + default=16, + help="The number of training examples in one forward/backward pass.") +@click.option( + "--use_gpu", default=False, help="Whether to use gpu to train the model.") +@click.option( + "--trainer_count", default=1, help="The thread number used in training.") +@click.option( + "--save_dir_path", + default="models", + help="The path to saved the trained models.") +@click.option( + "--encoder_depth", + default=3, + help="The number of stacked LSTM layers in encoder.") +@click.option( + "--decoder_depth", + default=3, + help="The number of stacked LSTM layers in encoder.") +@click.option( + "--train_data_path", required=True, help="The path of trainning data.") +@click.option( + "--word_dict_path", required=True, help="The path of word dictionary.") +@click.option( + "--init_model_path", + default="", + help=("The path of a trained model used to initialized all " + "the model parameters.")) +def train(num_passes, + batch_size, + use_gpu, + trainer_count, + save_dir_path, + encoder_depth, + decoder_depth, + train_data_path, + word_dict_path, + init_model_path=""): if not os.path.exists(save_dir_path): os.mkdir(save_dir_path) + assert os.path.exists( + word_dict_path), "The given word dictionary does not exist." + assert os.path.exists( + train_data_path), "The given training data does not exist." # initialize PaddlePaddle - paddle.init(use_gpu=use_gpu, trainer_count=trainer_count, parallel_nn=1) + paddle.init(use_gpu=use_gpu, trainer_count=trainer_count) # define optimization method and the trainer instance - # optimizer = paddle.optimizer.Adam( optimizer = paddle.optimizer.AdaDelta( learning_rate=1e-3, gradient_clipping_threshold=25.0, @@ -74,7 +107,7 @@ def main(num_passes, # define the event_handler callback def event_handler(event): if isinstance(event, paddle.event.EndIteration): - if (not event.batch_id % 2000) and event.batch_id: + if (not event.batch_id % 1000) and event.batch_id: save_path = os.path.join(save_dir_path, "pass_%05d_batch_%05d.tar.gz" % (event.pass_id, event.batch_id)) @@ -94,15 +127,5 @@ def main(num_passes, reader=train_reader, event_handler=event_handler, num_passes=num_passes) -if __name__ == '__main__': - main( - num_passes=500, - batch_size=4 * 500, - use_gpu=True, - trainer_count=4, - encoder_depth=3, - decoder_depth=3, - save_dir_path="models", - word_dict_path="data/word_dict.txt", - train_data_path="data/song.poet.txt", - init_model_path="") +if __name__ == "__main__": + train() diff --git a/generate_chinese_poetry/utils.py b/generate_chinese_poetry/utils.py index 628fa841458663c816a5054cf5c7d04ac2a3c9c7..f8a20bf4203bc091c8002953b3b3d7df12be25ef 100755 --- a/generate_chinese_poetry/utils.py +++ b/generate_chinese_poetry/utils.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -#coding=utf-8 - import os import sys import re @@ -30,7 +27,3 @@ def find_optiaml_pass(log_file): cost_info.iteritems(), key=lambda x: sum(x[1]) / (len(x[1])), reverse=False)[0][0]) - - -if __name__ == '__main__': - find_optiaml_pass('trained_models/models_first_round/train.log') diff --git a/generate_sequence_by_rnn_lm/beam_search.py b/generate_sequence_by_rnn_lm/beam_search.py index d71223cd58899427b551eb639afd15d0ab0d2d9a..f6d1d3646cbd6b42a28b8531d2f11fe2856a5c4d 100644 --- a/generate_sequence_by_rnn_lm/beam_search.py +++ b/generate_sequence_by_rnn_lm/beam_search.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# coding=utf-8 - import os import math import numpy as np diff --git a/generate_sequence_by_rnn_lm/generate.py b/generate_sequence_by_rnn_lm/generate.py index 71c66e86db1171ca3e618f8a855b7b6fa3f0198d..07729e7404bce598790b78a46b073e01e4e46484 100644 --- a/generate_sequence_by_rnn_lm/generate.py +++ b/generate_sequence_by_rnn_lm/generate.py @@ -1,5 +1,3 @@ -# coding=utf-8 - import os import gzip import numpy as np diff --git a/generate_sequence_by_rnn_lm/network_conf.py b/generate_sequence_by_rnn_lm/network_conf.py index 7306337bf7515ddfd4df137c3ee81f8aa4fa7b90..f1aceb0b7d70c6a2aec601cf935d2f34500d20fc 100644 --- a/generate_sequence_by_rnn_lm/network_conf.py +++ b/generate_sequence_by_rnn_lm/network_conf.py @@ -1,5 +1,3 @@ -# coding=utf-8 - import paddle.v2 as paddle diff --git a/generate_sequence_by_rnn_lm/reader.py b/generate_sequence_by_rnn_lm/reader.py index d0b95fb861c4748b722d3e2518fbbd6554a9a3b9..1c6bc7a8a83dbd028b11351cb55ace5b529b0268 100644 --- a/generate_sequence_by_rnn_lm/reader.py +++ b/generate_sequence_by_rnn_lm/reader.py @@ -1,5 +1,3 @@ -# coding=utf-8 - import collections import os diff --git a/generate_sequence_by_rnn_lm/train.py b/generate_sequence_by_rnn_lm/train.py index f85da2bd5945c751ede77021b5039cdb4acbb2f7..00e8ffa742ffd1966ee6f39b2fb14394ac870fcf 100644 --- a/generate_sequence_by_rnn_lm/train.py +++ b/generate_sequence_by_rnn_lm/train.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# coding=utf-8 - import os import sys import gzip diff --git a/generate_sequence_by_rnn_lm/utils.py b/generate_sequence_by_rnn_lm/utils.py index 4490a9de5c3addf624ca65629c82c47a86e45eb9..57edf3758e995a5c104b7a7cbf37edf1ba95dfbc 100644 --- a/generate_sequence_by_rnn_lm/utils.py +++ b/generate_sequence_by_rnn_lm/utils.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# coding=utf-8 - import os import logging from collections import defaultdict