storage.py 2.8 KB
Newer Older
S
superjom 已提交
1 2 3 4 5 6 7 8 9
__all__ = [
    'StorageReader',
    'StorageWriter',
]
import core

dtypes = ("float", "double", "int32", "int64")


10
class LogReader(object):
S
superjom 已提交
11

S
superjom 已提交
12 13
    cur_mode = None

S
superjom 已提交
14 15
    def __init__(self, dir, reader=None):
        self.dir = dir
16
        self.reader = reader if reader else core.LogReader(dir)
S
superjom 已提交
17

S
superjom 已提交
18
    def mode(self, mode):
S
superjom 已提交
19 20
        self.reader.set_mode(mode)
        return self
S
superjom 已提交
21

S
superjom 已提交
22
    def as_mode(self, mode):
23
        tmp = LogReader(dir, self.reader.as_mode(mode))
S
superjom 已提交
24 25 26 27 28 29 30
        return tmp

    def modes(self):
        return self.reader.modes()

    def tags(self, kind):
        return self.reader.tags(kind)
S
superjom 已提交
31 32 33 34 35 36 37 38 39

    def scalar(self, tag, type='float'):
        type2scalar = {
            'float': self.reader.get_scalar_float,
            'double': self.reader.get_scalar_double,
            'int': self.reader.get_scalar_int,
        }
        return type2scalar[type](tag)

S
superjom 已提交
40 41 42
    def image(self, tag):
        return self.reader.get_image(tag)

43 44 45 46 47 48 49 50
    def histogram(self, tag, type='float'):
        type2scalar = {
            'float': self.reader.get_histogram_float,
            'double': self.reader.get_histogram_double,
            'int': self.reader.get_histogram_int,
        }
        return type2scalar[type](tag)

S
superjom 已提交
51
    def __enter__(self):
S
superjom 已提交
52
        return self
S
superjom 已提交
53 54

    def __exit__(self, type, value, traceback):
S
superjom 已提交
55
        self.reader.set_mode("default")
S
superjom 已提交
56

S
superjom 已提交
57

58
class LogWriter(object):
S
superjom 已提交
59

S
superjom 已提交
60 61
    cur_mode = None

S
superjom 已提交
62 63 64
    def __init__(self, dir, sync_cycle, writer=None):
        self.dir = dir
        self.sync_cycle = sync_cycle
65
        self.writer = writer if writer else core.LogWriter(dir, sync_cycle)
S
superjom 已提交
66

S
superjom 已提交
67
    def mode(self, mode):
S
superjom 已提交
68 69
        self.writer.set_mode(mode)
        return self
S
superjom 已提交
70

S
superjom 已提交
71
    def as_mode(self, mode):
72 73
        LogWriter.cur_mode = LogWriter(self.dir, self.sync_cycle, self.writer.as_mode(mode))
        return LogWriter.cur_mode
S
superjom 已提交
74 75

    def scalar(self, tag, type='float'):
76 77 78
        '''
        Create a scalar component.
        '''
S
superjom 已提交
79 80 81 82 83 84
        type2scalar = {
            'float': self.writer.new_scalar_float,
            'double': self.writer.new_scalar_double,
            'int': self.writer.new_scalar_int,
        }
        return type2scalar[type](tag)
S
superjom 已提交
85

S
superjom 已提交
86
    def image(self, tag, num_samples, step_cycle):
87 88 89
        '''
        Create an image component.
        '''
S
superjom 已提交
90
        return self.writer.new_image(tag, num_samples, step_cycle)
S
superjom 已提交
91

92 93 94 95 96 97 98 99 100 101 102
    def histogram(self, tag, num_buckets, type='float'):
        '''
        Create a histogram component.
        '''
        types = {
            'float': self.writer.new_histogram_float,
            'double': self.writer.new_histogram_double,
            'int': self.writer.new_histogram_int,
        }
        return types[type](tag, num_buckets)

S
superjom 已提交
103
    def __enter__(self):
S
superjom 已提交
104
        return self
S
superjom 已提交
105 106

    def __exit__(self, type, value, traceback):
S
superjom 已提交
107
        self.writer.set_mode("default")