ctc_train.py 5.7 KB
Newer Older
W
wanghaoshuang 已提交
1
"""Trainer for OCR CTC model."""
W
wanghaoshuang 已提交
2
import paddle.fluid as fluid
3 4
from utility import add_arguments, print_arguments, to_lodtensor, get_feeder_data
from crnn_ctc_model import ctc_train_net
W
wanghaoshuang 已提交
5
import ctc_reader
W
wanghaoshuang 已提交
6 7 8
import argparse
import functools
import sys
W
wanghaoshuang 已提交
9
import time
10 11
import os
import numpy as np
W
wanghaoshuang 已提交
12 13 14 15

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
16
add_arg('batch_size',        int,   32,         "Minibatch size.")
17
add_arg('total_step',        int,   720000,    "Number of training iterations.")
18
add_arg('log_period',        int,   1000,       "Log period.")
W
wanghaoshuang 已提交
19 20
add_arg('save_model_period', int,   15000,      "Save model period. '-1' means never saving the model.")
add_arg('eval_period',       int,   15000,      "Evaluate period. '-1' means never evaluating the model.")
21 22
add_arg('save_model_dir',    str,   "./models", "The directory the model to be saved to.")
add_arg('init_model',        str,   None,       "The init model file of directory.")
W
wanghaoshuang 已提交
23
add_arg('use_gpu',           bool,  True,      "Whether use GPU to train.")
24
add_arg('min_average_window',int,   10000,     "Min average window.")
25
add_arg('max_average_window',int,   12500,     "Max average window. It is proposed to be set as the number of minibatch in a pass.")
26
add_arg('average_window',    float, 0.15,      "Average window.")
W
wanghaoshuang 已提交
27
add_arg('parallel',          bool,  False,     "Whether use parallel training.")
W
wanghaoshuang 已提交
28
# yapf: enable
29 30 31


def train(args, data_reader=ctc_reader):
W
wanghaoshuang 已提交
32
    """OCR CTC training"""
33 34 35 36 37
    num_classes = None
    train_images = None
    train_list = None
    test_images = None
    test_list = None
W
wanghaoshuang 已提交
38
    num_classes = data_reader.num_classes(
39
    ) if num_classes is None else num_classes
W
wanghaoshuang 已提交
40 41 42
    data_shape = data_reader.data_shape()
    # define network
    images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
W
wanghaoshuang 已提交
43 44 45 46
    label = fluid.layers.data(
        name='label', shape=[1], dtype='int32', lod_level=1)
    sum_cost, error_evaluator, inference_program, model_average = ctc_train_net(
        images, label, args, num_classes)
W
wanghaoshuang 已提交
47

W
wanghaoshuang 已提交
48
    # data reader
49
    train_reader = data_reader.train(
W
wanghaoshuang 已提交
50
        args.batch_size,
51 52
        train_images_dir=train_images,
        train_list_file=train_list)
53
    test_reader = data_reader.test(
54
        test_images_dir=test_images, test_list_file=test_list)
55

W
wanghaoshuang 已提交
56 57
    # prepare environment
    place = fluid.CPUPlace()
W
wanghaoshuang 已提交
58 59
    if args.use_gpu:
        place = fluid.CUDAPlace(0)
W
wanghaoshuang 已提交
60 61
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())
62 63 64 65 66 67

    # load init model
    if args.init_model is not None:
        model_dir = args.init_model
        model_file_name = None
        if not os.path.isdir(args.init_model):
W
wanghaoshuang 已提交
68 69
            model_dir = os.path.dirname(args.init_model)
            model_file_name = os.path.basename(args.init_model)
70 71 72 73
        fluid.io.load_params(exe, dirname=model_dir, filename=model_file_name)
        print "Init model from: %s." % args.init_model

    train_exe = exe
W
whs 已提交
74
    error_evaluator.reset(exe)
75 76
    if args.parallel:
        train_exe = fluid.ParallelExecutor(
77 78 79 80 81 82 83 84
            use_cuda=True, loss_name=sum_cost.name)

    fetch_vars = [sum_cost] + error_evaluator.metrics

    def train_one_batch(data):
        var_names = [var.name for var in fetch_vars]
        if args.parallel:
            results = train_exe.run(var_names,
W
whs 已提交
85
                                    feed=get_feeder_data(data, place))
86 87 88 89 90 91 92
            results = [np.array(result).sum() for result in results]
        else:
            results = exe.run(feed=get_feeder_data(data, place),
                              fetch_list=fetch_vars)
            results = [result[0] for result in results]
        return results

93
    def test(iter_num):
W
wanghaoshuang 已提交
94
        error_evaluator.reset(exe)
95 96 97
        for data in test_reader():
            exe.run(inference_program, feed=get_feeder_data(data, place))
        _, test_seq_error = error_evaluator.eval(exe)
98 99
        print "\nTime: %s; Iter[%d]; Test seq error: %s.\n" % (
            time.time(), iter_num, str(test_seq_error[0]))
100

101 102
    def save_model(args, exe, iter_num):
        filename = "model_%05d" % iter_num
103 104 105
        fluid.io.save_params(
            exe, dirname=args.save_model_dir, filename=filename)
        print "Saved model to: %s/%s." % (args.save_model_dir, filename)
W
wanghaoshuang 已提交
106

107 108
    iter_num = 0
    while True:
W
wanghaoshuang 已提交
109 110
        total_loss = 0.0
        total_seq_error = 0.0
W
wanghaoshuang 已提交
111 112
        # train a pass
        for data in train_reader():
113 114 115
            iter_num += 1
            if iter_num > args.total_step:
                return
116
            results = train_one_batch(data)
117 118 119
            total_loss += results[0]
            total_seq_error += results[2]
            # training log
120 121 122 123 124
            if iter_num % args.log_period == 0:
                print "\nTime: %s; Iter[%d]; Avg Warp-CTC loss: %.3f; Avg seq err: %.3f" % (
                    time.time(), iter_num,
                    total_loss / (args.log_period * args.batch_size),
                    total_seq_error / (args.log_period * args.batch_size))
W
wanghaoshuang 已提交
125
                sys.stdout.flush()
126 127
                total_loss = 0.0
                total_seq_error = 0.0
W
wanghaoshuang 已提交
128

129
            # evaluate
130
            if iter_num % args.eval_period == 0:
131 132
                if model_average:
                    with model_average.apply(exe):
133
                        test(iter_num)
134
                else:
135
                    test(iter_num)
136 137

            # save model
138
            if iter_num % args.save_model_period == 0:
139 140
                if model_average:
                    with model_average.apply(exe):
141
                        save_model(args, exe, iter_num)
142
                else:
143
                    save_model(args, exe, iter_num)
W
wanghaoshuang 已提交
144

W
wanghaoshuang 已提交
145 146 147 148

def main():
    args = parser.parse_args()
    print_arguments(args)
W
wanghaoshuang 已提交
149
    train(args, data_reader=ctc_reader)
W
wanghaoshuang 已提交
150

W
wanghaoshuang 已提交
151

W
wanghaoshuang 已提交
152 153
if __name__ == "__main__":
    main()