train.py 3.5 KB
Newer Older
P
peterzhang2029 已提交
1 2 3
import logging
import argparse
import gzip
4 5

import paddle.v2 as paddle
P
peterzhang2029 已提交
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
from model import Model
from data_provider import get_file_list, AsciiDic, ImageDataset

parser = argparse.ArgumentParser(description="PaddlePaddle CTC example")
parser.add_argument(
    '--image_shape',
    type=str,
    required=True,
    help="image's shape, format is like '173,46'")
parser.add_argument(
    '--train_file_list',
    type=str,
    required=True,
    help='path of the file which contains path list of train image files')
parser.add_argument(
    '--test_file_list',
    type=str,
    required=True,
    help='path of the file which contains path list of test image files')
parser.add_argument(
    '--batch_size', type=int, default=5, help='size of a mini-batch')
parser.add_argument(
    '--model_output_prefix',
    type=str,
    default='model.ctc',
    help='prefix of path for model to store (default: ./model.ctc)')
parser.add_argument(
    '--trainer_count', type=int, default=4, help='number of training threads')
parser.add_argument(
    '--save_period_by_batch',
    type=int,
37
    default=150,
P
peterzhang2029 已提交
38 39 40 41
    help='save model to disk every N batches')
parser.add_argument(
    '--num_passes',
    type=int,
42
    default=10,
P
peterzhang2029 已提交
43 44 45 46
    help='number of passes to train (default: 1)')

args = parser.parse_args()

47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109

def main():
    image_shape = tuple(map(int, args.image_shape.split(',')))

    print 'image_shape', image_shape
    print 'batch_size', args.batch_size
    print 'train_file_list', args.train_file_list
    print 'test_file_list', args.test_file_list

    train_generator = get_file_list(args.train_file_list)
    test_generator = get_file_list(args.test_file_list)
    infer_generator = None

    dataset = ImageDataset(
        train_generator,
        test_generator,
        infer_generator,
        fixed_shape=image_shape,
        is_infer=False)

    paddle.init(use_gpu=True, trainer_count=args.trainer_count)

    model = Model(AsciiDic().size(), image_shape, is_infer=False)
    params = paddle.parameters.create(model.cost)
    optimizer = paddle.optimizer.Momentum(momentum=0)
    trainer = paddle.trainer.SGD(
        cost=model.cost,
        parameters=params,
        update_equation=optimizer,
        extra_layers=model.eval)

    def event_handler(event):
        if isinstance(event, paddle.event.EndIteration):
            if event.batch_id % 100 == 0:
                print "Pass %d, batch %d, Samples %d, Cost %f, Eval %s" % (
                    event.pass_id, event.batch_id,
                    event.batch_id * args.batch_size, event.cost, event.metrics)

            if event.batch_id > 0 and event.batch_id % args.save_period_by_batch == 0:
                result = trainer.test(
                    reader=paddle.batch(dataset.test, batch_size=10),
                    feeding={'image': 0,
                             'label': 1})
                print "Test %d-%d, Cost %f, Eval %s" % (
                    event.pass_id, event.batch_id, result.cost, result.metrics)

                path = "{}-pass-{}-batch-{}-test.tar.gz".format(
                    args.model_output_prefix, event.pass_id, event.batch_id)
                with gzip.open(path, 'w') as f:
                    params.to_tar(f)

    trainer.train(
        reader=paddle.batch(
            paddle.reader.shuffle(dataset.train, buf_size=500),
            batch_size=args.batch_size),
        feeding={'image': 0,
                 'label': 1},
        event_handler=event_handler,
        num_passes=args.num_passes)


if __name__ == "__main__":
    main()