test_summary.py 3.1 KB
Newer Older
S
superjom 已提交
1 2 3
import summary
import numpy as np
import unittest
S
superjom 已提交
4
import random
S
superjom 已提交
5
import time
S
superjom 已提交
6 7 8 9 10 11

once_flag = False


class ScalarTester(unittest.TestCase):
    def setUp(self):
S
superjom 已提交
12 13 14 15 16 17 18 19 20
        self.dir = "tmp/summary.test"
        # clean path
        try:
            os.rmdir(self.dir)
        except:
            pass
        self.im = summary.IM(self.dir, "write", 200)
        self.tablet_name = "scalar0"
        self.scalar = summary.scalar(self.im, self.tablet_name)
S
Superjom 已提交
21 22
        self.py_captions = ["train cost", "test cost"]
        self.scalar.set_captions(self.py_captions)
S
superjom 已提交
23 24 25

        self.py_records = []
        self.py_ids = []
S
superjom 已提交
26
        # write
S
superjom 已提交
27 28 29 30 31
        for i in range(10):
            record = [0.1 * i, 0.2 * i]
            id = i * 10
            self.py_records.append(record)
            self.py_ids.append(id)
S
Superjom 已提交
32 33
            self.scalar.add(id, record)

S
superjom 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46
    def test_records(self):
        self.assertEqual(self.scalar.size, len(self.py_records))
        for i, record in enumerate(self.scalar.records):
            self.assertTrue(np.isclose(record, self.py_records[i]).all())

    def test_ids(self):
        self.assertEqual(len(self.py_ids), self.scalar.size)
        for i, id in enumerate(self.scalar.ids):
            self.assertEqual(self.py_ids[i], id)

    def test_captions(self):
        self.assertEqual(self.scalar.captions, self.py_captions)

S
superjom 已提交
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
    def test_read_records(self):
        time.sleep(1)
        im = summary.IM(self.dir, "read", 200)
        time.sleep(1)
        scalar = summary.scalar(im, self.tablet_name)
        records = scalar.records
        self.assertEqual(len(self.py_records), scalar.size)
        for i, record in enumerate(self.scalar.records):
            self.assertTrue(np.isclose(record, records[i]).all())

    def test_read_ids(self):
        time.sleep(0.6)
        im = summary.IM(self.dir, "read", msecs=200)
        time.sleep(0.6)
        scalar = summary.scalar(im, self.tablet_name)
        self.assertEqual(len(self.py_ids), scalar.size)
        for i, id in enumerate(scalar.ids):
            self.assertEqual(self.py_ids[i], id)

    def test_read_captions(self):
        time.sleep(0.6)
        im = summary.IM(self.dir, "read", msecs=200)
        time.sleep(0.6)
        scalar = summary.scalar(im, self.tablet_name)
        self.assertEqual(scalar.captions, self.py_captions)

    def test_mix_read_write(self):
        write_im = summary.IM(self.dir, "write", msecs=200)
        time.sleep(0.6)
        read_im = summary.IM(self.dir, "read", msecs=200)

        scalar_writer = summary.scalar(write_im, self.tablet_name)
        scalar_reader = summary.scalar(read_im, self.tablet_name)

        scalar_writer.set_captions(["train cost", "test cost"])
        for i in range(1000):
            scalar_writer.add(i, [random.random(), random.random()])

        scalar_reader.records

        for i in range(500):
            scalar_writer.add(i, [random.random(), random.random()])

        scalar_reader.records

        for i in range(500):
            scalar_writer.add(i, [random.random(), random.random()])

        for i in range(10):
            scalar_reader.records
            scalar_reader.captions

S
superjom 已提交
99 100 101

if __name__ == '__main__':
    unittest.main()