test_storage.py 1.0 KB
Newer Older
S
superjom 已提交
1
import storage
S
superjom 已提交
2 3
import numpy as np
import unittest
S
superjom 已提交
4
import random
S
superjom 已提交
5
import time
S
superjom 已提交
6

S
superjom 已提交
7

S
superjom 已提交
8
class StorageTest(unittest.TestCase):
S
superjom 已提交
9
    def setUp(self):
S
superjom 已提交
10
        self.dir = "./tmp/storage_test"
S
superjom 已提交
11

S
superjom 已提交
12 13
    def test_read(self):
        print 'test write'
S
superjom 已提交
14 15
        self.writer = storage.StorageWriter(
            self.dir, sync_cycle=1).as_mode("train")
S
superjom 已提交
16
        scalar = self.writer.scalar("model/scalar/min")
S
superjom 已提交
17
        # scalar.set_caption("model/scalar/min")
S
superjom 已提交
18
        for i in range(10):
S
superjom 已提交
19
            scalar.add_record(i, float(i))
S
superjom 已提交
20

S
superjom 已提交
21
        print 'test read'
S
superjom 已提交
22
        self.reader = storage.StorageReader(self.dir).as_mode("train")
S
superjom 已提交
23
        scalar = self.reader.scalar("model/scalar/min")
S
superjom 已提交
24
        self.assertEqual(scalar.caption(), "train")
S
superjom 已提交
25 26 27 28 29 30
        records = scalar.records()
        ids = scalar.ids()
        self.assertTrue(np.equal(records, [float(i) for i in range(10)]).all())
        self.assertTrue(np.equal(ids, [float(i) for i in range(10)]).all())
        print 'records', records
        print 'ids', ids
S
superjom 已提交
31

S
superjom 已提交
32 33 34

if __name__ == '__main__':
    unittest.main()