testTrain.py 3.8 KB
Newer Older
Z
zhangjinchao01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from py_paddle import swig_paddle, DataProviderWrapperConverter
import paddle.trainer.config_parser
from paddle.trainer.PyDataProviderWrapper import DenseSlot, IndexSlot
import numpy
import util


def init_params(params):
    def init_param(p):
        assert isinstance(p, swig_paddle.Parameter)
        val = p.getBuf(swig_paddle.PARAMETER_VALUE)
        assert isinstance(val, swig_paddle.Vector)
        arr = val.toNumpyArrayInplace()
        for i in xrange(len(arr)):
            arr[i] = numpy.random.uniform(-1.0, 1.0)

    for p in params:
        init_param(p)


def init_optimizers(opt_conf, params):
    opts = {}
    for param in params:
        param_conf = param.getConfig().toProto()
        opts[param.getID()] = swig_paddle.ParameterOptimizer.create(opt_conf)
        opts[param.getID()].init(param_conf.dims[1], param.getConfig())
    retv_opts = [None for _ in xrange(len(opts))]
    for k in opts:
        assert k < len(retv_opts)
        retv_opts[k] = opts[k]
    return retv_opts


def main():
    trainer_config = paddle.trainer.config_parser.parse_config(
        "./testTrainConfig.py", "")
    opt_config = trainer_config.opt_config
    print "========Optimization Config ======="
    print opt_config
    print "==================================="
    opt_config = swig_paddle.OptimizationConfig.createFromProto(opt_config)
    _temp_optimizer_ = swig_paddle.ParameterOptimizer.create(opt_config)
    enable_types = _temp_optimizer_.getParameterTypes()
    m = swig_paddle.GradientMachine.createFromConfigProto(
        trainer_config.model_config, swig_paddle.CREATE_MODE_NORMAL,
        enable_types)
    assert m is not None
    assert isinstance(m, swig_paddle.GradientMachine)
    init_params(m.getParameters())

    optimizers = init_optimizers(opt_config, m.getParameters())

    # Train One Pass.
    for optimizer in optimizers:
        optimizer.startPass()
    batch_id = 0
    while True:  # Train one batch
        batch_size = 1000
        inArgs, atEnd = util.loadMNISTTrainData(batch_size)
        if atEnd:
            break
        outArgs = swig_paddle.Arguments.createArguments(0)

        for optimizer in optimizers:
            optimizer.startBatch(batch_size)

        def update_callback(param):
            try:
                bufs = list(param.getBufs())
                opt = optimizers[param.getID()]
                opt.update(bufs, param.getConfig())
                callback = opt.needSpecialTraversal(param.getConfig())
                if callback is not None:
                    callback(bufs, param.getConfig(), swig_paddle.NO_SPARSE_ID)

            except Exception as e:
                print e

        m.forwardBackward(inArgs, outArgs, swig_paddle.PASS_TRAIN,
                          update_callback)

        for optimizer in optimizers:
            optimizer.finishBatch()

        cost_vec = outArgs.getSlotValue(0)
        assert isinstance(cost_vec, swig_paddle.Matrix)
        cost_vec = cost_vec.copyToNumpyMat()
        print 'Finish Batch', batch_id, 'with cost ', cost_vec.sum() / batch_size
        batch_id += 1

    for optimizer in optimizers:
        optimizer.finishPass()


if __name__ == '__main__':
    swig_paddle.initPaddle("--use_gpu=0", "--trainer_count=1")
    main()