test_storage.py 3.4 KB
Newer Older
S
superjom 已提交
1
import random
S
superjom 已提交
2
import time
S
superjom 已提交
3
import unittest
S
superjom 已提交
4
from PIL import Image
S
superjom 已提交
5 6 7 8

import numpy as np

import storage
S
superjom 已提交
9

S
superjom 已提交
10

S
superjom 已提交
11
class StorageTest(unittest.TestCase):
S
superjom 已提交
12
    def setUp(self):
S
superjom 已提交
13
        self.dir = "./tmp/storage_test"
S
superjom 已提交
14 15
        self.writer = storage.StorageWriter(
            self.dir, sync_cycle=1).as_mode("train")
S
superjom 已提交
16 17 18

    def test_scalar(self):
        print 'test write'
S
superjom 已提交
19
        scalar = self.writer.scalar("model/scalar/min")
S
superjom 已提交
20
        # scalar.set_caption("model/scalar/min")
S
superjom 已提交
21
        for i in range(10):
S
superjom 已提交
22
            scalar.add_record(i, float(i))
S
superjom 已提交
23

S
superjom 已提交
24
        print 'test read'
S
superjom 已提交
25
        self.reader = storage.StorageReader(self.dir).as_mode("train")
S
superjom 已提交
26
        scalar = self.reader.scalar("model/scalar/min")
S
superjom 已提交
27
        self.assertEqual(scalar.caption(), "train")
S
superjom 已提交
28 29 30 31 32 33
        records = scalar.records()
        ids = scalar.ids()
        self.assertTrue(np.equal(records, [float(i) for i in range(10)]).all())
        self.assertTrue(np.equal(ids, [float(i) for i in range(10)]).all())
        print 'records', records
        print 'ids', ids
S
superjom 已提交
34

S
superjom 已提交
35 36
    def test_image(self):
        tag = "layer1/layer2/image0"
S
superjom 已提交
37
        image_writer = self.writer.image(tag, 10, 1)
S
superjom 已提交
38 39
        num_passes = 10
        num_samples = 100
40
        shape = [10, 10, 3]
S
superjom 已提交
41 42 43 44 45 46 47 48 49 50 51 52 53 54

        for pass_ in xrange(num_passes):
            image_writer.start_sampling()
            for ins in xrange(num_samples):
                index =  image_writer.is_sample_taken()
                if index != -1:
                    data = np.random.random(shape) * 256
                    data = np.ndarray.flatten(data)
                    image_writer.set_sample(index, shape, list(data))
            image_writer.finish_sampling()

        self.reader = storage.StorageReader(self.dir).as_mode("train")
        image_reader = self.reader.image(tag)
        self.assertEqual(image_reader.caption(), tag)
S
superjom 已提交
55
        self.assertEqual(image_reader.num_records(), num_passes)
S
superjom 已提交
56 57 58 59

        image_record = image_reader.record(0, 1)
        self.assertTrue(np.equal(image_record.shape(), shape).all())
        data = image_record.data()
S
superjom 已提交
60 61 62 63 64
        self.assertEqual(len(data), np.prod(shape))

        image_tags = self.reader.tags("image")
        self.assertTrue(image_tags)
        self.assertEqual(len(image_tags), 1)
S
superjom 已提交
65

S
superjom 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
    def test_check_image(self):
        '''
        check whether the storage will keep image data consistent
        '''
        print 'check image'
        tag = "layer1/check/image1"
        image_writer = self.writer.image(tag, 10, 1)

        image = Image.open("./dog.jpg")
        shape = [image.size[1], image.size[0], 3]
        origin_data = np.array(image.getdata()).flatten()

        self.reader = storage.StorageReader(self.dir).as_mode("train")

        image_writer.start_sampling()
        index = image_writer.is_sample_taken()
        image_writer.set_sample(index, shape, list(origin_data))
        image_writer.finish_sampling()

        # read and check whether the original image will be displayed

        image_reader = self.reader.image(tag)
        image_record = image_reader.record(0, 0)
        data = image_record.data()
        shape = image_record.shape()

        PIL_image_shape = (shape[0]*shape[1], shape[2])
        data = np.array(data, dtype='uint8').reshape(PIL_image_shape)
        print 'origin', origin_data.flatten()
        print 'data', data.flatten()
        image = Image.fromarray(data.reshape(shape))

        self.assertTrue(np.equal(origin_data.reshape(PIL_image_shape), data).all())







S
superjom 已提交
106 107 108

if __name__ == '__main__':
    unittest.main()