train.py 3.1 KB
Newer Older
P
peterzhang2029 已提交
1
import gzip
P
peterzhang2029 已提交
2 3
import os
import click
4 5

import paddle.v2 as paddle
P
peterzhang2029 已提交
6
from config import TrainerConfig as conf
P
peterzhang2029 已提交
7
from model import Model
P
peterzhang2029 已提交
8 9
from reader import DataGenerator
from utils import get_file_list, AsciiDic
P
peterzhang2029 已提交
10

P
peterzhang2029 已提交
11 12 13 14

@click.command('train')
@click.option(
    "--train_file_list_path",
P
peterzhang2029 已提交
15 16
    type=str,
    required=True,
P
peterzhang2029 已提交
17 18 19 20
    help=("The path of the file which contains "
          "path list of train image files."))
@click.option(
    "--test_file_list_path",
P
peterzhang2029 已提交
21 22
    type=str,
    required=True,
P
peterzhang2029 已提交
23 24 25 26
    help=("The path of the file which contains "
          "path list of test image files."))
@click.option(
    "--model_save_dir",
P
peterzhang2029 已提交
27
    type=str,
P
peterzhang2029 已提交
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
    default="models",
    help="The path to save the trained models (default: 'models').")
def train(train_file_list_path, test_file_list_path, model_save_dir):

    if not os.path.exists(model_save_dir):
        os.mkdir(model_save_dir)
    train_file_list = get_file_list(train_file_list_path)
    test_file_list = get_file_list(test_file_list_path)
    char_dict = AsciiDic()
    dict_size = char_dict.size()
    data_generator = DataGenerator(
        char_dict=char_dict, image_shape=conf.image_shape)

    paddle.init(use_gpu=conf.use_gpu, trainer_count=conf.trainer_count)
    # Create optimizer.
    optimizer = paddle.optimizer.Momentum(momentum=conf.momentum)
    # Define network topology.
    model = Model(dict_size, conf.image_shape, is_infer=False)
    # Create all the trainable parameters.
47
    params = paddle.parameters.create(model.cost)
P
peterzhang2029 已提交
48

49 50 51 52 53
    trainer = paddle.trainer.SGD(
        cost=model.cost,
        parameters=params,
        update_equation=optimizer,
        extra_layers=model.eval)
P
peterzhang2029 已提交
54 55
    # Feeding dictionary.
    feeding = {'image': 0, 'label': 1}
56 57 58

    def event_handler(event):
        if isinstance(event, paddle.event.EndIteration):
P
peterzhang2029 已提交
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
            if event.batch_id % conf.log_period == 0:
                print("Pass %d, batch %d, Samples %d, Cost %f, Eval %s" %
                      (event.pass_id, event.batch_id, event.batch_id *
                       conf.batch_size, event.cost, event.metrics))

        if isinstance(event, paddle.event.EndPass):
            # Here, because training and testing data share a same format,
            # we still use the reader.train_reader to read the testing data.
            result = trainer.test(
                reader=paddle.batch(
                    data_generator.train_reader(test_file_list),
                    batch_size=conf.batch_size),
                feeding=feeding)
            print("Test %d, Cost %f, Eval %s" %
                  (event.pass_id, result.cost, result.metrics))
            with gzip.open(
                    os.path.join(model_save_dir, "params_pass_%05d.tar.gz" %
                                 event.pass_id), "w") as f:
                trainer.save_parameter_to_tar(f)
78 79 80

    trainer.train(
        reader=paddle.batch(
P
peterzhang2029 已提交
81 82 83 84 85
            paddle.reader.shuffle(
                data_generator.train_reader(train_file_list),
                buf_size=conf.buf_size),
            batch_size=conf.batch_size),
        feeding=feeding,
86
        event_handler=event_handler,
P
peterzhang2029 已提交
87
        num_passes=conf.num_passes)
88 89 90


if __name__ == "__main__":
P
peterzhang2029 已提交
91
    train()