ctc_train.py 3.6 KB
Newer Older
W
wanghaoshuang 已提交
1 2
"""Trainer for OCR CTC model."""
import paddle.v2 as paddle
W
wanghaoshuang 已提交
3
import paddle.fluid as fluid
W
wanghaoshuang 已提交
4
import dummy_reader
W
wanghaoshuang 已提交
5
import ctc_reader
W
wanghaoshuang 已提交
6
import argparse
W
wanghaoshuang 已提交
7
from load_model import load_param
W
wanghaoshuang 已提交
8 9 10 11 12 13 14 15 16
import functools
import sys
from utility import add_arguments, print_arguments, to_lodtensor, get_feeder_data
from crnn_ctc_model import ctc_train_net

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size',     int,   32,     "Minibatch size.")
W
wanghaoshuang 已提交
17
add_arg('pass_num',       int,   100,     "# of training epochs.")
W
wanghaoshuang 已提交
18 19 20 21 22 23 24 25 26 27 28
add_arg('log_period',     int,   1000,   "Log period.")
add_arg('learning_rate',  float, 1.0e-3, "Learning rate.")
add_arg('l2',             float, 0.0004, "L2 regularizer.")
add_arg('max_clip',       float, 10.0,   "Max clip threshold.")
add_arg('min_clip',       float, -10.0,  "Min clip threshold.")
add_arg('momentum',       float, 0.9,    "Momentum.")
add_arg('rnn_hidden_size',int,   200,    "Hidden size of rnn layers.")
add_arg('device',         int,   0,      "Device id.'-1' means running on CPU"
                                         "while '0' means GPU-0.")
# yapf: disable

W
wanghaoshuang 已提交
29 30 31 32 33 34 35 36
def load_parameter(place):
    params = load_param('./name.map', './data/model/results_without_avg_window/pass-00000/')
    for name in params:
        #        print "param: %s" % name
        t = fluid.global_scope().find_var(name).get_tensor()
        t.set(params[name], place)


W
wanghaoshuang 已提交
37 38 39 40 41 42 43
def train(args, data_reader=dummy_reader):
    """OCR CTC training"""
    num_classes = data_reader.num_classes()
    data_shape = data_reader.data_shape()
    # define network
    images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
    label = fluid.layers.data(name='label', shape=[1], dtype='int32', lod_level=1)
44
    sum_cost, error_evaluator = ctc_train_net(images, label, args, num_classes)
W
wanghaoshuang 已提交
45 46 47 48 49 50 51 52 53 54
    # data reader
    train_reader = data_reader.train(args.batch_size)
    test_reader = data_reader.test()
    # prepare environment
    place = fluid.CPUPlace()
    if args.device >= 0:
        place = fluid.CUDAPlace(args.device)
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())

55
    #load_parameter(place)
W
wanghaoshuang 已提交
56

W
wanghaoshuang 已提交
57
    inference_program = fluid.io.get_inference_program(error_evaluator)
W
wanghaoshuang 已提交
58

W
wanghaoshuang 已提交
59 60
    for pass_id in range(args.pass_num):
        error_evaluator.reset(exe)
W
wanghaoshuang 已提交
61 62 63
        batch_id = 1
        total_loss = 0.0
        total_seq_error = 0.0
W
wanghaoshuang 已提交
64 65
        # train a pass
        for data in train_reader():
W
wanghaoshuang 已提交
66
            batch_loss, _, batch_seq_error = exe.run(
W
wanghaoshuang 已提交
67 68
                fluid.default_main_program(),
                feed=get_feeder_data(data, place),
69
                fetch_list=[sum_cost] + error_evaluator.metrics)
W
wanghaoshuang 已提交
70 71 72 73 74 75 76
            total_loss += batch_loss[0]
            total_seq_error += batch_seq_error[0]
            if batch_id % 10 == 1:
                print '.',
                sys.stdout.flush()
            if batch_id % args.log_period == 1:
                print "\nPass[%d]-batch[%d]; Avg Warp-CTC loss: %s; Avg seq error: %s." % (
77
                    pass_id, batch_id, total_loss / (batch_id * args.batch_size), total_seq_error / (batch_id * args.batch_size))
W
wanghaoshuang 已提交
78 79 80 81 82 83 84
                sys.stdout.flush()
            batch_id += 1

        # evaluate model on test data
        error_evaluator.reset(exe)
        for data in test_reader():
            exe.run(inference_program, feed=get_feeder_data(data, place))
W
wanghaoshuang 已提交
85 86 87
        _, test_seq_error = error_evaluator.eval(exe)
        print "\nEnd pass[%d]; Test seq error: %s.\n" % (
            pass_id, str(test_seq_error[0]))
W
wanghaoshuang 已提交
88 89 90 91

def main():
    args = parser.parse_args()
    print_arguments(args)
W
wanghaoshuang 已提交
92
    train(args, data_reader=ctc_reader)
W
wanghaoshuang 已提交
93 94 95

if __name__ == "__main__":
    main()