train.py 2.7 KB
Newer Older
D
dzhwinter 已提交
1
import os
2 3 4
import paddle.v2 as paddle
import paddle.v2.dataset.uci_housing as uci_housing

D
dzhwinter 已提交
5
with_gpu = os.getenv('WITH_GPU', '0') != '0'
D
dongzhihong 已提交
6

G
gongweibao 已提交
7

8 9
def main():
    # init
D
dzhwinter 已提交
10
    paddle.init(use_gpu=with_gpu, trainer_count=1)
11 12 13

    # network config
    x = paddle.layer.data(name='x', type=paddle.data_type.dense_vector(13))
G
gongweibao 已提交
14
    y_predict = paddle.layer.fc(input=x, size=1, act=paddle.activation.Linear())
15
    y = paddle.layer.data(name='y', type=paddle.data_type.dense_vector(1))
T
typhoonzero 已提交
16
    cost = paddle.layer.square_error_cost(input=y_predict, label=y)
17

Q
qiaolongfei 已提交
18 19 20 21 22
    # Save the inference topology to protobuf.
    inference_topology = paddle.topology.Topology(layers=y_predict)
    with open("inference_topology.pkl", 'wb') as f:
        inference_topology.serialize_for_inference(f)

23 24 25 26 27 28
    # create parameters
    parameters = paddle.parameters.create(cost)

    # create optimizer
    optimizer = paddle.optimizer.Momentum(momentum=0)

29 30
    trainer = paddle.trainer.SGD(
        cost=cost, parameters=parameters, update_equation=optimizer)
31

G
gongweibao 已提交
32
    feeding = {'x': 0, 'y': 1}
33 34 35 36 37 38 39 40 41

    # event_handler to print training and testing info
    def event_handler(event):
        if isinstance(event, paddle.event.EndIteration):
            if event.batch_id % 100 == 0:
                print "Pass %d, Batch %d, Cost %f" % (
                    event.pass_id, event.batch_id, event.cost)

        if isinstance(event, paddle.event.EndPass):
Q
qiaolongfei 已提交
42 43
            if event.pass_id % 10 == 0:
                with open('params_pass_%d.tar' % event.pass_id, 'w') as f:
44
                    trainer.save_parameter_to_tar(f)
45
            result = trainer.test(
46
                reader=paddle.batch(uci_housing.test(), batch_size=2),
G
gongweibao 已提交
47
                feeding=feeding)
48 49 50 51
            print "Test %d, Cost %f" % (event.pass_id, result.cost)

    # training
    trainer.train(
G
gongweibao 已提交
52
        reader=paddle.batch(
53
            paddle.reader.shuffle(uci_housing.train(), buf_size=500),
54
            batch_size=2),
G
gongweibao 已提交
55
        feeding=feeding,
56 57 58
        event_handler=event_handler,
        num_passes=30)

Q
qiaolongfei 已提交
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
    # inference
    test_data_creator = paddle.dataset.uci_housing.test()
    test_data = []
    test_label = []

    for item in test_data_creator():
        test_data.append((item[0], ))
        test_label.append(item[1])
        if len(test_data) == 5:
            break

    # load parameters from tar file.
    # users can remove the comments and change the model name
    # with open('params_pass_20.tar', 'r') as f:
    #     parameters = paddle.parameters.Parameters.from_tar(f)

    probs = paddle.infer(
        output_layer=y_predict, parameters=parameters, input=test_data)

    for i in xrange(len(probs)):
        print "label=" + str(test_label[i][0]) + ", predict=" + str(probs[i][0])

G
gongweibao 已提交
81

82 83
if __name__ == '__main__':
    main()