api_train.py 1.3 KB
Newer Older
Y
Yu Yang 已提交
1
import py_paddle.swig_paddle as api
Y
Yu Yang 已提交
2 3 4 5 6 7 8 9 10 11 12 13
import paddle.trainer.config_parser
import numpy as np


def init_parameter(network):
    assert isinstance(network, api.GradientMachine)
    for each_param in network.getParameters():
        assert isinstance(each_param, api.Parameter)
        array = each_param.getBuf(api.PARAMETER_VALUE).toNumpyArrayInplace()
        assert isinstance(array, np.ndarray)
        for i in xrange(len(array)):
            array[i] = np.random.uniform(-1.0, 1.0)
Y
Yu Yang 已提交
14 15 16 17


def main():
    api.initPaddle("-use_gpu=false", "-trainer_count=4")  # use 4 cpu cores
Y
Yu Yang 已提交
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
    config = paddle.trainer.config_parser.parse_config(
        'simple_mnist_network.py', '')

    opt_config = api.OptimizationConfig.createFromProto(config.opt_config)
    _temp_optimizer_ = api.ParameterOptimizer.create(opt_config)
    enable_types = _temp_optimizer_.getParameterTypes()

    m = api.GradientMachine.createFromConfigProto(
        config.model_config, api.CREATE_MODE_NORMAL, enable_types)
    assert isinstance(m, api.GradientMachine)
    init_parameter(network=m)

    updater = api.ParameterUpdater.createLocalUpdater(opt_config)
    assert isinstance(updater, api.ParameterUpdater)
    updater.init(m)
Y
Yu Yang 已提交
33 34 35 36 37
    m.start()

    for _ in xrange(100):
        updater.startPass()

Y
Yu Yang 已提交
38 39
        updater.finishPass()

Y
Yu Yang 已提交
40
    m.finish()
Y
Yu Yang 已提交
41 42 43 44


if __name__ == '__main__':
    main()