api_train_v2.py 2.3 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
from paddle.trainer_config_helpers import *
from paddle.trainer.PyDataProvider2 import dense_vector, integer_value
import paddle.v2 as paddle_v2
import numpy
import mnist_util


def train_reader():
    train_file = './data/raw_data/train'
    generator = mnist_util.read_from_mnist(train_file)
    for item in generator:
        yield item


def network_config():
    imgs = data_layer(name='pixel', size=784)
    hidden1 = fc_layer(input=imgs, size=200)
    hidden2 = fc_layer(input=hidden1, size=200)
    inference = fc_layer(input=hidden2, size=10, act=SoftmaxActivation())
    cost = classification_cost(
        input=inference, label=data_layer(
            name='label', size=10))
    outputs(cost)


def event_handler(event):
    if isinstance(event, paddle_v2.trainer.CompleteTrainOneBatch):
        print "Pass %d, Batch %d, Cost %f" % (event.pass_id, event.batch_id,
                                              event.cost)
    else:
        pass


def main():
    paddle_v2.init(use_gpu=False, trainer_count=1)
    model_config = parse_network_config(network_config)
    pool = paddle_v2.parameters.create(model_config)
    for param_name in pool.get_names():
        array = pool.get_parameter(param_name)
        array[:] = numpy.random.uniform(low=-1.0, high=1.0, size=array.shape)

42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
    def nag(v, g, vel_t_1):
        """
        NAG Optimizer. A optimizer which Paddle CPP is not implemented.
        https://arxiv.org/pdf/1212.0901v2.pdf eq.6 eq.7
        :param v: parameter value
        :param g: parameter gradient
        :param vel_t_1: t-1 velocity
        :return:
        """
        mu = 0.09
        e = 0.0001

        vel_t = mu * vel_t_1 - e * g

        v[:] = v + (mu**2) * vel_t - (1 + mu) * e * g
        vel_t_1[:] = vel_t

    trainer = paddle_v2.trainer.SGDTrainer(update_equation=nag)
Y
Yu Yang 已提交
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74

    trainer.train(train_data_reader=train_reader,
                  topology=model_config,
                  parameters=pool,
                  event_handler=event_handler,
                  batch_size=32,  # batch size should be refactor in Data reader
                  data_types={  # data_types will be removed, It should be in
                      # network topology
                      'pixel': dense_vector(784),
                      'label': integer_value(10)
                  })


if __name__ == '__main__':
    main()