train.py 9.1 KB
Newer Older
1
"""Trainer for OCR CTC or attention model."""
2 3 4
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
W
wanghaoshuang 已提交
5
import paddle.fluid as fluid
6
from utility import add_arguments, print_arguments, to_lodtensor, get_ctc_feeder_data, get_attention_feeder_data
W
whs 已提交
7
from utility import check_gpu
8
import paddle.fluid.profiler as profiler
9
from crnn_ctc_model import ctc_train_net
10 11
from attention_model import attention_train_net
import data_reader
W
wanghaoshuang 已提交
12 13 14
import argparse
import functools
import sys
W
wanghaoshuang 已提交
15
import time
16 17
import os
import numpy as np
W
wanghaoshuang 已提交
18 19 20 21

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
22
add_arg('batch_size',        int,   32,         "Minibatch size.")
23
add_arg('total_step',        int,   720000,    "The number of iterations. Zero or less means whole training set. More than 0 means the training set might be looped until # of iterations is reached.")
24
add_arg('log_period',        int,   1000,       "Log period.")
W
wanghaoshuang 已提交
25 26
add_arg('save_model_period', int,   15000,      "Save model period. '-1' means never saving the model.")
add_arg('eval_period',       int,   15000,      "Evaluate period. '-1' means never evaluating the model.")
27
add_arg('save_model_dir',    str,   "./models", "The directory the model to be saved to.")
W
whs 已提交
28 29 30 31
add_arg('train_images',      str,   None,       "The directory of images to be used for training.")
add_arg('train_list',        str,   None,       "The list file of images to be used for training.")
add_arg('test_images',       str,   None,       "The directory of images to be used for test.")
add_arg('test_list',         str,   None,       "The list file of images to be used for training.")
32
add_arg('model',    str,   "crnn_ctc",           "Which type of network to be used. 'crnn_ctc' or 'attention'")
33
add_arg('init_model',        str,   None,       "The init model file of directory.")
W
wanghaoshuang 已提交
34
add_arg('use_gpu',           bool,  True,      "Whether use GPU to train.")
35
add_arg('min_average_window',int,   10000,     "Min average window.")
36
add_arg('max_average_window',int,   12500,     "Max average window. It is proposed to be set as the number of minibatch in a pass.")
37
add_arg('average_window',    float, 0.15,      "Average window.")
W
wanghaoshuang 已提交
38
add_arg('parallel',          bool,  False,     "Whether use parallel training.")
39 40 41
add_arg('profile',           bool,  False,      "Whether to use profiling.")
add_arg('skip_batch_num',    int,   0,          "The number of first minibatches to skip as warm-up for better performance test.")
add_arg('skip_test',         bool,  False,      "Whether to skip test phase.")
W
wanghaoshuang 已提交
42
# yapf: enable
43 44


45 46 47 48 49 50 51 52 53 54
def train(args):
    """OCR training"""

    if args.model == "crnn_ctc":
        train_net = ctc_train_net
        get_feeder_data = get_ctc_feeder_data
    else:
        train_net = attention_train_net
        get_feeder_data = get_attention_feeder_data

55
    num_classes = None
W
wanghaoshuang 已提交
56
    num_classes = data_reader.num_classes(
57
    ) if num_classes is None else num_classes
W
wanghaoshuang 已提交
58 59
    data_shape = data_reader.data_shape()
    # define network
60 61
    sum_cost, error_evaluator, inference_program, model_average = train_net(
        args, data_shape, num_classes)
W
wanghaoshuang 已提交
62

W
wanghaoshuang 已提交
63
    # data reader
64
    train_reader = data_reader.train(
W
wanghaoshuang 已提交
65
        args.batch_size,
W
whs 已提交
66 67
        train_images_dir=args.train_images,
        train_list_file=args.train_list,
68 69
        cycle=args.total_step > 0,
        model=args.model)
70
    test_reader = data_reader.test(
W
whs 已提交
71 72 73
        test_images_dir=args.test_images,
        test_list_file=args.test_list,
        model=args.model)
74

W
wanghaoshuang 已提交
75 76
    # prepare environment
    place = fluid.CPUPlace()
W
wanghaoshuang 已提交
77 78
    if args.use_gpu:
        place = fluid.CUDAPlace(0)
W
wanghaoshuang 已提交
79
    exe = fluid.Executor(place)
W
whs 已提交
80 81 82 83

    if 'ce_mode' in os.environ:
        fluid.default_startup_program().random_seed = 90

W
wanghaoshuang 已提交
84
    exe.run(fluid.default_startup_program())
85 86 87 88 89 90

    # load init model
    if args.init_model is not None:
        model_dir = args.init_model
        model_file_name = None
        if not os.path.isdir(args.init_model):
W
wanghaoshuang 已提交
91 92
            model_dir = os.path.dirname(args.init_model)
            model_file_name = os.path.basename(args.init_model)
93
        fluid.io.load_params(exe, dirname=model_dir, filename=model_file_name)
94
        print("Init model from: %s." % args.init_model)
95 96

    train_exe = exe
W
whs 已提交
97
    error_evaluator.reset(exe)
98 99
    if args.parallel:
        train_exe = fluid.ParallelExecutor(
100
            use_cuda=True if args.use_gpu else False, loss_name=sum_cost.name)
101 102 103 104 105 106 107

    fetch_vars = [sum_cost] + error_evaluator.metrics

    def train_one_batch(data):
        var_names = [var.name for var in fetch_vars]
        if args.parallel:
            results = train_exe.run(var_names,
W
whs 已提交
108
                                    feed=get_feeder_data(data, place))
109 110
            results = [np.array(result).sum() for result in results]
        else:
111 112
            results = train_exe.run(feed=get_feeder_data(data, place),
                                    fetch_list=fetch_vars)
113 114 115
            results = [result[0] for result in results]
        return results

116
    def test(iter_num):
W
wanghaoshuang 已提交
117
        error_evaluator.reset(exe)
118 119 120
        for data in test_reader():
            exe.run(inference_program, feed=get_feeder_data(data, place))
        _, test_seq_error = error_evaluator.eval(exe)
W
whs 已提交
121 122
        print("\nTime: %s; Iter[%d]; Test seq error: %s.\n" %
              (time.time(), iter_num, str(test_seq_error[0])))
123

W
wanghaoshuang 已提交
124 125
        #Note: The following logs are special for CE monitoring.
        #Other situations do not need to care about these logs.
126
        print("kpis	test_acc	%f" % (1 - test_seq_error[0]))
W
wanghaoshuang 已提交
127

128 129
    def save_model(args, exe, iter_num):
        filename = "model_%05d" % iter_num
130 131
        fluid.io.save_params(
            exe, dirname=args.save_model_dir, filename=filename)
132
        print("Saved model to: %s/%s." % (args.save_model_dir, filename))
W
wanghaoshuang 已提交
133

134
    iter_num = 0
135
    stop = False
W
wanghaoshuang 已提交
136
    start_time = time.time()
137
    while not stop:
W
wanghaoshuang 已提交
138 139
        total_loss = 0.0
        total_seq_error = 0.0
140
        batch_times = []
W
wanghaoshuang 已提交
141 142
        # train a pass
        for data in train_reader():
143 144 145 146 147 148 149 150
            if args.total_step > 0 and iter_num == args.total_step + args.skip_batch_num:
                stop = True
                break
            if iter_num < args.skip_batch_num:
                print("Warm-up iteration")
            if iter_num == args.skip_batch_num:
                profiler.reset_profiler()
            start = time.time()
151
            results = train_one_batch(data)
152 153 154
            batch_time = time.time() - start
            fps = args.batch_size / batch_time
            batch_times.append(batch_time)
155 156
            total_loss += results[0]
            total_seq_error += results[2]
157 158

            iter_num += 1
159
            # training log
160
            if iter_num % args.log_period == 0:
W
whs 已提交
161 162 163 164
                print("\nTime: %s; Iter[%d]; Avg loss: %.3f; Avg seq err: %.3f"
                      % (time.time(), iter_num,
                         total_loss / (args.log_period * args.batch_size),
                         total_seq_error / (args.log_period * args.batch_size)))
165 166 167 168
                print("kpis	train_cost	%f" % (total_loss / (args.log_period *
                                                            args.batch_size)))
                print("kpis	train_acc	%f" % (
                    1 - total_seq_error / (args.log_period * args.batch_size)))
169 170
                total_loss = 0.0
                total_seq_error = 0.0
W
wanghaoshuang 已提交
171

172
            # evaluate
173
            if not args.skip_test and iter_num % args.eval_period == 0:
174 175
                if model_average:
                    with model_average.apply(exe):
176
                        test(iter_num)
177
                else:
178
                    test(iter_num)
179 180

            # save model
181
            if iter_num % args.save_model_period == 0:
182 183
                if model_average:
                    with model_average.apply(exe):
184
                        save_model(args, exe, iter_num)
185
                else:
186
                    save_model(args, exe, iter_num)
W
wanghaoshuang 已提交
187
        end_time = time.time()
188
        print("kpis	train_duration	%f" % (end_time - start_time))
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
        # Postprocess benchmark data
        latencies = batch_times[args.skip_batch_num:]
        latency_avg = np.average(latencies)
        latency_pc99 = np.percentile(latencies, 99)
        fpses = np.divide(args.batch_size, latencies)
        fps_avg = np.average(fpses)
        fps_pc99 = np.percentile(fpses, 1)

        # Benchmark output
        print('\nTotal examples (incl. warm-up): %d' %
              (iter_num * args.batch_size))
        print('average latency: %.5f s, 99pc latency: %.5f s' % (latency_avg,
                                                                 latency_pc99))
        print('average fps: %.5f, fps for 99pc latency: %.5f' % (fps_avg,
                                                                 fps_pc99))
W
wanghaoshuang 已提交
204

W
wanghaoshuang 已提交
205 206 207 208

def main():
    args = parser.parse_args()
    print_arguments(args)
W
whs 已提交
209
    check_gpu(args.use_gpu)
210 211 212
    if args.profile:
        if args.use_gpu:
            with profiler.cuda_profiler("cuda_profiler.txt", 'csv') as nvprof:
213
                train(args)
214 215
        else:
            with profiler.profiler("CPU", sorted_key='total') as cpuprof:
216
                train(args)
217
    else:
218
        train(args)
W
wanghaoshuang 已提交
219

W
wanghaoshuang 已提交
220

W
wanghaoshuang 已提交
221 222
if __name__ == "__main__":
    main()