api_train.py 2.2 KB
Newer Older
Y
Yu Yang 已提交
1
import py_paddle.swig_paddle as api
Y
Yu Yang 已提交
2 3
from py_paddle import DataProviderConverter
import paddle.trainer.PyDataProvider2 as dp
Y
Yu Yang 已提交
4 5
import paddle.trainer.config_parser
import numpy as np
Y
Yu Yang 已提交
6
from mnist_util import read_from_mnist
Y
Yu Yang 已提交
7 8 9 10 11 12 13 14 15 16


def init_parameter(network):
    assert isinstance(network, api.GradientMachine)
    for each_param in network.getParameters():
        assert isinstance(each_param, api.Parameter)
        array = each_param.getBuf(api.PARAMETER_VALUE).toNumpyArrayInplace()
        assert isinstance(array, np.ndarray)
        for i in xrange(len(array)):
            array[i] = np.random.uniform(-1.0, 1.0)
Y
Yu Yang 已提交
17 18


Y
Yu Yang 已提交
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
def generator_to_batch(generator, batch_size):
    ret_val = list()
    for each_item in generator:
        ret_val.append(each_item)
        if len(ret_val) == batch_size:
            yield ret_val
            ret_val = list()
    if len(ret_val) != 0:
        yield ret_val


def input_order_converter(generator):
    for each_item in generator:
        yield each_item['pixel'], each_item['label']


Y
Yu Yang 已提交
35 36
def main():
    api.initPaddle("-use_gpu=false", "-trainer_count=4")  # use 4 cpu cores
Y
Yu Yang 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
    config = paddle.trainer.config_parser.parse_config(
        'simple_mnist_network.py', '')

    opt_config = api.OptimizationConfig.createFromProto(config.opt_config)
    _temp_optimizer_ = api.ParameterOptimizer.create(opt_config)
    enable_types = _temp_optimizer_.getParameterTypes()

    m = api.GradientMachine.createFromConfigProto(
        config.model_config, api.CREATE_MODE_NORMAL, enable_types)
    assert isinstance(m, api.GradientMachine)
    init_parameter(network=m)

    updater = api.ParameterUpdater.createLocalUpdater(opt_config)
    assert isinstance(updater, api.ParameterUpdater)
    updater.init(m)
Y
Yu Yang 已提交
52 53 54 55 56 57

    converter = DataProviderConverter(
        input_types=[dp.dense_vector(784), dp.integer_value(10)])

    train_file = './data/raw_data/train'

Y
Yu Yang 已提交
58 59 60 61
    m.start()

    for _ in xrange(100):
        updater.startPass()
Y
Yu Yang 已提交
62 63 64 65
        train_data_generator = input_order_converter(
            read_from_mnist(train_file))
        for data_batch in generator_to_batch(train_data_generator, 128):
            inArgs = converter(data_batch)
Y
Yu Yang 已提交
66

Y
Yu Yang 已提交
67 68
        updater.finishPass()

Y
Yu Yang 已提交
69
    m.finish()
Y
Yu Yang 已提交
70 71 72 73


if __name__ == '__main__':
    main()