api_train_v2.py 2.1 KB
Newer Older
Y
Yu Yang 已提交
1
import numpy
Q
qiaolongfei 已提交
2 3 4
import paddle.v2 as paddle
from paddle.trainer.PyDataProvider2 import dense_vector, integer_value

Y
Yu Yang 已提交
5 6 7 8 9 10 11 12 13 14 15
import mnist_util


def train_reader():
    train_file = './data/raw_data/train'
    generator = mnist_util.read_from_mnist(train_file)
    for item in generator:
        yield item


def main():
Y
Yu Yang 已提交
16
    paddle.init(use_gpu=False, trainer_count=1)
Q
qiaolongfei 已提交
17 18

    # define network topology
Q
qiaolongfei 已提交
19 20
    images = paddle.layer.data(name='pixel', data_type=dense_vector(784))
    label = paddle.layer.data(name='label', data_type=integer_value(10))
Q
qiaolongfei 已提交
21 22 23 24 25 26 27
    hidden1 = paddle.layer.fc(input=images, size=200)
    hidden2 = paddle.layer.fc(input=hidden1, size=200)
    inference = paddle.layer.fc(input=hidden2,
                                size=10,
                                act=paddle.activation.Softmax())
    cost = paddle.layer.classification_cost(input=inference, label=label)

Q
qiaolongfei 已提交
28
    parameters = paddle.parameters.create(cost)
Y
Yu Yang 已提交
29
    for param_name in parameters.keys():
Y
Yu Yang 已提交
30
        array = parameters.get(param_name)
Y
Yu Yang 已提交
31
        array[:] = numpy.random.uniform(low=-1.0, high=1.0, size=array.shape)
Y
Yu Yang 已提交
32
        parameters.set(parameter_name=param_name, value=array)
Y
Yu Yang 已提交
33

Q
qiaolongfei 已提交
34
    adam_optimizer = paddle.optimizer.Adam(learning_rate=0.01)
Y
Yu Yang 已提交
35 36

    def event_handler(event):
Y
Yu Yang 已提交
37
        if isinstance(event, paddle.event.EndIteration):
Q
qiaolongfei 已提交
38
            para = parameters.get('___fc_2__.w0')
Y
Yu Yang 已提交
39 40
            print "Pass %d, Batch %d, Cost %f, Weight Mean Of Fc 2 is %f" % (
                event.pass_id, event.batch_id, event.cost, para.mean())
Y
Yu Yang 已提交
41

Y
Yu Yang 已提交
42 43
        else:
            pass
44

Y
Yu Yang 已提交
45
    trainer = paddle.trainer.SGD(update_equation=adam_optimizer)
Y
Yu Yang 已提交
46 47

    trainer.train(train_data_reader=train_reader,
Q
qiaolongfei 已提交
48
                  topology=cost,
Y
Yu Yang 已提交
49
                  parameters=parameters,
Y
Yu Yang 已提交
50 51 52 53
                  event_handler=event_handler,
                  batch_size=32,  # batch size should be refactor in Data reader
                  data_types={  # data_types will be removed, It should be in
                      # network topology
Q
qiaolongfei 已提交
54 55
                      'pixel': images.data_type,
                      'label': label.data_type
Y
Yu Yang 已提交
56 57 58 59 60
                  })


if __name__ == '__main__':
    main()