ctc_train.py 5.7 KB
Newer Older
W
wanghaoshuang 已提交
1
"""Trainer for OCR CTC model."""
W
wanghaoshuang 已提交
2
import paddle.fluid as fluid
3 4
from utility import add_arguments, print_arguments, to_lodtensor, get_feeder_data
from crnn_ctc_model import ctc_train_net
W
wanghaoshuang 已提交
5
import ctc_reader
W
wanghaoshuang 已提交
6 7 8
import argparse
import functools
import sys
W
wanghaoshuang 已提交
9
import time
10 11
import os
import numpy as np
W
wanghaoshuang 已提交
12 13 14 15

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
16 17 18
add_arg('batch_size',        int,   32,         "Minibatch size.")
add_arg('pass_num',          int,   100,        "Number of training epochs.")
add_arg('log_period',        int,   1000,       "Log period.")
W
wanghaoshuang 已提交
19 20
add_arg('save_model_period', int,   15000,      "Save model period. '-1' means never saving the model.")
add_arg('eval_period',       int,   15000,      "Evaluate period. '-1' means never evaluating the model.")
21 22
add_arg('save_model_dir',    str,   "./models", "The directory the model to be saved to.")
add_arg('init_model',        str,   None,       "The init model file of directory.")
W
wanghaoshuang 已提交
23
add_arg('use_gpu',           bool,  True,      "Whether use GPU to train.")
24
add_arg('min_average_window',int,   10000,     "Min average window.")
W
wanghaoshuang 已提交
25
add_arg('max_average_window',int,   15625,     "Max average window. It is proposed to be set as the number of minibatch in a pass.")
26
add_arg('average_window',    float, 0.15,      "Average window.")
W
wanghaoshuang 已提交
27
add_arg('parallel',          bool,  False,     "Whether use parallel training.")
W
wanghaoshuang 已提交
28
# yapf: enable
29 30 31


def train(args, data_reader=ctc_reader):
W
wanghaoshuang 已提交
32
    """OCR CTC training"""
33 34 35 36 37
    num_classes = None
    train_images = None
    train_list = None
    test_images = None
    test_list = None
W
wanghaoshuang 已提交
38
    num_classes = data_reader.num_classes(
39
    ) if num_classes is None else num_classes
W
wanghaoshuang 已提交
40 41 42
    data_shape = data_reader.data_shape()
    # define network
    images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
W
wanghaoshuang 已提交
43 44 45 46
    label = fluid.layers.data(
        name='label', shape=[1], dtype='int32', lod_level=1)
    sum_cost, error_evaluator, inference_program, model_average = ctc_train_net(
        images, label, args, num_classes)
W
wanghaoshuang 已提交
47

W
wanghaoshuang 已提交
48
    # data reader
49
    train_reader = data_reader.train(
W
wanghaoshuang 已提交
50
        args.batch_size,
51 52
        train_images_dir=train_images,
        train_list_file=train_list)
53
    test_reader = data_reader.test(
54
        test_images_dir=test_images, test_list_file=test_list)
55

W
wanghaoshuang 已提交
56 57
    # prepare environment
    place = fluid.CPUPlace()
W
wanghaoshuang 已提交
58 59
    if args.use_gpu:
        place = fluid.CUDAPlace(0)
W
wanghaoshuang 已提交
60 61
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())
62 63 64 65 66 67

    # load init model
    if args.init_model is not None:
        model_dir = args.init_model
        model_file_name = None
        if not os.path.isdir(args.init_model):
W
wanghaoshuang 已提交
68 69
            model_dir = os.path.dirname(args.init_model)
            model_file_name = os.path.basename(args.init_model)
70 71 72 73
        fluid.io.load_params(exe, dirname=model_dir, filename=model_file_name)
        print "Init model from: %s." % args.init_model

    train_exe = exe
W
whs 已提交
74
    error_evaluator.reset(exe)
75 76
    if args.parallel:
        train_exe = fluid.ParallelExecutor(
77 78 79 80 81 82 83 84
            use_cuda=True, loss_name=sum_cost.name)

    fetch_vars = [sum_cost] + error_evaluator.metrics

    def train_one_batch(data):
        var_names = [var.name for var in fetch_vars]
        if args.parallel:
            results = train_exe.run(var_names,
W
whs 已提交
85
                                    feed=get_feeder_data(data, place))
86 87 88 89 90 91 92 93
            results = [np.array(result).sum() for result in results]
        else:
            results = exe.run(feed=get_feeder_data(data, place),
                              fetch_list=fetch_vars)
            results = [result[0] for result in results]
        return results

    def test(pass_id, batch_id):
W
wanghaoshuang 已提交
94
        error_evaluator.reset(exe)
95 96 97 98 99 100 101 102 103 104 105
        for data in test_reader():
            exe.run(inference_program, feed=get_feeder_data(data, place))
        _, test_seq_error = error_evaluator.eval(exe)
        print "\nTime: %s; Pass[%d]-batch[%d]; Test seq error: %s.\n" % (
            time.time(), pass_id, batch_id, str(test_seq_error[0]))

    def save_model(args, exe, pass_id, batch_id):
        filename = "model_%05d_%d" % (pass_id, batch_id)
        fluid.io.save_params(
            exe, dirname=args.save_model_dir, filename=filename)
        print "Saved model to: %s/%s." % (args.save_model_dir, filename)
W
wanghaoshuang 已提交
106

W
wanghaoshuang 已提交
107
    for pass_id in range(args.pass_num):
W
wanghaoshuang 已提交
108 109 110
        batch_id = 1
        total_loss = 0.0
        total_seq_error = 0.0
W
wanghaoshuang 已提交
111 112
        # train a pass
        for data in train_reader():
113
            results = train_one_batch(data)
114 115 116 117 118
            total_loss += results[0]
            total_seq_error += results[2]
            # training log
            if batch_id % args.log_period == 0:
                print "\nTime: %s; Pass[%d]-batch[%d]; Avg Warp-CTC loss: %s; Avg seq err: %s" % (
W
wanghaoshuang 已提交
119 120 121
                    time.time(), pass_id, batch_id,
                    total_loss / (batch_id * args.batch_size),
                    total_seq_error / (batch_id * args.batch_size))
W
wanghaoshuang 已提交
122 123
                sys.stdout.flush()

124 125 126 127
            # evaluate
            if batch_id % args.eval_period == 0:
                if model_average:
                    with model_average.apply(exe):
128
                        test(pass_id, batch_id)
129
                else:
130
                    test(pass_id, batch_d)
131 132 133 134 135 136 137 138 139 140

            # save model
            if batch_id % args.save_model_period == 0:
                if model_average:
                    with model_average.apply(exe):
                        save_model(args, exe, pass_id, batch_id)
                else:
                    save_model(args, exe, pass_id, batch_id)

            batch_id += 1
W
wanghaoshuang 已提交
141

W
wanghaoshuang 已提交
142 143 144 145

def main():
    args = parser.parse_args()
    print_arguments(args)
W
wanghaoshuang 已提交
146
    train(args, data_reader=ctc_reader)
W
wanghaoshuang 已提交
147

W
wanghaoshuang 已提交
148

W
wanghaoshuang 已提交
149 150
if __name__ == "__main__":
    main()