test_storage.py 3.4 KB
Newer Older
S
superjom 已提交
1
import random
S
superjom 已提交
2
import time
S
superjom 已提交
3
import unittest
S
superjom 已提交
4
from PIL import Image
S
superjom 已提交
5 6 7 8

import numpy as np

import storage
S
superjom 已提交
9

S
superjom 已提交
10

S
superjom 已提交
11
class StorageTest(unittest.TestCase):
S
superjom 已提交
12
    def setUp(self):
S
superjom 已提交
13
        self.dir = "./tmp/storage_test"
S
superjom 已提交
14 15
        self.writer = storage.StorageWriter(
            self.dir, sync_cycle=1).as_mode("train")
S
superjom 已提交
16 17 18

    def test_scalar(self):
        print 'test write'
S
superjom 已提交
19
        scalar = self.writer.scalar("model/scalar/min")
S
superjom 已提交
20
        # scalar.set_caption("model/scalar/min")
S
superjom 已提交
21
        for i in range(10):
S
superjom 已提交
22
            scalar.add_record(i, float(i))
S
superjom 已提交
23

S
superjom 已提交
24
        print 'test read'
S
superjom 已提交
25
        self.reader = storage.StorageReader(self.dir).as_mode("train")
S
superjom 已提交
26
        scalar = self.reader.scalar("model/scalar/min")
S
superjom 已提交
27
        self.assertEqual(scalar.caption(), "train")
S
superjom 已提交
28 29 30 31 32 33
        records = scalar.records()
        ids = scalar.ids()
        self.assertTrue(np.equal(records, [float(i) for i in range(10)]).all())
        self.assertTrue(np.equal(ids, [float(i) for i in range(10)]).all())
        print 'records', records
        print 'ids', ids
S
superjom 已提交
34

S
superjom 已提交
35 36
    def test_image(self):
        tag = "layer1/layer2/image0"
S
superjom 已提交
37
        image_writer = self.writer.image(tag, 10, 1)
S
superjom 已提交
38 39
        num_passes = 10
        num_samples = 100
40
        shape = [10, 10, 3]
S
superjom 已提交
41 42 43 44

        for pass_ in xrange(num_passes):
            image_writer.start_sampling()
            for ins in xrange(num_samples):
S
superjom 已提交
45
                index = image_writer.is_sample_taken()
S
superjom 已提交
46 47 48 49 50 51 52 53 54
                if index != -1:
                    data = np.random.random(shape) * 256
                    data = np.ndarray.flatten(data)
                    image_writer.set_sample(index, shape, list(data))
            image_writer.finish_sampling()

        self.reader = storage.StorageReader(self.dir).as_mode("train")
        image_reader = self.reader.image(tag)
        self.assertEqual(image_reader.caption(), tag)
S
superjom 已提交
55
        self.assertEqual(image_reader.num_records(), num_passes)
S
superjom 已提交
56 57 58 59

        image_record = image_reader.record(0, 1)
        self.assertTrue(np.equal(image_record.shape(), shape).all())
        data = image_record.data()
S
superjom 已提交
60 61 62 63 64
        self.assertEqual(len(data), np.prod(shape))

        image_tags = self.reader.tags("image")
        self.assertTrue(image_tags)
        self.assertEqual(len(image_tags), 1)
S
superjom 已提交
65

S
superjom 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
    def test_check_image(self):
        '''
        check whether the storage will keep image data consistent
        '''
        print 'check image'
        tag = "layer1/check/image1"
        image_writer = self.writer.image(tag, 10, 1)

        image = Image.open("./dog.jpg")
        shape = [image.size[1], image.size[0], 3]
        origin_data = np.array(image.getdata()).flatten()

        self.reader = storage.StorageReader(self.dir).as_mode("train")

        image_writer.start_sampling()
        index = image_writer.is_sample_taken()
        image_writer.set_sample(index, shape, list(origin_data))
        image_writer.finish_sampling()

        # read and check whether the original image will be displayed

        image_reader = self.reader.image(tag)
        image_record = image_reader.record(0, 0)
        data = image_record.data()
        shape = image_record.shape()

S
superjom 已提交
92
        PIL_image_shape = (shape[0] * shape[1], shape[2])
S
superjom 已提交
93 94 95 96 97
        data = np.array(data, dtype='uint8').reshape(PIL_image_shape)
        print 'origin', origin_data.flatten()
        print 'data', data.flatten()
        image = Image.fromarray(data.reshape(shape))

S
superjom 已提交
98 99
        self.assertTrue(
            np.equal(origin_data.reshape(PIL_image_shape), data).all())
S
superjom 已提交
100

S
superjom 已提交
101 102 103

if __name__ == '__main__':
    unittest.main()