test_storage.py 2.2 KB
Newer Older
S
superjom 已提交
1
import random
S
superjom 已提交
2
import time
S
superjom 已提交
3 4 5 6 7
import unittest

import numpy as np

import storage
S
superjom 已提交
8

S
superjom 已提交
9

S
superjom 已提交
10
class StorageTest(unittest.TestCase):
S
superjom 已提交
11
    def setUp(self):
S
superjom 已提交
12
        self.dir = "./tmp/storage_test"
S
superjom 已提交
13 14
        self.writer = storage.StorageWriter(
            self.dir, sync_cycle=1).as_mode("train")
S
superjom 已提交
15 16 17

    def test_scalar(self):
        print 'test write'
S
superjom 已提交
18
        scalar = self.writer.scalar("model/scalar/min")
S
superjom 已提交
19
        # scalar.set_caption("model/scalar/min")
S
superjom 已提交
20
        for i in range(10):
S
superjom 已提交
21
            scalar.add_record(i, float(i))
S
superjom 已提交
22

S
superjom 已提交
23
        print 'test read'
S
superjom 已提交
24
        self.reader = storage.StorageReader(self.dir).as_mode("train")
S
superjom 已提交
25
        scalar = self.reader.scalar("model/scalar/min")
S
superjom 已提交
26
        self.assertEqual(scalar.caption(), "train")
S
superjom 已提交
27 28 29 30 31 32
        records = scalar.records()
        ids = scalar.ids()
        self.assertTrue(np.equal(records, [float(i) for i in range(10)]).all())
        self.assertTrue(np.equal(ids, [float(i) for i in range(10)]).all())
        print 'records', records
        print 'ids', ids
S
superjom 已提交
33

S
superjom 已提交
34 35
    def test_image(self):
        tag = "layer1/layer2/image0"
S
superjom 已提交
36
        image_writer = self.writer.image(tag, 10, 1)
S
superjom 已提交
37 38
        num_passes = 10
        num_samples = 100
39
        shape = [10, 10, 3]
S
superjom 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53

        for pass_ in xrange(num_passes):
            image_writer.start_sampling()
            for ins in xrange(num_samples):
                index =  image_writer.is_sample_taken()
                if index != -1:
                    data = np.random.random(shape) * 256
                    data = np.ndarray.flatten(data)
                    image_writer.set_sample(index, shape, list(data))
            image_writer.finish_sampling()

        self.reader = storage.StorageReader(self.dir).as_mode("train")
        image_reader = self.reader.image(tag)
        self.assertEqual(image_reader.caption(), tag)
S
superjom 已提交
54
        self.assertEqual(image_reader.num_records(), num_passes)
S
superjom 已提交
55 56 57 58

        image_record = image_reader.record(0, 1)
        self.assertTrue(np.equal(image_record.shape(), shape).all())
        data = image_record.data()
S
superjom 已提交
59 60 61 62 63
        self.assertEqual(len(data), np.prod(shape))

        image_tags = self.reader.tags("image")
        self.assertTrue(image_tags)
        self.assertEqual(len(image_tags), 1)
S
superjom 已提交
64

S
superjom 已提交
65 66 67

if __name__ == '__main__':
    unittest.main()