train.py 3.0 KB
Newer Older
Y
yangyaming 已提交
1 2 3 4 5 6
import paddle.v2 as paddle
import data_provider
import vgg_ssd_net
import os, sys
import gzip
import tarfile
7
from config.pascal_voc_conf import cfg
Y
yangyaming 已提交
8 9 10 11 12 13 14 15 16 17 18 19


def train(train_file_list, dev_file_list, data_args, init_model_path):
    optimizer = paddle.optimizer.Momentum(
        momentum=cfg.TRAIN.MOMENTUM,
        learning_rate=cfg.TRAIN.LEARNING_RATE,
        regularization=paddle.optimizer.L2Regularization(
            rate=cfg.TRAIN.L2REGULARIZATION),
        learning_rate_decay_a=cfg.TRAIN.LEARNING_RATE_DECAY_A,
        learning_rate_decay_b=cfg.TRAIN.LEARNING_RATE_DECAY_B,
        learning_rate_schedule=cfg.TRAIN.LEARNING_RATE_SCHEDULE)

20 21 22 23 24 25 26
    cost, detect_out = vgg_ssd_net.net_conf('train')

    parameters = paddle.parameters.create(cost)
    if not (init_model_path is None):
        assert os.path.isfile(init_model_path), 'Invalid model.'
        parameters.init_from_tar(gzip.open(init_model_path))

27 28 29 30
    trainer = paddle.trainer.SGD(cost=cost,
                                 parameters=parameters,
                                 extra_layers=[detect_out],
                                 update_equation=optimizer)
Y
yangyaming 已提交
31 32 33 34

    feeding = {'image': 0, 'bbox': 1}

    train_reader = paddle.batch(
35
        data_provider.train(data_args, train_file_list),
Y
yangyaming 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
        batch_size=cfg.TRAIN.BATCH_SIZE)  # generate a batch image each time

    dev_reader = paddle.batch(
        data_provider.test(data_args, dev_file_list),
        batch_size=cfg.TRAIN.BATCH_SIZE)

    def event_handler(event):
        if isinstance(event, paddle.event.EndIteration):
            if event.batch_id % 1 == 0:
                print "\nPass %d, Batch %d, TrainCost %f, Detection mAP=%f" % \
                        (event.pass_id,
                         event.batch_id,
                         event.cost,
                         event.metrics['detection_evaluator'])
            else:
                sys.stdout.write('.')
                sys.stdout.flush()

        if isinstance(event, paddle.event.EndPass):
Y
yangyaming 已提交
55
            with gzip.open('checkpoints/params_pass_%05d.tar.gz' % \
Y
yangyaming 已提交
56
                    event.pass_id, 'w') as f:
57
                trainer.save_parameter_to_tar(f)
Y
yangyaming 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
            result = trainer.test(reader=dev_reader, feeding=feeding)
            print "\nTest with Pass %d, TestCost: %f, Detection mAP=%g" % \
                    (event.pass_id,
                     result.cost,
                     result.metrics['detection_evaluator'])

    trainer.train(
        reader=train_reader,
        event_handler=event_handler,
        num_passes=cfg.TRAIN.NUM_PASS,
        feeding=feeding)


if __name__ == "__main__":
    paddle.init(use_gpu=True, trainer_count=4)
    data_args = data_provider.Settings(
        data_dir='./data',
        label_file='label_list',
        resize_h=cfg.IMG_HEIGHT,
        resize_w=cfg.IMG_WIDTH,
        mean_value=[104, 117, 124])
    train(
        train_file_list='./data/trainval.txt',
        dev_file_list='./data/test.txt',
        data_args=data_args,
Y
yangyaming 已提交
83
        init_model_path='./vgg/vgg_model.tar.gz')