提交 662d2994 编写于 作者: C caoying03

add command line parser.

上级 65b355fe
#!/usr/bin/env python
#coding=utf-8
import os
import sys
import gzip
import logging
import numpy as np
import click
import reader
import paddle.v2 as paddle
......@@ -36,19 +34,45 @@ def infer_a_batch(inferer, test_batch, beam_size, id_to_text, fout):
fout.flush
@click.command("generate")
@click.option(
"--model_path",
default="",
help="The path of the trained model for generation.")
@click.option(
"--word_dict_path", required=True, help="The path of word dictionary.")
@click.option(
"--test_data_path",
required=True,
help="The path of input data for generation.")
@click.option(
"--batch_size",
default=1,
help="The number of testing examples in one forward pass in generation.")
@click.option(
"--beam_size", default=5, help="The beam expansion in beam search.")
@click.option(
"--save_file",
required=True,
help="The file path to save the generated results.")
@click.option(
"--use_gpu", default=False, help="Whether to use GPU in generation.")
def generate(model_path, word_dict_path, test_data_path, batch_size, beam_size,
save_file, use_gpu):
assert os.path.exists(model_path), "trained model does not exist."
assert os.path.exists(model_path), "The given model does not exist."
assert os.path.exists(test_data_path), "The given test data does not exist."
paddle.init(use_gpu=use_gpu, trainer_count=1)
with gzip.open(model_path, "r") as f:
parameters = paddle.parameters.Parameters.from_tar(f)
id_to_text = {}
assert os.path.exists(
word_dict_path), "The given word dictionary path does not exist."
with open(word_dict_path, "r") as f:
for i, line in enumerate(f):
id_to_text[i] = line.strip().split("\t")[0]
paddle.init(use_gpu=use_gpu, trainer_count=1)
beam_gen = encoder_decoder_network(
word_count=len(id_to_text),
emb_dim=512,
......@@ -78,11 +102,4 @@ def generate(model_path, word_dict_path, test_data_path, batch_size, beam_size,
if __name__ == "__main__":
generate(
model_path="models/pass_00025.tar.gz",
word_dict_path="data/word_dict.txt",
test_data_path="data/input.txt",
save_file="gen_result.txt",
batch_size=4,
beam_size=5,
use_gpu=True)
generate()
#!/usr/bin/env python
#coding=utf-8
import paddle.v2 as paddle
from paddle.v2.layer import parse_network
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from utils import load_dict
......
#!/usr/bin/env python
#coding=utf-8
import gzip
import os
import gzip
import logging
import click
import paddle.v2 as paddle
import reader
......@@ -24,24 +22,59 @@ def load_initial_model(model_path, parameters):
parameters.init_from_tar(f)
def main(num_passes,
batch_size,
use_gpu,
trainer_count,
save_dir_path,
encoder_depth,
decoder_depth,
word_dict_path,
train_data_path,
init_model_path=""):
@click.command("train")
@click.option(
"--num_passes", default=10, help="Number of passes for the training task.")
@click.option(
"--batch_size",
default=16,
help="The number of training examples in one forward/backward pass.")
@click.option(
"--use_gpu", default=False, help="Whether to use gpu to train the model.")
@click.option(
"--trainer_count", default=1, help="The thread number used in training.")
@click.option(
"--save_dir_path",
default="models",
help="The path to saved the trained models.")
@click.option(
"--encoder_depth",
default=3,
help="The number of stacked LSTM layers in encoder.")
@click.option(
"--decoder_depth",
default=3,
help="The number of stacked LSTM layers in encoder.")
@click.option(
"--train_data_path", required=True, help="The path of trainning data.")
@click.option(
"--word_dict_path", required=True, help="The path of word dictionary.")
@click.option(
"--init_model_path",
default="",
help=("The path of a trained model used to initialized all "
"the model parameters."))
def train(num_passes,
batch_size,
use_gpu,
trainer_count,
save_dir_path,
encoder_depth,
decoder_depth,
train_data_path,
word_dict_path,
init_model_path=""):
if not os.path.exists(save_dir_path):
os.mkdir(save_dir_path)
assert os.path.exists(
word_dict_path), "The given word dictionary does not exist."
assert os.path.exists(
train_data_path), "The given training data does not exist."
# initialize PaddlePaddle
paddle.init(use_gpu=use_gpu, trainer_count=trainer_count, parallel_nn=1)
paddle.init(use_gpu=use_gpu, trainer_count=trainer_count)
# define optimization method and the trainer instance
# optimizer = paddle.optimizer.Adam(
optimizer = paddle.optimizer.AdaDelta(
learning_rate=1e-3,
gradient_clipping_threshold=25.0,
......@@ -74,7 +107,7 @@ def main(num_passes,
# define the event_handler callback
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if (not event.batch_id % 2000) and event.batch_id:
if (not event.batch_id % 1000) and event.batch_id:
save_path = os.path.join(save_dir_path,
"pass_%05d_batch_%05d.tar.gz" %
(event.pass_id, event.batch_id))
......@@ -94,15 +127,5 @@ def main(num_passes,
reader=train_reader, event_handler=event_handler, num_passes=num_passes)
if __name__ == '__main__':
main(
num_passes=500,
batch_size=4 * 500,
use_gpu=True,
trainer_count=4,
encoder_depth=3,
decoder_depth=3,
save_dir_path="models",
word_dict_path="data/word_dict.txt",
train_data_path="data/song.poet.txt",
init_model_path="")
if __name__ == "__main__":
train()
#!/usr/bin/env python
#coding=utf-8
import os
import sys
import re
......@@ -30,7 +27,3 @@ def find_optiaml_pass(log_file):
cost_info.iteritems(),
key=lambda x: sum(x[1]) / (len(x[1])),
reverse=False)[0][0])
if __name__ == '__main__':
find_optiaml_pass('trained_models/models_first_round/train.log')
#!/usr/bin/env python
# coding=utf-8
import os
import math
import numpy as np
......
# coding=utf-8
import os
import gzip
import numpy as np
......
# coding=utf-8
import paddle.v2 as paddle
......
# coding=utf-8
import collections
import os
......
#!/usr/bin/env python
# coding=utf-8
import os
import sys
import gzip
......
#!/usr/bin/env python
# coding=utf-8
import os
import logging
from collections import defaultdict
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册