test_storage.py 4.7 KB
Newer Older
S
superjom 已提交
1
import random
S
superjom 已提交
2
import time
S
superjom 已提交
3 4 5
import unittest

import numpy as np
S
superjom 已提交
6
from PIL import Image
S
superjom 已提交
7

Y
Yan Chunwei 已提交
8 9 10 11
import sys, pprint
pprint.pprint(sys.path)

from visualdl import LogWriter, LogReader
S
superjom 已提交
12

S
superjom 已提交
13

S
superjom 已提交
14
class StorageTest(unittest.TestCase):
S
superjom 已提交
15
    def setUp(self):
S
superjom 已提交
16
        self.dir = "./tmp/storage_test"
Y
Yan Chunwei 已提交
17
        self.writer = LogWriter(
S
superjom 已提交
18
            self.dir, sync_cycle=1).as_mode("train")
S
superjom 已提交
19 20 21

    def test_scalar(self):
        print 'test write'
S
superjom 已提交
22
        scalar = self.writer.scalar("model/scalar/min")
S
superjom 已提交
23
        # scalar.set_caption("model/scalar/min")
S
superjom 已提交
24
        for i in range(10):
S
superjom 已提交
25
            scalar.add_record(i, float(i))
S
superjom 已提交
26

S
superjom 已提交
27
        print 'test read'
Y
Yan Chunwei 已提交
28
        self.reader = LogReader(self.dir)
S
superjom 已提交
29 30 31 32 33 34 35 36 37
        with self.reader.mode("train") as reader:
            scalar = reader.scalar("model/scalar/min")
            self.assertEqual(scalar.caption(), "train")
            records = scalar.records()
            ids = scalar.ids()
            self.assertTrue(np.equal(records, [float(i) for i in range(10)]).all())
            self.assertTrue(np.equal(ids, [float(i) for i in range(10)]).all())
            print 'records', records
            print 'ids', ids
S
superjom 已提交
38

S
superjom 已提交
39 40
    def test_image(self):
        tag = "layer1/layer2/image0"
S
superjom 已提交
41
        image_writer = self.writer.image(tag, 10, 1)
S
superjom 已提交
42 43
        num_passes = 10
        num_samples = 100
44
        shape = [10, 10, 3]
S
superjom 已提交
45 46 47 48

        for pass_ in xrange(num_passes):
            image_writer.start_sampling()
            for ins in xrange(num_samples):
S
superjom 已提交
49
                index = image_writer.is_sample_taken()
S
superjom 已提交
50 51 52 53 54 55
                if index != -1:
                    data = np.random.random(shape) * 256
                    data = np.ndarray.flatten(data)
                    image_writer.set_sample(index, shape, list(data))
            image_writer.finish_sampling()

Y
Yan Chunwei 已提交
56
        self.reader = LogReader(self.dir)
S
superjom 已提交
57 58 59 60
        with self.reader.mode("train") as reader:
            image_reader = reader.image(tag)
            self.assertEqual(image_reader.caption(), tag)
            self.assertEqual(image_reader.num_records(), num_passes)
S
superjom 已提交
61

S
superjom 已提交
62 63 64 65
            image_record = image_reader.record(0, 1)
            self.assertTrue(np.equal(image_record.shape(), shape).all())
            data = image_record.data()
            self.assertEqual(len(data), np.prod(shape))
S
superjom 已提交
66

S
superjom 已提交
67 68 69
            image_tags = reader.tags("image")
            self.assertTrue(image_tags)
            self.assertEqual(len(image_tags), 1)
S
superjom 已提交
70

S
superjom 已提交
71 72 73 74 75 76 77 78 79 80 81 82
    def test_check_image(self):
        '''
        check whether the storage will keep image data consistent
        '''
        print 'check image'
        tag = "layer1/check/image1"
        image_writer = self.writer.image(tag, 10, 1)

        image = Image.open("./dog.jpg")
        shape = [image.size[1], image.size[0], 3]
        origin_data = np.array(image.getdata()).flatten()

Y
Yan Chunwei 已提交
83
        self.reader = LogReader(self.dir)
S
superjom 已提交
84
        with self.reader.mode("train") as reader:
S
superjom 已提交
85

S
superjom 已提交
86 87 88 89
            image_writer.start_sampling()
            index = image_writer.is_sample_taken()
            image_writer.set_sample(index, shape, list(origin_data))
            image_writer.finish_sampling()
S
superjom 已提交
90

S
superjom 已提交
91
            # read and check whether the original image will be displayed
S
superjom 已提交
92

S
superjom 已提交
93 94 95 96
            image_reader = reader.image(tag)
            image_record = image_reader.record(0, 0)
            data = image_record.data()
            shape = image_record.shape()
S
superjom 已提交
97

S
superjom 已提交
98 99 100 101 102 103 104
            PIL_image_shape = (shape[0] * shape[1], shape[2])
            data = np.array(data, dtype='uint8').reshape(PIL_image_shape)
            print 'origin', origin_data.flatten()
            print 'data', data.flatten()
            image = Image.fromarray(data.reshape(shape))
            # manully check the image and found that nothing wrong with the image storage.
            # image.show()
S
superjom 已提交
105

S
superjom 已提交
106 107 108
            # after scale, elements are changed.
            # self.assertTrue(
            #     np.equal(origin_data.reshape(PIL_image_shape), data).all())
S
superjom 已提交
109

S
superjom 已提交
110 111 112 113 114 115
    def test_with_syntax(self):
        with self.writer.mode("train") as writer:
            scalar = writer.scalar("model/scalar/average")
            for i in range(10):
                scalar.add_record(i, float(i))

Y
Yan Chunwei 已提交
116
        self.reader = LogReader(self.dir)
S
superjom 已提交
117 118 119 120
        with self.reader.mode("train") as reader:
            scalar = reader.scalar("model/scalar/average")
            self.assertEqual(scalar.caption(), "train")

S
superjom 已提交
121 122
    def test_modes(self):
        dir = "./tmp/storagetest0"
Y
Yan Chunwei 已提交
123
        store = LogWriter(
S
superjom 已提交
124 125 126 127 128 129 130 131 132 133 134 135 136
            self.dir, sync_cycle=1)

        scalars = []

        for i in range(10):
            with store.mode("mode-%d" % i) as writer:
                scalar = writer.scalar("add/scalar0")
                scalars.append(scalar)

        for scalar in scalars[:-1]:
            for i in range(10):
                scalar.add_record(i, float(i))

S
superjom 已提交
137

S
superjom 已提交
138 139 140

if __name__ == '__main__':
    unittest.main()