api_train.py 2.7 KB
Newer Older
Y
Yu Yang 已提交
1
import py_paddle.swig_paddle as api
Y
Yu Yang 已提交
2 3
from py_paddle import DataProviderConverter
import paddle.trainer.PyDataProvider2 as dp
Y
Yu Yang 已提交
4 5
import paddle.trainer.config_parser
import numpy as np
Y
Yu Yang 已提交
6
from mnist_util import read_from_mnist
Y
Yu Yang 已提交
7 8 9 10 11 12 13 14 15 16


def init_parameter(network):
    assert isinstance(network, api.GradientMachine)
    for each_param in network.getParameters():
        assert isinstance(each_param, api.Parameter)
        array = each_param.getBuf(api.PARAMETER_VALUE).toNumpyArrayInplace()
        assert isinstance(array, np.ndarray)
        for i in xrange(len(array)):
            array[i] = np.random.uniform(-1.0, 1.0)
Y
Yu Yang 已提交
17 18


Y
Yu Yang 已提交
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
def generator_to_batch(generator, batch_size):
    ret_val = list()
    for each_item in generator:
        ret_val.append(each_item)
        if len(ret_val) == batch_size:
            yield ret_val
            ret_val = list()
    if len(ret_val) != 0:
        yield ret_val


def input_order_converter(generator):
    for each_item in generator:
        yield each_item['pixel'], each_item['label']


Y
Yu Yang 已提交
35 36
def main():
    api.initPaddle("-use_gpu=false", "-trainer_count=4")  # use 4 cpu cores
Y
Yu Yang 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50
    config = paddle.trainer.config_parser.parse_config(
        'simple_mnist_network.py', '')

    opt_config = api.OptimizationConfig.createFromProto(config.opt_config)
    _temp_optimizer_ = api.ParameterOptimizer.create(opt_config)
    enable_types = _temp_optimizer_.getParameterTypes()

    m = api.GradientMachine.createFromConfigProto(
        config.model_config, api.CREATE_MODE_NORMAL, enable_types)
    assert isinstance(m, api.GradientMachine)
    init_parameter(network=m)
    updater = api.ParameterUpdater.createLocalUpdater(opt_config)
    assert isinstance(updater, api.ParameterUpdater)
    updater.init(m)
Y
Yu Yang 已提交
51 52 53 54 55 56

    converter = DataProviderConverter(
        input_types=[dp.dense_vector(784), dp.integer_value(10)])

    train_file = './data/raw_data/train'

Y
Yu Yang 已提交
57 58 59 60
    m.start()

    for _ in xrange(100):
        updater.startPass()
Y
Yu Yang 已提交
61
        outArgs = api.Arguments.createArguments(0)
Y
Yu Yang 已提交
62 63
        train_data_generator = input_order_converter(
            read_from_mnist(train_file))
Y
Yu Yang 已提交
64
        for batch_id, data_batch in enumerate(
Y
Yu Yang 已提交
65
                generator_to_batch(train_data_generator, 2048)):
Y
Yu Yang 已提交
66
            trainRole = updater.startBatch(len(data_batch))
Y
Yu Yang 已提交
67

Y
Yu Yang 已提交
68
            def updater_callback(param):
Y
Yu Yang 已提交
69 70 71
                updater.update(param)

            m.forwardBackward(
Y
Yu Yang 已提交
72
                converter(data_batch), outArgs, trainRole, updater_callback)
Y
Yu Yang 已提交
73 74 75 76 77 78 79

            cost_vec = outArgs.getSlotValue(0)
            cost_vec = cost_vec.copyToNumpyMat()
            cost = cost_vec.sum() / len(data_batch)
            print 'Batch id', batch_id, 'with cost=', cost
            updater.finishBatch(cost)

Y
Yu Yang 已提交
80 81
        updater.finishPass()

Y
Yu Yang 已提交
82
    m.finish()
Y
Yu Yang 已提交
83 84 85 86


if __name__ == '__main__':
    main()