From 662d2994579364fac4a7fbf9ccde622e276868f3 Mon Sep 17 00:00:00 2001 From: caoying03 Date: Wed, 27 Sep 2017 16:48:12 +0800 Subject: [PATCH] add command line parser. --- generate_chinese_poetry/generate.py | 43 +++++++---- generate_chinese_poetry/network_conf.py | 3 - generate_chinese_poetry/reader.py | 3 - generate_chinese_poetry/train.py | 81 +++++++++++++-------- generate_chinese_poetry/utils.py | 7 -- generate_sequence_by_rnn_lm/beam_search.py | 3 - generate_sequence_by_rnn_lm/generate.py | 2 - generate_sequence_by_rnn_lm/network_conf.py | 2 - generate_sequence_by_rnn_lm/reader.py | 2 - generate_sequence_by_rnn_lm/train.py | 3 - generate_sequence_by_rnn_lm/utils.py | 3 - 11 files changed, 82 insertions(+), 70 deletions(-) diff --git a/generate_chinese_poetry/generate.py b/generate_chinese_poetry/generate.py index b412f7f7..b2d90917 100755 --- a/generate_chinese_poetry/generate.py +++ b/generate_chinese_poetry/generate.py @@ -1,11 +1,9 @@ -#!/usr/bin/env python -#coding=utf-8 - import os import sys import gzip import logging import numpy as np +import click import reader import paddle.v2 as paddle @@ -36,19 +34,45 @@ def infer_a_batch(inferer, test_batch, beam_size, id_to_text, fout): fout.flush +@click.command("generate") +@click.option( + "--model_path", + default="", + help="The path of the trained model for generation.") +@click.option( + "--word_dict_path", required=True, help="The path of word dictionary.") +@click.option( + "--test_data_path", + required=True, + help="The path of input data for generation.") +@click.option( + "--batch_size", + default=1, + help="The number of testing examples in one forward pass in generation.") +@click.option( + "--beam_size", default=5, help="The beam expansion in beam search.") +@click.option( + "--save_file", + required=True, + help="The file path to save the generated results.") +@click.option( + "--use_gpu", default=False, help="Whether to use GPU in generation.") def generate(model_path, word_dict_path, test_data_path, batch_size, beam_size, save_file, use_gpu): - assert os.path.exists(model_path), "trained model does not exist." + assert os.path.exists(model_path), "The given model does not exist." + assert os.path.exists(test_data_path), "The given test data does not exist." - paddle.init(use_gpu=use_gpu, trainer_count=1) with gzip.open(model_path, "r") as f: parameters = paddle.parameters.Parameters.from_tar(f) id_to_text = {} + assert os.path.exists( + word_dict_path), "The given word dictionary path does not exist." with open(word_dict_path, "r") as f: for i, line in enumerate(f): id_to_text[i] = line.strip().split("\t")[0] + paddle.init(use_gpu=use_gpu, trainer_count=1) beam_gen = encoder_decoder_network( word_count=len(id_to_text), emb_dim=512, @@ -78,11 +102,4 @@ def generate(model_path, word_dict_path, test_data_path, batch_size, beam_size, if __name__ == "__main__": - generate( - model_path="models/pass_00025.tar.gz", - word_dict_path="data/word_dict.txt", - test_data_path="data/input.txt", - save_file="gen_result.txt", - batch_size=4, - beam_size=5, - use_gpu=True) + generate() diff --git a/generate_chinese_poetry/network_conf.py b/generate_chinese_poetry/network_conf.py index 3ee6f64a..5aec3c06 100755 --- a/generate_chinese_poetry/network_conf.py +++ b/generate_chinese_poetry/network_conf.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -#coding=utf-8 - import paddle.v2 as paddle from paddle.v2.layer import parse_network diff --git a/generate_chinese_poetry/reader.py b/generate_chinese_poetry/reader.py index b1f4f54d..4ecdb041 100755 --- a/generate_chinese_poetry/reader.py +++ b/generate_chinese_poetry/reader.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - from utils import load_dict diff --git a/generate_chinese_poetry/train.py b/generate_chinese_poetry/train.py index 28a9ec28..a641efa5 100755 --- a/generate_chinese_poetry/train.py +++ b/generate_chinese_poetry/train.py @@ -1,9 +1,7 @@ -#!/usr/bin/env python -#coding=utf-8 - -import gzip import os +import gzip import logging +import click import paddle.v2 as paddle import reader @@ -24,24 +22,59 @@ def load_initial_model(model_path, parameters): parameters.init_from_tar(f) -def main(num_passes, - batch_size, - use_gpu, - trainer_count, - save_dir_path, - encoder_depth, - decoder_depth, - word_dict_path, - train_data_path, - init_model_path=""): +@click.command("train") +@click.option( + "--num_passes", default=10, help="Number of passes for the training task.") +@click.option( + "--batch_size", + default=16, + help="The number of training examples in one forward/backward pass.") +@click.option( + "--use_gpu", default=False, help="Whether to use gpu to train the model.") +@click.option( + "--trainer_count", default=1, help="The thread number used in training.") +@click.option( + "--save_dir_path", + default="models", + help="The path to saved the trained models.") +@click.option( + "--encoder_depth", + default=3, + help="The number of stacked LSTM layers in encoder.") +@click.option( + "--decoder_depth", + default=3, + help="The number of stacked LSTM layers in encoder.") +@click.option( + "--train_data_path", required=True, help="The path of trainning data.") +@click.option( + "--word_dict_path", required=True, help="The path of word dictionary.") +@click.option( + "--init_model_path", + default="", + help=("The path of a trained model used to initialized all " + "the model parameters.")) +def train(num_passes, + batch_size, + use_gpu, + trainer_count, + save_dir_path, + encoder_depth, + decoder_depth, + train_data_path, + word_dict_path, + init_model_path=""): if not os.path.exists(save_dir_path): os.mkdir(save_dir_path) + assert os.path.exists( + word_dict_path), "The given word dictionary does not exist." + assert os.path.exists( + train_data_path), "The given training data does not exist." # initialize PaddlePaddle - paddle.init(use_gpu=use_gpu, trainer_count=trainer_count, parallel_nn=1) + paddle.init(use_gpu=use_gpu, trainer_count=trainer_count) # define optimization method and the trainer instance - # optimizer = paddle.optimizer.Adam( optimizer = paddle.optimizer.AdaDelta( learning_rate=1e-3, gradient_clipping_threshold=25.0, @@ -74,7 +107,7 @@ def main(num_passes, # define the event_handler callback def event_handler(event): if isinstance(event, paddle.event.EndIteration): - if (not event.batch_id % 2000) and event.batch_id: + if (not event.batch_id % 1000) and event.batch_id: save_path = os.path.join(save_dir_path, "pass_%05d_batch_%05d.tar.gz" % (event.pass_id, event.batch_id)) @@ -94,15 +127,5 @@ def main(num_passes, reader=train_reader, event_handler=event_handler, num_passes=num_passes) -if __name__ == '__main__': - main( - num_passes=500, - batch_size=4 * 500, - use_gpu=True, - trainer_count=4, - encoder_depth=3, - decoder_depth=3, - save_dir_path="models", - word_dict_path="data/word_dict.txt", - train_data_path="data/song.poet.txt", - init_model_path="") +if __name__ == "__main__": + train() diff --git a/generate_chinese_poetry/utils.py b/generate_chinese_poetry/utils.py index 628fa841..f8a20bf4 100755 --- a/generate_chinese_poetry/utils.py +++ b/generate_chinese_poetry/utils.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -#coding=utf-8 - import os import sys import re @@ -30,7 +27,3 @@ def find_optiaml_pass(log_file): cost_info.iteritems(), key=lambda x: sum(x[1]) / (len(x[1])), reverse=False)[0][0]) - - -if __name__ == '__main__': - find_optiaml_pass('trained_models/models_first_round/train.log') diff --git a/generate_sequence_by_rnn_lm/beam_search.py b/generate_sequence_by_rnn_lm/beam_search.py index d71223cd..f6d1d364 100644 --- a/generate_sequence_by_rnn_lm/beam_search.py +++ b/generate_sequence_by_rnn_lm/beam_search.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# coding=utf-8 - import os import math import numpy as np diff --git a/generate_sequence_by_rnn_lm/generate.py b/generate_sequence_by_rnn_lm/generate.py index 71c66e86..07729e74 100644 --- a/generate_sequence_by_rnn_lm/generate.py +++ b/generate_sequence_by_rnn_lm/generate.py @@ -1,5 +1,3 @@ -# coding=utf-8 - import os import gzip import numpy as np diff --git a/generate_sequence_by_rnn_lm/network_conf.py b/generate_sequence_by_rnn_lm/network_conf.py index 7306337b..f1aceb0b 100644 --- a/generate_sequence_by_rnn_lm/network_conf.py +++ b/generate_sequence_by_rnn_lm/network_conf.py @@ -1,5 +1,3 @@ -# coding=utf-8 - import paddle.v2 as paddle diff --git a/generate_sequence_by_rnn_lm/reader.py b/generate_sequence_by_rnn_lm/reader.py index d0b95fb8..1c6bc7a8 100644 --- a/generate_sequence_by_rnn_lm/reader.py +++ b/generate_sequence_by_rnn_lm/reader.py @@ -1,5 +1,3 @@ -# coding=utf-8 - import collections import os diff --git a/generate_sequence_by_rnn_lm/train.py b/generate_sequence_by_rnn_lm/train.py index f85da2bd..00e8ffa7 100644 --- a/generate_sequence_by_rnn_lm/train.py +++ b/generate_sequence_by_rnn_lm/train.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# coding=utf-8 - import os import sys import gzip diff --git a/generate_sequence_by_rnn_lm/utils.py b/generate_sequence_by_rnn_lm/utils.py index 4490a9de..57edf375 100644 --- a/generate_sequence_by_rnn_lm/utils.py +++ b/generate_sequence_by_rnn_lm/utils.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# coding=utf-8 - import os import logging from collections import defaultdict -- GitLab