storage.py 2.0 KB
Newer Older
S
superjom 已提交
1 2 3 4 5 6 7 8 9
__all__ = [
    'StorageReader',
    'StorageWriter',
]
import core

dtypes = ("float", "double", "int32", "int64")


10
class LogReader(object):
S
superjom 已提交
11

S
superjom 已提交
12 13
    cur_mode = None

S
superjom 已提交
14 15
    def __init__(self, dir, reader=None):
        self.dir = dir
16
        self.reader = reader if reader else core.LogReader(dir)
S
superjom 已提交
17

S
superjom 已提交
18
    def mode(self, mode):
19 20
        LogReader.cur_mode = self.as_mode(mode)
        return LogReader.cur_mode
S
superjom 已提交
21

S
superjom 已提交
22
    def as_mode(self, mode):
23
        tmp = LogReader(dir, self.reader.as_mode(mode))
S
superjom 已提交
24 25 26 27 28 29 30
        return tmp

    def modes(self):
        return self.reader.modes()

    def tags(self, kind):
        return self.reader.tags(kind)
S
superjom 已提交
31 32 33 34 35 36 37 38 39

    def scalar(self, tag, type='float'):
        type2scalar = {
            'float': self.reader.get_scalar_float,
            'double': self.reader.get_scalar_double,
            'int': self.reader.get_scalar_int,
        }
        return type2scalar[type](tag)

S
superjom 已提交
40 41 42
    def image(self, tag):
        return self.reader.get_image(tag)

S
superjom 已提交
43
    def __enter__(self):
44
        return LogReader.cur_mode
S
superjom 已提交
45 46 47 48

    def __exit__(self, type, value, traceback):
        pass

S
superjom 已提交
49

50
class LogWriter(object):
S
superjom 已提交
51

S
superjom 已提交
52 53
    cur_mode = None

S
superjom 已提交
54 55 56
    def __init__(self, dir, sync_cycle, writer=None):
        self.dir = dir
        self.sync_cycle = sync_cycle
57
        self.writer = writer if writer else core.LogWriter(dir, sync_cycle)
S
superjom 已提交
58

S
superjom 已提交
59
    def mode(self, mode):
S
superjom 已提交
60 61
        self.writer.set_mode(mode)
        return self
S
superjom 已提交
62

S
superjom 已提交
63
    def as_mode(self, mode):
64 65
        LogWriter.cur_mode = LogWriter(self.dir, self.sync_cycle, self.writer.as_mode(mode))
        return LogWriter.cur_mode
S
superjom 已提交
66 67 68 69 70 71 72 73

    def scalar(self, tag, type='float'):
        type2scalar = {
            'float': self.writer.new_scalar_float,
            'double': self.writer.new_scalar_double,
            'int': self.writer.new_scalar_int,
        }
        return type2scalar[type](tag)
S
superjom 已提交
74

S
superjom 已提交
75 76
    def image(self, tag, num_samples, step_cycle):
        return self.writer.new_image(tag, num_samples, step_cycle)
S
superjom 已提交
77 78

    def __enter__(self):
S
superjom 已提交
79
        return self
S
superjom 已提交
80 81

    def __exit__(self, type, value, traceback):
S
superjom 已提交
82
        self.writer.set_mode("default")