api_train_v2.py 1.9 KB
Newer Older
Y
Yu Yang 已提交
1
import numpy
Q
qiaolongfei 已提交
2 3
import paddle.v2 as paddle

Y
Yu Yang 已提交
4 5

def main():
Y
Yu Yang 已提交
6
    paddle.init(use_gpu=False, trainer_count=1)
Q
qiaolongfei 已提交
7 8

    # define network topology
9 10 11 12
    images = paddle.layer.data(
        name='pixel', type=paddle.data_type.dense_vector(784))
    label = paddle.layer.data(
        name='label', type=paddle.data_type.integer_value(10))
Q
qiaolongfei 已提交
13 14 15 16 17 18 19
    hidden1 = paddle.layer.fc(input=images, size=200)
    hidden2 = paddle.layer.fc(input=hidden1, size=200)
    inference = paddle.layer.fc(input=hidden2,
                                size=10,
                                act=paddle.activation.Softmax())
    cost = paddle.layer.classification_cost(input=inference, label=label)

Q
qiaolongfei 已提交
20
    parameters = paddle.parameters.create(cost)
Y
Yu Yang 已提交
21
    for param_name in parameters.keys():
Y
Yu Yang 已提交
22
        array = parameters.get(param_name)
Y
Yu Yang 已提交
23
        array[:] = numpy.random.uniform(low=-1.0, high=1.0, size=array.shape)
Y
Yu Yang 已提交
24
        parameters.set(parameter_name=param_name, value=array)
Y
Yu Yang 已提交
25

Q
qiaolongfei 已提交
26
    adam_optimizer = paddle.optimizer.Adam(learning_rate=0.01)
Y
Yu Yang 已提交
27 28

    def event_handler(event):
Y
Yu Yang 已提交
29
        if isinstance(event, paddle.event.EndIteration):
Q
qiaolongfei 已提交
30
            para = parameters.get('___fc_2__.w0')
Y
Yu Yang 已提交
31 32
            print "Pass %d, Batch %d, Cost %f, Weight Mean Of Fc 2 is %f" % (
                event.pass_id, event.batch_id, event.cost, para.mean())
Y
Yu Yang 已提交
33

Y
Yu Yang 已提交
34 35
        else:
            pass
36

Y
Yu Yang 已提交
37
    trainer = paddle.trainer.SGD(update_equation=adam_optimizer)
Y
Yu Yang 已提交
38

Y
Yu Yang 已提交
39 40 41 42 43 44 45 46 47 48 49 50 51
    trainer.train(
        train_reader=paddle.reader.batched(
            paddle.reader.shuffle(paddle.dataset.mnist.train_creator(),
                                  buf_size=8192), batch_size=32),
        topology=cost,
        parameters=parameters,
        event_handler=event_handler,
        data_types=[  # data_types will be removed, It should be in
            # network topology
            ('pixel', images.type),
            ('label', label.type)],
        reader_dict={'pixel': 0, 'label': 1}
    )
Y
Yu Yang 已提交
52 53 54 55


if __name__ == '__main__':
    main()