ctc_train.py 3.9 KB
Newer Older
W
wanghaoshuang 已提交
1
"""Trainer for OCR CTC model."""
W
wanghaoshuang 已提交
2
import paddle.fluid as fluid
W
wanghaoshuang 已提交
3
import dummy_reader
W
wanghaoshuang 已提交
4
import ctc_reader
W
wanghaoshuang 已提交
5
import argparse
W
wanghaoshuang 已提交
6
from load_model import load_param
W
wanghaoshuang 已提交
7 8 9 10
import functools
import sys
from utility import add_arguments, print_arguments, to_lodtensor, get_feeder_data
from crnn_ctc_model import ctc_train_net
W
wanghaoshuang 已提交
11
import time
W
wanghaoshuang 已提交
12 13 14 15 16

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size',     int,   32,     "Minibatch size.")
W
wanghaoshuang 已提交
17
add_arg('pass_num',       int,   100,     "# of training epochs.")
W
wanghaoshuang 已提交
18 19 20 21 22 23 24 25 26
add_arg('log_period',     int,   1000,   "Log period.")
add_arg('learning_rate',  float, 1.0e-3, "Learning rate.")
add_arg('l2',             float, 0.0004, "L2 regularizer.")
add_arg('max_clip',       float, 10.0,   "Max clip threshold.")
add_arg('min_clip',       float, -10.0,  "Min clip threshold.")
add_arg('momentum',       float, 0.9,    "Momentum.")
add_arg('rnn_hidden_size',int,   200,    "Hidden size of rnn layers.")
add_arg('device',         int,   0,      "Device id.'-1' means running on CPU"
                                         "while '0' means GPU-0.")
W
wanghaoshuang 已提交
27 28 29
add_arg('min_average_window',     int,   10000,     "Min average window.")
add_arg('max_average_window',     int,   15625,     "Max average window.")
add_arg('average_window',     float,   0.15,     "Average window.")
W
wanghaoshuang 已提交
30
add_arg('parallel',     bool,   False,     "Whether use parallel training.")
W
wanghaoshuang 已提交
31 32
# yapf: disable

W
wanghaoshuang 已提交
33 34 35 36 37 38 39
def load_parameter(place):
    params = load_param('./name.map', './data/model/results_without_avg_window/pass-00000/')
    for name in params:
        t = fluid.global_scope().find_var(name).get_tensor()
        t.set(params[name], place)


W
wanghaoshuang 已提交
40 41 42 43 44 45 46
def train(args, data_reader=dummy_reader):
    """OCR CTC training"""
    num_classes = data_reader.num_classes()
    data_shape = data_reader.data_shape()
    # define network
    images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
    label = fluid.layers.data(name='label', shape=[1], dtype='int32', lod_level=1)
47
    sum_cost, error_evaluator, inference_program, model_average = ctc_train_net(images, label, args, num_classes)
W
wanghaoshuang 已提交
48

W
wanghaoshuang 已提交
49 50 51 52 53 54 55 56 57
    # data reader
    train_reader = data_reader.train(args.batch_size)
    test_reader = data_reader.test()
    # prepare environment
    place = fluid.CPUPlace()
    if args.device >= 0:
        place = fluid.CUDAPlace(args.device)
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())
58
    #load_parameter(place)
W
wanghaoshuang 已提交
59

W
wanghaoshuang 已提交
60 61
    for pass_id in range(args.pass_num):
        error_evaluator.reset(exe)
W
wanghaoshuang 已提交
62 63 64
        batch_id = 1
        total_loss = 0.0
        total_seq_error = 0.0
W
wanghaoshuang 已提交
65 66
        # train a pass
        for data in train_reader():
W
wanghaoshuang 已提交
67
            batch_loss, _, batch_seq_error = exe.run(
W
wanghaoshuang 已提交
68 69
                fluid.default_main_program(),
                feed=get_feeder_data(data, place),
70
                fetch_list=[sum_cost] + error_evaluator.metrics)
W
wanghaoshuang 已提交
71 72
            total_loss += batch_loss[0]
            total_seq_error += batch_seq_error[0]
W
wanghaoshuang 已提交
73
            if batch_id % 100 == 1:
W
wanghaoshuang 已提交
74 75 76
                print '.',
                sys.stdout.flush()
            if batch_id % args.log_period == 1:
W
wanghaoshuang 已提交
77 78
                print "\nTime: %s; Pass[%d]-batch[%d]; Avg Warp-CTC loss: %s; Avg seq error: %s." % (
                    time.time(),
79
                    pass_id, batch_id, total_loss / (batch_id * args.batch_size), total_seq_error / (batch_id * args.batch_size))
W
wanghaoshuang 已提交
80 81 82
                sys.stdout.flush()
            batch_id += 1

83 84 85 86 87
        with model_average.apply(exe):
            error_evaluator.reset(exe)
            for data in test_reader():
                exe.run(inference_program, feed=get_feeder_data(data, place))
            _, test_seq_error = error_evaluator.eval(exe)
W
wanghaoshuang 已提交
88

89 90
            print "\nEnd pass[%d]; Test seq error: %s.\n" % (
                pass_id, str(test_seq_error[0]))
W
wanghaoshuang 已提交
91 92 93 94

def main():
    args = parser.parse_args()
    print_arguments(args)
W
wanghaoshuang 已提交
95
    train(args, data_reader=ctc_reader)
W
wanghaoshuang 已提交
96 97 98

if __name__ == "__main__":
    main()