ctc_train.py 7.7 KB
Newer Older
W
wanghaoshuang 已提交
1
"""Trainer for OCR CTC model."""
W
wanghaoshuang 已提交
2
import paddle.fluid as fluid
3
import paddle.fluid.profiler as profiler
4 5
from utility import add_arguments, print_arguments, to_lodtensor, get_feeder_data
from crnn_ctc_model import ctc_train_net
W
wanghaoshuang 已提交
6
import ctc_reader
W
wanghaoshuang 已提交
7 8 9
import argparse
import functools
import sys
W
wanghaoshuang 已提交
10
import time
11 12
import os
import numpy as np
W
wanghaoshuang 已提交
13 14 15 16

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
17
add_arg('batch_size',        int,   32,         "Minibatch size.")
18
add_arg('total_step',        int,   720000,    "The number of iterations. Zero or less means whole training set. More than 0 means the training set might be looped until # of iterations is reached.")
19
add_arg('log_period',        int,   1000,       "Log period.")
W
wanghaoshuang 已提交
20 21
add_arg('save_model_period', int,   15000,      "Save model period. '-1' means never saving the model.")
add_arg('eval_period',       int,   15000,      "Evaluate period. '-1' means never evaluating the model.")
22 23
add_arg('save_model_dir',    str,   "./models", "The directory the model to be saved to.")
add_arg('init_model',        str,   None,       "The init model file of directory.")
W
wanghaoshuang 已提交
24
add_arg('use_gpu',           bool,  True,      "Whether use GPU to train.")
25
add_arg('min_average_window',int,   10000,     "Min average window.")
26
add_arg('max_average_window',int,   12500,     "Max average window. It is proposed to be set as the number of minibatch in a pass.")
27
add_arg('average_window',    float, 0.15,      "Average window.")
W
wanghaoshuang 已提交
28
add_arg('parallel',          bool,  False,     "Whether use parallel training.")
29 30 31
add_arg('profile',           bool,  False,      "Whether to use profiling.")
add_arg('skip_batch_num',    int,   0,          "The number of first minibatches to skip as warm-up for better performance test.")
add_arg('skip_test',         bool,  False,      "Whether to skip test phase.")
W
wanghaoshuang 已提交
32
# yapf: enable
33 34 35


def train(args, data_reader=ctc_reader):
W
wanghaoshuang 已提交
36
    """OCR CTC training"""
37 38 39 40 41
    num_classes = None
    train_images = None
    train_list = None
    test_images = None
    test_list = None
W
wanghaoshuang 已提交
42
    num_classes = data_reader.num_classes(
43
    ) if num_classes is None else num_classes
W
wanghaoshuang 已提交
44 45 46
    data_shape = data_reader.data_shape()
    # define network
    images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
W
wanghaoshuang 已提交
47 48 49 50
    label = fluid.layers.data(
        name='label', shape=[1], dtype='int32', lod_level=1)
    sum_cost, error_evaluator, inference_program, model_average = ctc_train_net(
        images, label, args, num_classes)
W
wanghaoshuang 已提交
51

W
wanghaoshuang 已提交
52
    # data reader
53
    train_reader = data_reader.train(
W
wanghaoshuang 已提交
54
        args.batch_size,
55
        train_images_dir=train_images,
56 57
        train_list_file=train_list,
        cycle=args.total_step > 0)
58
    test_reader = data_reader.test(
59
        test_images_dir=test_images, test_list_file=test_list)
60

W
wanghaoshuang 已提交
61 62
    # prepare environment
    place = fluid.CPUPlace()
W
wanghaoshuang 已提交
63 64
    if args.use_gpu:
        place = fluid.CUDAPlace(0)
W
wanghaoshuang 已提交
65 66
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())
67 68 69 70 71 72

    # load init model
    if args.init_model is not None:
        model_dir = args.init_model
        model_file_name = None
        if not os.path.isdir(args.init_model):
W
wanghaoshuang 已提交
73 74
            model_dir = os.path.dirname(args.init_model)
            model_file_name = os.path.basename(args.init_model)
75 76 77 78
        fluid.io.load_params(exe, dirname=model_dir, filename=model_file_name)
        print "Init model from: %s." % args.init_model

    train_exe = exe
W
whs 已提交
79
    error_evaluator.reset(exe)
80 81
    if args.parallel:
        train_exe = fluid.ParallelExecutor(
82
            use_cuda=True if args.use_gpu else False, loss_name=sum_cost.name)
83 84 85 86 87 88 89

    fetch_vars = [sum_cost] + error_evaluator.metrics

    def train_one_batch(data):
        var_names = [var.name for var in fetch_vars]
        if args.parallel:
            results = train_exe.run(var_names,
W
whs 已提交
90
                                    feed=get_feeder_data(data, place))
91 92
            results = [np.array(result).sum() for result in results]
        else:
93 94
            results = train_exe.run(feed=get_feeder_data(data, place),
                                    fetch_list=fetch_vars)
95 96 97
            results = [result[0] for result in results]
        return results

98
    def test(iter_num):
W
wanghaoshuang 已提交
99
        error_evaluator.reset(exe)
100 101 102
        for data in test_reader():
            exe.run(inference_program, feed=get_feeder_data(data, place))
        _, test_seq_error = error_evaluator.eval(exe)
103 104
        print "\nTime: %s; Iter[%d]; Test seq error: %s.\n" % (
            time.time(), iter_num, str(test_seq_error[0]))
105

106 107
    def save_model(args, exe, iter_num):
        filename = "model_%05d" % iter_num
108 109 110
        fluid.io.save_params(
            exe, dirname=args.save_model_dir, filename=filename)
        print "Saved model to: %s/%s." % (args.save_model_dir, filename)
W
wanghaoshuang 已提交
111

112
    iter_num = 0
113 114
    stop = False
    while not stop:
W
wanghaoshuang 已提交
115 116
        total_loss = 0.0
        total_seq_error = 0.0
117
        batch_times = []
W
wanghaoshuang 已提交
118 119
        # train a pass
        for data in train_reader():
120 121 122 123 124 125 126 127
            if args.total_step > 0 and iter_num == args.total_step + args.skip_batch_num:
                stop = True
                break
            if iter_num < args.skip_batch_num:
                print("Warm-up iteration")
            if iter_num == args.skip_batch_num:
                profiler.reset_profiler()
            start = time.time()
128
            results = train_one_batch(data)
129 130 131
            batch_time = time.time() - start
            fps = args.batch_size / batch_time
            batch_times.append(batch_time)
132 133
            total_loss += results[0]
            total_seq_error += results[2]
134 135

            iter_num += 1
136
            # training log
137 138 139 140 141
            if iter_num % args.log_period == 0:
                print "\nTime: %s; Iter[%d]; Avg Warp-CTC loss: %.3f; Avg seq err: %.3f" % (
                    time.time(), iter_num,
                    total_loss / (args.log_period * args.batch_size),
                    total_seq_error / (args.log_period * args.batch_size))
W
wanghaoshuang 已提交
142
                sys.stdout.flush()
143 144
                total_loss = 0.0
                total_seq_error = 0.0
W
wanghaoshuang 已提交
145

146
            # evaluate
147
            if not args.skip_test and iter_num % args.eval_period == 0:
148 149
                if model_average:
                    with model_average.apply(exe):
150
                        test(iter_num)
151
                else:
152
                    test(iter_num)
153 154

            # save model
155
            if iter_num % args.save_model_period == 0:
156 157
                if model_average:
                    with model_average.apply(exe):
158
                        save_model(args, exe, iter_num)
159
                else:
160
                    save_model(args, exe, iter_num)
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
        # Postprocess benchmark data
        latencies = batch_times[args.skip_batch_num:]
        latency_avg = np.average(latencies)
        latency_pc99 = np.percentile(latencies, 99)
        fpses = np.divide(args.batch_size, latencies)
        fps_avg = np.average(fpses)
        fps_pc99 = np.percentile(fpses, 1)

        # Benchmark output
        print('\nTotal examples (incl. warm-up): %d' %
              (iter_num * args.batch_size))
        print('average latency: %.5f s, 99pc latency: %.5f s' % (latency_avg,
                                                                 latency_pc99))
        print('average fps: %.5f, fps for 99pc latency: %.5f' % (fps_avg,
                                                                 fps_pc99))
W
wanghaoshuang 已提交
176

W
wanghaoshuang 已提交
177 178 179 180

def main():
    args = parser.parse_args()
    print_arguments(args)
181 182 183 184 185 186 187 188 189
    if args.profile:
        if args.use_gpu:
            with profiler.cuda_profiler("cuda_profiler.txt", 'csv') as nvprof:
                train(args, data_reader=ctc_reader)
        else:
            with profiler.profiler("CPU", sorted_key='total') as cpuprof:
                train(args, data_reader=ctc_reader)
    else:
        train(args, data_reader=ctc_reader)
W
wanghaoshuang 已提交
190

W
wanghaoshuang 已提交
191

W
wanghaoshuang 已提交
192 193
if __name__ == "__main__":
    main()