train.py 3.7 KB
Newer Older
P
peterzhang2029 已提交
1
import gzip
P
peterzhang2029 已提交
2 3
import os
import click
4 5

import paddle.v2 as paddle
P
peterzhang2029 已提交
6
from config import TrainerConfig as conf
P
peterzhang2029 已提交
7
from network_conf import Model
P
peterzhang2029 已提交
8
from reader import DataGenerator
P
peterzhang2029 已提交
9
from utils import get_file_list, build_label_dict, load_dict
P
peterzhang2029 已提交
10

P
peterzhang2029 已提交
11 12 13 14

@click.command('train')
@click.option(
    "--train_file_list_path",
P
peterzhang2029 已提交
15 16
    type=str,
    required=True,
P
peterzhang2029 已提交
17 18 19 20
    help=("The path of the file which contains "
          "path list of train image files."))
@click.option(
    "--test_file_list_path",
P
peterzhang2029 已提交
21 22
    type=str,
    required=True,
P
peterzhang2029 已提交
23 24
    help=("The path of the file which contains "
          "path list of test image files."))
P
peterzhang2029 已提交
25 26 27 28 29 30 31 32
@click.option(
    "--label_dict_path",
    type=str,
    required=True,
    help=("The path of label dictionary. "
          "If this parameter is set, but the file does not exist, "
          "label dictionay will be built from "
          "the training data automatically."))
P
peterzhang2029 已提交
33 34
@click.option(
    "--model_save_dir",
P
peterzhang2029 已提交
35
    type=str,
P
peterzhang2029 已提交
36 37
    default="models",
    help="The path to save the trained models (default: 'models').")
P
peterzhang2029 已提交
38 39
def train(train_file_list_path, test_file_list_path, label_dict_path,
          model_save_dir):
P
peterzhang2029 已提交
40 41 42

    if not os.path.exists(model_save_dir):
        os.mkdir(model_save_dir)
P
peterzhang2029 已提交
43

P
peterzhang2029 已提交
44 45
    train_file_list = get_file_list(train_file_list_path)
    test_file_list = get_file_list(test_file_list_path)
P
peterzhang2029 已提交
46 47 48 49 50 51 52 53

    if not os.path.exists(label_dict_path):
        print(("Label dictionary is not given, the dictionary "
               "is automatically built from the training data."))
        build_label_dict(train_file_list, label_dict_path)

    char_dict = load_dict(label_dict_path)
    dict_size = len(char_dict)
P
peterzhang2029 已提交
54 55 56 57 58 59 60 61 62
    data_generator = DataGenerator(
        char_dict=char_dict, image_shape=conf.image_shape)

    paddle.init(use_gpu=conf.use_gpu, trainer_count=conf.trainer_count)
    # Create optimizer.
    optimizer = paddle.optimizer.Momentum(momentum=conf.momentum)
    # Define network topology.
    model = Model(dict_size, conf.image_shape, is_infer=False)
    # Create all the trainable parameters.
63
    params = paddle.parameters.create(model.cost)
P
peterzhang2029 已提交
64

65 66 67 68 69
    trainer = paddle.trainer.SGD(
        cost=model.cost,
        parameters=params,
        update_equation=optimizer,
        extra_layers=model.eval)
P
peterzhang2029 已提交
70 71
    # Feeding dictionary.
    feeding = {'image': 0, 'label': 1}
72 73 74

    def event_handler(event):
        if isinstance(event, paddle.event.EndIteration):
P
peterzhang2029 已提交
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
            if event.batch_id % conf.log_period == 0:
                print("Pass %d, batch %d, Samples %d, Cost %f, Eval %s" %
                      (event.pass_id, event.batch_id, event.batch_id *
                       conf.batch_size, event.cost, event.metrics))

        if isinstance(event, paddle.event.EndPass):
            # Here, because training and testing data share a same format,
            # we still use the reader.train_reader to read the testing data.
            result = trainer.test(
                reader=paddle.batch(
                    data_generator.train_reader(test_file_list),
                    batch_size=conf.batch_size),
                feeding=feeding)
            print("Test %d, Cost %f, Eval %s" %
                  (event.pass_id, result.cost, result.metrics))
            with gzip.open(
                    os.path.join(model_save_dir, "params_pass_%05d.tar.gz" %
                                 event.pass_id), "w") as f:
                trainer.save_parameter_to_tar(f)
94 95 96

    trainer.train(
        reader=paddle.batch(
P
peterzhang2029 已提交
97 98 99 100 101
            paddle.reader.shuffle(
                data_generator.train_reader(train_file_list),
                buf_size=conf.buf_size),
            batch_size=conf.batch_size),
        feeding=feeding,
102
        event_handler=event_handler,
P
peterzhang2029 已提交
103
        num_passes=conf.num_passes)
104 105 106


if __name__ == "__main__":
P
peterzhang2029 已提交
107
    train()