train.py 4.2 KB
Newer Older
1
import os
C
caoying03 已提交
2
import gzip
3
import logging
C
caoying03 已提交
4
import click
5 6 7 8 9 10 11 12 13 14

import paddle.v2 as paddle
import reader
from paddle.v2.layer import parse_network
from network_conf import encoder_decoder_network

logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO)


15
def save_model(trainer, save_path, parameters):
16
    with gzip.open(save_path, "w") as f:
17
        trainer.save_parameter_to_tar(f)
18 19 20 21 22 23 24


def load_initial_model(model_path, parameters):
    with gzip.open(model_path, "rb") as f:
        parameters.init_from_tar(f)


C
caoying03 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
@click.command("train")
@click.option(
    "--num_passes", default=10, help="Number of passes for the training task.")
@click.option(
    "--batch_size",
    default=16,
    help="The number of training examples in one forward/backward pass.")
@click.option(
    "--use_gpu", default=False, help="Whether to use gpu to train the model.")
@click.option(
    "--trainer_count", default=1, help="The thread number used in training.")
@click.option(
    "--save_dir_path",
    default="models",
    help="The path to saved the trained models.")
@click.option(
    "--encoder_depth",
    default=3,
    help="The number of stacked LSTM layers in encoder.")
@click.option(
    "--decoder_depth",
    default=3,
47
    help="The number of stacked LSTM layers in decoder.")
C
caoying03 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
@click.option(
    "--train_data_path", required=True, help="The path of trainning data.")
@click.option(
    "--word_dict_path", required=True, help="The path of word dictionary.")
@click.option(
    "--init_model_path",
    default="",
    help=("The path of a trained model used to initialized all "
          "the model parameters."))
def train(num_passes,
          batch_size,
          use_gpu,
          trainer_count,
          save_dir_path,
          encoder_depth,
          decoder_depth,
          train_data_path,
          word_dict_path,
          init_model_path=""):
67 68
    if not os.path.exists(save_dir_path):
        os.mkdir(save_dir_path)
C
caoying03 已提交
69 70 71 72
    assert os.path.exists(
        word_dict_path), "The given word dictionary does not exist."
    assert os.path.exists(
        train_data_path), "The given training data does not exist."
73 74

    # initialize PaddlePaddle
C
caoying03 已提交
75
    paddle.init(use_gpu=use_gpu, trainer_count=trainer_count)
76 77

    # define optimization method and the trainer instance
78
    optimizer = paddle.optimizer.Adam(
79
        learning_rate=1e-4,
80
        regularization=paddle.optimizer.L2Regularization(rate=1e-5),
81 82 83 84 85 86 87 88 89
        model_average=paddle.optimizer.ModelAverage(
            average_window=0.5, max_average_window=2500))

    cost = encoder_decoder_network(
        word_count=len(open(word_dict_path, "r").readlines()),
        emb_dim=512,
        encoder_depth=encoder_depth,
        encoder_hidden_dim=512,
        decoder_depth=decoder_depth,
90 91 92
        decoder_hidden_dim=512,
        bos_id=0,
        eos_id=1,
93
        max_length=9)
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111

    parameters = paddle.parameters.create(cost)
    if init_model_path:
        load_initial_model(init_model_path, parameters)

    trainer = paddle.trainer.SGD(
        cost=cost, parameters=parameters, update_equation=optimizer)

    # define data reader
    train_reader = paddle.batch(
        paddle.reader.shuffle(
            reader.train_reader(train_data_path, word_dict_path),
            buf_size=1024000),
        batch_size=batch_size)

    # define the event_handler callback
    def event_handler(event):
        if isinstance(event, paddle.event.EndIteration):
C
caoying03 已提交
112
            if (not event.batch_id % 1000) and event.batch_id:
113 114 115
                save_path = os.path.join(save_dir_path,
                                         "pass_%05d_batch_%05d.tar.gz" %
                                         (event.pass_id, event.batch_id))
116
                save_model(trainer, save_path, parameters)
117

118
            if not event.batch_id % 10:
119 120 121 122 123 124
                logger.info("Pass %d, Batch %d, Cost %f, %s" % (
                    event.pass_id, event.batch_id, event.cost, event.metrics))

        if isinstance(event, paddle.event.EndPass):
            save_path = os.path.join(save_dir_path,
                                     "pass_%05d.tar.gz" % event.pass_id)
125
            save_model(trainer, save_path, parameters)
126 127 128 129 130 131

    # start training
    trainer.train(
        reader=train_reader, event_handler=event_handler, num_passes=num_passes)


C
caoying03 已提交
132 133
if __name__ == "__main__":
    train()