api_v2_train.py 3.0 KB
Newer Older
L
liaogang 已提交
1 2 3 4 5 6 7 8 9 10 11 12
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
L
liaogang 已提交
13
# limitations under the License
L
liaogang 已提交
14

L
liaogang 已提交
15
import sys
Y
Yu Yang 已提交
16

L
liaogang 已提交
17
import paddle.v2 as paddle
Y
Yu Yang 已提交
18

L
liaogang 已提交
19
from api_v2_resnet import resnet_cifar10
L
liaogang 已提交
20 21 22 23 24 25


def main():
    datadim = 3 * 32 * 32
    classdim = 10

L
liaogang 已提交
26
    # PaddlePaddle init
Y
Yu Yang 已提交
27
    paddle.init(use_gpu=False, trainer_count=1)
L
liaogang 已提交
28 29 30

    image = paddle.layer.data(
        name="image", type=paddle.data_type.dense_vector(datadim))
L
liaogang 已提交
31

L
liaogang 已提交
32
    # Add neural network config
L
liaogang 已提交
33 34 35
    # option 1. resnet
    net = resnet_cifar10(image, depth=32)
    # option 2. vgg
L
liaogang 已提交
36
    # net = vgg_bn_drop(image)
L
liaogang 已提交
37 38

    out = paddle.layer.fc(input=net,
L
liaogang 已提交
39 40 41 42 43 44 45
                          size=classdim,
                          act=paddle.activation.Softmax())

    lbl = paddle.layer.data(
        name="label", type=paddle.data_type.integer_value(classdim))
    cost = paddle.layer.classification_cost(input=out, label=lbl)

L
liaogang 已提交
46
    # Create parameters
L
liaogang 已提交
47
    parameters = paddle.parameters.create(cost)
L
liaogang 已提交
48

L
liaogang 已提交
49
    # Create optimizer
L
liaogang 已提交
50 51
    momentum_optimizer = paddle.optimizer.Momentum(
        momentum=0.9,
L
liaogang 已提交
52
        regularization=paddle.optimizer.L2Regularization(rate=0.0002 * 128),
L
liaogang 已提交
53 54 55 56 57 58
        learning_rate=0.1 / 128.0,
        learning_rate_decay_a=0.1,
        learning_rate_decay_b=50000 * 100,
        learning_rate_schedule='discexp',
        batch_size=128)

L
liaogang 已提交
59 60 61 62 63 64 65 66 67 68 69
    # End batch and end pass event handler
    def event_handler(event):
        if isinstance(event, paddle.event.EndIteration):
            if event.batch_id % 100 == 0:
                print "\nPass %d, Batch %d, Cost %f, %s" % (
                    event.pass_id, event.batch_id, event.cost, event.metrics)
            else:
                sys.stdout.write('.')
                sys.stdout.flush()
        if isinstance(event, paddle.event.EndPass):
            result = trainer.test(
70
                reader=paddle.batch(
L
liaogang 已提交
71
                    paddle.dataset.cifar.test10(), batch_size=128),
Y
Yu Yang 已提交
72 73
                feeding={'image': 0,
                         'label': 1})
L
liaogang 已提交
74 75
            print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)

L
liaogang 已提交
76
    # Create trainer
L
liaogang 已提交
77 78 79
    trainer = paddle.trainer.SGD(cost=cost,
                                 parameters=parameters,
                                 update_equation=momentum_optimizer)
L
liaogang 已提交
80
    trainer.train(
81
        reader=paddle.batch(
L
liaogang 已提交
82
            paddle.reader.shuffle(
L
liaogang 已提交
83
                paddle.dataset.cifar.train10(), buf_size=50000),
L
liaogang 已提交
84
            batch_size=128),
L
liaogang 已提交
85
        num_passes=5,
L
liaogang 已提交
86
        event_handler=event_handler,
Y
Yu Yang 已提交
87 88
        feeding={'image': 0,
                 'label': 1})
L
liaogang 已提交
89 90 91 92


if __name__ == '__main__':
    main()