ctc_train.py 6.9 KB
Newer Older
W
wanghaoshuang 已提交
1
"""Trainer for OCR CTC model."""
W
wanghaoshuang 已提交
2
import paddle.fluid as fluid
3 4
from utility import add_arguments, print_arguments, to_lodtensor, get_feeder_data
from crnn_ctc_model import ctc_train_net
W
wanghaoshuang 已提交
5
import ctc_reader
W
wanghaoshuang 已提交
6 7 8
import argparse
import functools
import sys
W
wanghaoshuang 已提交
9
import time
10 11
import os
import numpy as np
W
wanghaoshuang 已提交
12 13 14 15

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
add_arg('batch_size',        int,   32,         "Minibatch size.")
add_arg('pass_num',          int,   100,        "Number of training epochs.")
add_arg('log_period',        int,   1000,       "Log period.")
add_arg('save_model_period', int,   15000,      "Save model period.")
add_arg('eval_period',       int,   15000,      "Evaluate period.")
add_arg('save_model_dir',    str,   "./models", "The directory the model to be saved to.")
add_arg('init_model',        str,   None,       "The init model file of directory.")
add_arg('learning_rate',     float, 1.0e-3,    "Learning rate.")
add_arg('l2',                float, 0.0004,    "L2 regularizer.")
add_arg('max_clip',          float, 10.0,      "Max clip threshold.")
add_arg('min_clip',          float, -10.0,     "Min clip threshold.")
add_arg('momentum',          float, 0.9,       "Momentum.")
add_arg('rnn_hidden_size',   int,   200,       "Hidden size of rnn layers.")
add_arg('device',            int,   0,         "Device id.'-1' means running on CPU"
                                               "while '0' means GPU-0.")
add_arg('min_average_window',int,   10000,     "Min average window.")
add_arg('max_average_window',int,   15625,     "Max average window.")
add_arg('average_window',    float, 0.15,      "Average window.")
add_arg('parallel',          bool,  True,     "Whether use parallel training.")
add_arg('train_images',      str,   None,    "The directory of training images."
        "None means using the default training images of reader.")
add_arg('train_list',        str,   None,    "The list file of training images."
        "None means using the default train_list file of reader.")
add_arg('test_images',      str,    None,    "The directory of training images."
        "None means using the default test images of reader.")
add_arg('test_list',        str,    None,   "The list file of training images."
        "None means using the default test_list file of reader.")
add_arg('num_classes',      int,    None,      "The number of classes."
        "None means using the default num_classes from reader.")
W
wanghaoshuang 已提交
45 46
# yapf: disable

W
wanghaoshuang 已提交
47

48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
def train_one_batch(args, exe, data, fetch_vars, data_place):
    var_names = [var.name for var in fetch_vars]
    if args.parallel:
        results = exe.run(var_names, feed_dict=get_feeder_data(data, data_place))
        results = [np.array(result).sum() for result in results]
    else:
        results = exe.run(
                feed=get_feeder_data(data, data_place),
                fetch_list=fetch_vars)
        results = [result[0] for result in results]
    return results

def test(args, test_reader, exe, inference_program, error_evaluator, pass_id, batch_id, data_place):
    error_evaluator.reset(exe)
    for data in test_reader():
        exe.run(inference_program, feed=get_feeder_data(data, data_place))
    _, test_seq_error = error_evaluator.eval(exe)
    print "\nTime: %s; Pass[%d]-batch[%d]; Test seq error: %s.\n" % (
                   time.time(), pass_id, batch_id, str(test_seq_error[0]))
W
wanghaoshuang 已提交
67

68 69 70 71 72 73 74
def save_model(args, exe, pass_id, batch_id):
    filename = "model_%05d_%d" % (pass_id, batch_id)
    fluid.io.save_params(exe, dirname=args.save_model_dir, filename=filename)
    print "Saved model to: %s/%s." %(args.save_model_dir, filename)


def train(args, data_reader=ctc_reader):
W
wanghaoshuang 已提交
75
    """OCR CTC training"""
76
    num_classes = data_reader.num_classes() if args.num_classes is None else args.num_classes
W
wanghaoshuang 已提交
77 78 79 80
    data_shape = data_reader.data_shape()
    # define network
    images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
    label = fluid.layers.data(name='label', shape=[1], dtype='int32', lod_level=1)
81
    sum_cost, error_evaluator, inference_program, model_average = ctc_train_net(images, label, args, num_classes)
W
wanghaoshuang 已提交
82

W
wanghaoshuang 已提交
83
    # data reader
84 85 86 87 88 89 90 91
    train_reader = data_reader.train(
            args.batch_size,
            train_images_dir=args.train_images,
            train_list_file=args.train_list)
    test_reader = data_reader.test(
            test_images_dir=args.test_images,
            test_list_file=args.test_list)

W
wanghaoshuang 已提交
92 93 94 95
    # prepare environment
    place = fluid.CPUPlace()
    if args.device >= 0:
        place = fluid.CUDAPlace(args.device)
96

W
wanghaoshuang 已提交
97 98
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116

    error_evaluator.reset(exe)

    # load init model
    if args.init_model is not None:
        model_dir = args.init_model
        model_file_name = None
        if not os.path.isdir(args.init_model):
            model_dir=os.path.dirname(args.init_model)
            model_file_name=os.path.basename(args.init_model)
        fluid.io.load_params(exe, dirname=model_dir, filename=model_file_name)
        print "Init model from: %s." % args.init_model

    train_exe = exe
    if args.parallel:
        train_exe = fluid.ParallelExecutor(
            use_cuda=True,
            loss_name=sum_cost.name)
W
wanghaoshuang 已提交
117

W
wanghaoshuang 已提交
118
    for pass_id in range(args.pass_num):
W
wanghaoshuang 已提交
119 120 121
        batch_id = 1
        total_loss = 0.0
        total_seq_error = 0.0
W
wanghaoshuang 已提交
122 123
        # train a pass
        for data in train_reader():
124 125 126 127 128 129 130 131 132
            results = train_one_batch(args, train_exe, data,
                    [sum_cost] + error_evaluator.metrics,
#                    [sum_cost],
                    place)
            total_loss += results[0]
            total_seq_error += results[2]
            # training log
            if batch_id % args.log_period == 0:
                print "\nTime: %s; Pass[%d]-batch[%d]; Avg Warp-CTC loss: %s; Avg seq err: %s" % (
W
wanghaoshuang 已提交
133
                    time.time(),
134
                    pass_id, batch_id, total_loss / (batch_id * args.batch_size), total_seq_error / (batch_id * args.batch_size))
W
wanghaoshuang 已提交
135 136
                sys.stdout.flush()

137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
            # evaluate
            if batch_id % args.eval_period == 0:
                if model_average:
                    with model_average.apply(exe):
                        test(args, test_reader, exe, inference_program, error_evaluator, pass_id, batch_id, place)
                else:
                    test(args, test_reader, exe, inference_program, error_evaluator, pass_id, batch_id, place)

            # save model
            if batch_id % args.save_model_period == 0:
                if model_average:
                    with model_average.apply(exe):
                        save_model(args, exe, pass_id, batch_id)
                else:
                    save_model(args, exe, pass_id, batch_id)

            batch_id += 1
W
wanghaoshuang 已提交
154

W
wanghaoshuang 已提交
155 156 157 158

def main():
    args = parser.parse_args()
    print_arguments(args)
W
wanghaoshuang 已提交
159
    train(args, data_reader=ctc_reader)
W
wanghaoshuang 已提交
160 161 162

if __name__ == "__main__":
    main()