test_summary.py 1.2 KB
Newer Older
S
superjom 已提交
1 2 3
import summary
import numpy as np
import unittest
S
superjom 已提交
4
import time
S
superjom 已提交
5 6 7 8 9 10

once_flag = False


class ScalarTester(unittest.TestCase):
    def setUp(self):
S
Superjom 已提交
11 12 13 14 15
        dir = "tmp/3.test"
        self.im = summary.IM("write", dir)
        self.scalar = summary.scalar(self.im, "scalar0")
        self.py_captions = ["train cost", "test cost"]
        self.scalar.set_captions(self.py_captions)
S
superjom 已提交
16 17 18 19 20 21 22 23

        self.py_records = []
        self.py_ids = []
        for i in range(10):
            record = [0.1 * i, 0.2 * i]
            id = i * 10
            self.py_records.append(record)
            self.py_ids.append(id)
S
Superjom 已提交
24 25
            self.scalar.add(id, record)

S
superjom 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42

    def test_records(self):
        self.assertEqual(self.scalar.size, len(self.py_records))
        for i, record in enumerate(self.scalar.records):
            self.assertTrue(np.isclose(record, self.py_records[i]).all())

    def test_ids(self):
        self.assertEqual(len(self.py_ids), self.scalar.size)
        for i, id in enumerate(self.scalar.ids):
            self.assertEqual(self.py_ids[i], id)

    def test_captions(self):
        self.assertEqual(self.scalar.captions, self.py_captions)


if __name__ == '__main__':
    unittest.main()