storage.py 1.5 KB
Newer Older
S
superjom 已提交
1 2 3 4 5 6 7 8 9 10 11
__all__ = [
    'StorageReader',
    'StorageWriter',
]
import core

dtypes = ("float", "double", "int32", "int64")


class StorageReader(object):

S
superjom 已提交
12 13 14 15 16 17 18 19 20 21 22 23 24
    def __init__(self, dir, reader=None):
        self.dir = dir
        self.reader = reader if reader else core.Reader(dir)

    def as_mode(self, mode):
        tmp = StorageReader(dir, self.reader.as_mode(mode))
        return tmp

    def modes(self):
        return self.reader.modes()

    def tags(self, kind):
        return self.reader.tags(kind)
S
superjom 已提交
25 26 27 28 29 30 31 32 33

    def scalar(self, tag, type='float'):
        type2scalar = {
            'float': self.reader.get_scalar_float,
            'double': self.reader.get_scalar_double,
            'int': self.reader.get_scalar_int,
        }
        return type2scalar[type](tag)

S
superjom 已提交
34 35 36
    def image(self, tag):
        return self.reader.get_image(tag)

S
superjom 已提交
37 38 39

class StorageWriter(object):

S
superjom 已提交
40 41 42 43
    def __init__(self, dir, sync_cycle, writer=None):
        self.dir = dir
        self.sync_cycle = sync_cycle
        self.writer = writer if writer else core.Writer(dir, sync_cycle)
S
superjom 已提交
44 45

    def as_mode(self, mode):
S
superjom 已提交
46 47
        tmp = StorageWriter(self.dir, self.sync_cycle, self.writer.as_mode(mode))
        return tmp
S
superjom 已提交
48 49 50 51 52 53 54 55

    def scalar(self, tag, type='float'):
        type2scalar = {
            'float': self.writer.new_scalar_float,
            'double': self.writer.new_scalar_double,
            'int': self.writer.new_scalar_int,
        }
        return type2scalar[type](tag)
S
superjom 已提交
56

S
superjom 已提交
57 58
    def image(self, tag, num_samples, step_cycle):
        return self.writer.new_image(tag, num_samples, step_cycle)