cifar10.py 1.8 KB
Newer Older
D
dangqingqing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

D
dangqingqing 已提交
15 16 17 18 19 20 21 22 23
import os
import numpy as np
import cPickle

DATA = "cifar-10-batches-py"
CHANNEL = 3
HEIGHT = 32
WIDTH = 32

D
dangqingqing 已提交
24

D
dangqingqing 已提交
25 26 27 28 29 30
def create_mean(dataset):
    if not os.path.isfile("mean.meta"):
        mean = np.zeros(CHANNEL * HEIGHT * WIDTH)
        num = 0
        for f in dataset:
            batch = np.load(f)
D
dangqingqing 已提交
31
            mean += batch['data'].sum(0)
D
dangqingqing 已提交
32 33 34 35
            num += len(batch['data'])
        mean /= num
        print mean.size
        data = {"mean": mean, "size": mean.size}
D
dangqingqing 已提交
36 37
        cPickle.dump(
            data, open("mean.meta", 'w'), protocol=cPickle.HIGHEST_PROTOCOL)
D
dangqingqing 已提交
38 39 40


def create_data():
D
dangqingqing 已提交
41
    train_set = [DATA + "/data_batch_%d" % (i + 1) for i in xrange(0, 5)]
D
dangqingqing 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
    test_set = [DATA + "/test_batch"]

    # create mean values
    create_mean(train_set)

    # create dataset lists
    if not os.path.isfile("train.txt"):
        train = ["data/" + i for i in train_set]
        open("train.txt", "w").write("\n".join(train))
        open("train.list", "w").write("\n".join(["data/train.txt"]))

    if not os.path.isfile("text.txt"):
        test = ["data/" + i for i in test_set]
        open("test.txt", "w").write("\n".join(test))
        open("test.list", "w").write("\n".join(["data/test.txt"]))

D
dangqingqing 已提交
58

D
dangqingqing 已提交
59 60
if __name__ == '__main__':
    create_data()