api_train_v2.py 1.6 KB
Newer Older
Q
qiaolongfei 已提交
1 2
import paddle.v2 as paddle

Y
Yu Yang 已提交
3 4 5 6 7 8 9 10 11 12 13
import mnist_util


def train_reader():
    train_file = './data/raw_data/train'
    generator = mnist_util.read_from_mnist(train_file)
    for item in generator:
        yield item


def main():
Y
Yu Yang 已提交
14
    paddle.init(use_gpu=False, trainer_count=1)
Q
qiaolongfei 已提交
15 16

    # define network topology
17 18 19 20
    images = paddle.layer.data(
        name='pixel', type=paddle.data_type.dense_vector(784))
    label = paddle.layer.data(
        name='label', type=paddle.data_type.integer_value(10))
Q
qiaolongfei 已提交
21 22 23 24 25 26 27
    hidden1 = paddle.layer.fc(input=images, size=200)
    hidden2 = paddle.layer.fc(input=hidden1, size=200)
    inference = paddle.layer.fc(input=hidden2,
                                size=10,
                                act=paddle.activation.Softmax())
    cost = paddle.layer.classification_cost(input=inference, label=label)

Q
qiaolongfei 已提交
28
    parameters = paddle.parameters.create(cost)
Y
Yu Yang 已提交
29

Q
qiaolongfei 已提交
30
    adam_optimizer = paddle.optimizer.Adam(learning_rate=0.01)
Y
Yu Yang 已提交
31 32

    def event_handler(event):
Y
Yu Yang 已提交
33
        if isinstance(event, paddle.event.EndIteration):
Y
Yu Yang 已提交
34 35 36
            if event.batch_id % 100 == 0:
                print "Pass %d, Batch %d, Cost %f, %s" % (
                    event.pass_id, event.batch_id, event.cost, event.metrics)
Y
Yu Yang 已提交
37 38
        else:
            pass
39

Y
Yu Yang 已提交
40
    trainer = paddle.trainer.SGD(update_equation=adam_optimizer)
Y
Yu Yang 已提交
41

Q
qiaolongfei 已提交
42 43
    trainer.train(
        train_data_reader=train_reader,
44
        cost=cost,
Q
qiaolongfei 已提交
45 46
        parameters=parameters,
        event_handler=event_handler,
Q
qiaolongfei 已提交
47 48 49
        batch_size=32,  # batch size should be refactor in Data reader
        reader_dict={images.name: 0,
                     label.name: 1})
Y
Yu Yang 已提交
50 51 52 53


if __name__ == '__main__':
    main()