train.py 4.2 KB
Newer Older
1
import os
C
caoying03 已提交
2
import gzip
3
import logging
C
caoying03 已提交
4
import click
5 6 7 8 9 10 11 12 13 14

import paddle.v2 as paddle
import reader
from paddle.v2.layer import parse_network
from network_conf import encoder_decoder_network

logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO)


15
def save_model(trainer, save_path, parameters):
16
    with gzip.open(save_path, "w") as f:
17
        trainer.save_parameter_to_tar(f)
18 19 20 21 22 23 24


def load_initial_model(model_path, parameters):
    with gzip.open(model_path, "rb") as f:
        parameters.init_from_tar(f)


C
caoying03 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
@click.command("train")
@click.option(
    "--num_passes", default=10, help="Number of passes for the training task.")
@click.option(
    "--batch_size",
    default=16,
    help="The number of training examples in one forward/backward pass.")
@click.option(
    "--use_gpu", default=False, help="Whether to use gpu to train the model.")
@click.option(
    "--trainer_count", default=1, help="The thread number used in training.")
@click.option(
    "--save_dir_path",
    default="models",
    help="The path to saved the trained models.")
@click.option(
    "--encoder_depth",
    default=3,
    help="The number of stacked LSTM layers in encoder.")
@click.option(
    "--decoder_depth",
    default=3,
    help="The number of stacked LSTM layers in encoder.")
@click.option(
    "--train_data_path", required=True, help="The path of trainning data.")
@click.option(
    "--word_dict_path", required=True, help="The path of word dictionary.")
@click.option(
    "--init_model_path",
    default="",
    help=("The path of a trained model used to initialized all "
          "the model parameters."))
def train(num_passes,
          batch_size,
          use_gpu,
          trainer_count,
          save_dir_path,
          encoder_depth,
          decoder_depth,
          train_data_path,
          word_dict_path,
          init_model_path=""):
67 68
    if not os.path.exists(save_dir_path):
        os.mkdir(save_dir_path)
C
caoying03 已提交
69 70 71 72
    assert os.path.exists(
        word_dict_path), "The given word dictionary does not exist."
    assert os.path.exists(
        train_data_path), "The given training data does not exist."
73 74

    # initialize PaddlePaddle
C
caoying03 已提交
75
    paddle.init(use_gpu=use_gpu, trainer_count=trainer_count)
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109

    # define optimization method and the trainer instance
    optimizer = paddle.optimizer.AdaDelta(
        learning_rate=1e-3,
        gradient_clipping_threshold=25.0,
        regularization=paddle.optimizer.L2Regularization(rate=8e-4),
        model_average=paddle.optimizer.ModelAverage(
            average_window=0.5, max_average_window=2500))

    cost = encoder_decoder_network(
        word_count=len(open(word_dict_path, "r").readlines()),
        emb_dim=512,
        encoder_depth=encoder_depth,
        encoder_hidden_dim=512,
        decoder_depth=decoder_depth,
        decoder_hidden_dim=512)

    parameters = paddle.parameters.create(cost)
    if init_model_path:
        load_initial_model(init_model_path, parameters)

    trainer = paddle.trainer.SGD(
        cost=cost, parameters=parameters, update_equation=optimizer)

    # define data reader
    train_reader = paddle.batch(
        paddle.reader.shuffle(
            reader.train_reader(train_data_path, word_dict_path),
            buf_size=1024000),
        batch_size=batch_size)

    # define the event_handler callback
    def event_handler(event):
        if isinstance(event, paddle.event.EndIteration):
C
caoying03 已提交
110
            if (not event.batch_id % 1000) and event.batch_id:
111 112 113
                save_path = os.path.join(save_dir_path,
                                         "pass_%05d_batch_%05d.tar.gz" %
                                         (event.pass_id, event.batch_id))
114
                save_model(trainer, save_path, parameters)
115 116 117 118 119 120 121 122

            if not event.batch_id % 5:
                logger.info("Pass %d, Batch %d, Cost %f, %s" % (
                    event.pass_id, event.batch_id, event.cost, event.metrics))

        if isinstance(event, paddle.event.EndPass):
            save_path = os.path.join(save_dir_path,
                                     "pass_%05d.tar.gz" % event.pass_id)
123
            save_model(trainer, save_path, parameters)
124 125 126 127 128 129

    # start training
    trainer.train(
        reader=train_reader, event_handler=event_handler, num_passes=num_passes)


C
caoying03 已提交
130 131
if __name__ == "__main__":
    train()