api_train_v2.py 4.8 KB
Newer Older
Q
qiaolongfei 已提交
1
import paddle.v2 as paddle
Y
Yu Yang 已提交
2
import gzip
Q
qiaolongfei 已提交
3

Y
Yu Yang 已提交
4

L
Luo Tao 已提交
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
def softmax_regression(img):
    predict = paddle.layer.fc(input=img,
                              size=10,
                              act=paddle.activation.Softmax())
    return predict


def multilayer_perceptron(img):
    # The first fully-connected layer
    hidden1 = paddle.layer.fc(input=img, size=128, act=paddle.activation.Relu())
    # The second fully-connected layer and the according activation function
    hidden2 = paddle.layer.fc(input=hidden1,
                              size=64,
                              act=paddle.activation.Relu())
    # The thrid fully-connected layer, note that the hidden size should be 10,
    # which is the number of unique digits
    predict = paddle.layer.fc(input=hidden2,
                              size=10,
                              act=paddle.activation.Softmax())
    return predict


def convolutional_neural_network(img):
    # first conv layer
    conv_pool_1 = paddle.networks.simple_img_conv_pool(
        input=img,
        filter_size=5,
        num_filters=20,
        num_channel=1,
        pool_size=2,
        pool_stride=2,
        act=paddle.activation.Tanh())
    # second conv layer
    conv_pool_2 = paddle.networks.simple_img_conv_pool(
        input=conv_pool_1,
        filter_size=5,
        num_filters=50,
        num_channel=20,
        pool_size=2,
        pool_stride=2,
        act=paddle.activation.Tanh())
    # The first fully-connected layer
    fc1 = paddle.layer.fc(input=conv_pool_2,
                          size=128,
                          act=paddle.activation.Tanh())
    # The softmax layer, note that the hidden size should be 10,
    # which is the number of unique digits
    predict = paddle.layer.fc(input=fc1,
                              size=10,
                              act=paddle.activation.Softmax())
    return predict


Y
Yu Yang 已提交
58
def main():
Y
Yu Yang 已提交
59
    paddle.init(use_gpu=False, trainer_count=1)
Q
qiaolongfei 已提交
60 61

    # define network topology
62 63 64 65
    images = paddle.layer.data(
        name='pixel', type=paddle.data_type.dense_vector(784))
    label = paddle.layer.data(
        name='label', type=paddle.data_type.integer_value(10))
L
Luo Tao 已提交
66

L
Luo Tao 已提交
67 68
    # Here we can build the prediction network in different ways. Please
    # choose one by uncomment corresponding line.
L
Luo Tao 已提交
69 70 71 72 73
    predict = softmax_regression(images)
    #predict = multilayer_perceptron(images)
    #predict = convolutional_neural_network(images)

    cost = paddle.layer.classification_cost(input=predict, label=label)
Q
qiaolongfei 已提交
74

Y
Yu Yang 已提交
75
    try:
Y
Yu Yang 已提交
76 77
        with gzip.open('params.tar.gz', 'r') as f:
            parameters = paddle.parameters.Parameters.from_tar(f)
Y
Yu Yang 已提交
78 79
    except IOError:
        parameters = paddle.parameters.create(cost)
Y
Yu Yang 已提交
80

L
Luo Tao 已提交
81 82 83 84
    optimizer = paddle.optimizer.Momentum(
        learning_rate=0.1 / 128.0,
        momentum=0.9,
        regularization=paddle.optimizer.L2Regularization(rate=0.0005 * 128))
Y
Yu Yang 已提交
85

Y
Yu Yang 已提交
86
    trainer = paddle.trainer.SGD(cost=cost,
Y
Yu Yang 已提交
87
                                 parameters=parameters,
L
Luo Tao 已提交
88
                                 update_equation=optimizer)
Y
Yu Yang 已提交
89

L
Luo Tao 已提交
90
    lists = []
L
Luo Tao 已提交
91

Y
Yu Yang 已提交
92
    def event_handler(event):
Y
Yu Yang 已提交
93
        if isinstance(event, paddle.event.EndIteration):
Y
Yu Yang 已提交
94
            if event.batch_id % 1000 == 0:
Y
Yu Yang 已提交
95
                result = trainer.test(reader=paddle.batch(
Y
Yu Yang 已提交
96
                    paddle.dataset.mnist.test(), batch_size=256))
Y
Yu Yang 已提交
97 98 99 100 101

                print "Pass %d, Batch %d, Cost %f, %s, Testing metrics %s" % (
                    event.pass_id, event.batch_id, event.cost, event.metrics,
                    result.metrics)

Y
Yu Yang 已提交
102 103
                with gzip.open('params.tar.gz', 'w') as f:
                    parameters.to_tar(f)
Y
Yu Yang 已提交
104

105
        elif isinstance(event, paddle.event.EndPass):
Y
Yu Yang 已提交
106
            result = trainer.test(reader=paddle.batch(
L
Luo Tao 已提交
107 108
                paddle.dataset.mnist.test(), batch_size=128))
            print "Test with Pass %d, Cost %f, %s\n" % (
L
Luo Tao 已提交
109 110 111
                event.pass_id, result.cost, result.metrics)
            lists.append((event.pass_id, result.cost,
                          result.metrics['classification_error_evaluator']))
Y
Yu Yang 已提交
112

Y
Yu Yang 已提交
113
    trainer.train(
114
        reader=paddle.batch(
Y
Yu Yang 已提交
115
            paddle.reader.shuffle(
Y
Yu Yang 已提交
116
                paddle.dataset.mnist.train(), buf_size=8192),
L
Luo Tao 已提交
117 118 119
            batch_size=128),
        event_handler=event_handler,
        num_passes=100)
Y
Yu Yang 已提交
120

L
Luo Tao 已提交
121
    # find the best pass
L
Luo Tao 已提交
122
    best = sorted(lists, key=lambda list: float(list[1]))[0]
L
Luo Tao 已提交
123 124 125
    print 'Best pass is %s, testing Avgcost is %s' % (best[0], best[1])
    print 'The classification accuracy is %.2f%%' % (100 - float(best[2]) * 100)

Y
Yu Yang 已提交
126 127 128
    test_creator = paddle.dataset.mnist.test()
    test_data = []
    for item in test_creator():
Y
Yu Yang 已提交
129
        test_data.append((item[0], ))
Y
Yu Yang 已提交
130 131 132
        if len(test_data) == 100:
            break

Y
Yu Yang 已提交
133 134
    # output is a softmax layer. It returns probabilities.
    # Shape should be (100, 10)
Y
Yu Yang 已提交
135
    probs = paddle.infer(output=predict, parameters=parameters, input=test_data)
Y
Yu Yang 已提交
136 137
    print probs.shape

Y
Yu Yang 已提交
138 139 140

if __name__ == '__main__':
    main()