dataprovider.py 1.6 KB
Newer Older
D
dangqingqing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import cPickle
from paddle.trainer.PyDataProvider2 import *

D
dangqingqing 已提交
19

D
dangqingqing 已提交
20 21 22 23 24 25 26 27 28 29
def initializer(settings, mean_path, is_train, **kwargs):
    settings.is_train = is_train
    settings.input_size = 3 * 32 * 32
    settings.mean = np.load(mean_path)['mean']
    settings.input_types = {
        'image': dense_vector(settings.input_size),
        'label': integer_value(10)
    }


30
@provider(init_hook=initializer, pool_size=50000)
D
dangqingqing 已提交
31 32 33 34 35 36 37 38 39
def process(settings, file_list):
    with open(file_list, 'r') as fdata:
        for fname in fdata:
            fo = open(fname.strip(), 'rb')
            batch = cPickle.load(fo)
            fo.close()
            images = batch['data']
            labels = batch['labels']
            for im, lab in zip(images, labels):
D
dangqingqing 已提交
40
                if settings.is_train and np.random.randint(2):
D
dangqingqing 已提交
41
                    im = im.reshape(3, 32, 32)
D
dangqingqing 已提交
42
                    im = im[:, :, ::-1]
D
dangqingqing 已提交
43
                    im = im.flatten()
D
dangqingqing 已提交
44
                im = im - settings.mean
D
dangqingqing 已提交
45
                yield {'image': im.astype('float32'), 'label': int(lab)}