storage.py 2.1 KB
Newer Older
S
superjom 已提交
1 2 3 4 5 6 7 8 9 10 11
__all__ = [
    'StorageReader',
    'StorageWriter',
]
import core

dtypes = ("float", "double", "int32", "int64")


class StorageReader(object):

S
superjom 已提交
12 13
    cur_mode = None

S
superjom 已提交
14 15 16 17
    def __init__(self, dir, reader=None):
        self.dir = dir
        self.reader = reader if reader else core.Reader(dir)

S
superjom 已提交
18 19 20 21
    def mode(self, mode):
        StorageReader.cur_mode = self.as_mode(mode)
        return StorageReader.cur_mode

S
superjom 已提交
22 23 24 25 26 27 28 29 30
    def as_mode(self, mode):
        tmp = StorageReader(dir, self.reader.as_mode(mode))
        return tmp

    def modes(self):
        return self.reader.modes()

    def tags(self, kind):
        return self.reader.tags(kind)
S
superjom 已提交
31 32 33 34 35 36 37 38 39

    def scalar(self, tag, type='float'):
        type2scalar = {
            'float': self.reader.get_scalar_float,
            'double': self.reader.get_scalar_double,
            'int': self.reader.get_scalar_int,
        }
        return type2scalar[type](tag)

S
superjom 已提交
40 41 42
    def image(self, tag):
        return self.reader.get_image(tag)

S
superjom 已提交
43 44 45 46 47 48
    def __enter__(self):
        return StorageReader.cur_mode

    def __exit__(self, type, value, traceback):
        pass

S
superjom 已提交
49 50 51

class StorageWriter(object):

S
superjom 已提交
52 53
    cur_mode = None

S
superjom 已提交
54 55 56 57
    def __init__(self, dir, sync_cycle, writer=None):
        self.dir = dir
        self.sync_cycle = sync_cycle
        self.writer = writer if writer else core.Writer(dir, sync_cycle)
S
superjom 已提交
58

S
superjom 已提交
59 60 61 62
    def mode(self, mode):
        StorageWriter.cur_mode = self.as_mode(mode)
        return StorageWriter.cur_mode

S
superjom 已提交
63
    def as_mode(self, mode):
S
superjom 已提交
64 65
        StorageWriter.cur_mode = StorageWriter(self.dir, self.sync_cycle, self.writer.as_mode(mode))
        return StorageWriter.cur_mode
S
superjom 已提交
66 67 68 69 70 71 72 73

    def scalar(self, tag, type='float'):
        type2scalar = {
            'float': self.writer.new_scalar_float,
            'double': self.writer.new_scalar_double,
            'int': self.writer.new_scalar_int,
        }
        return type2scalar[type](tag)
S
superjom 已提交
74

S
superjom 已提交
75 76
    def image(self, tag, num_samples, step_cycle):
        return self.writer.new_image(tag, num_samples, step_cycle)
S
superjom 已提交
77 78 79 80 81 82

    def __enter__(self):
        return StorageWriter.cur_mode

    def __exit__(self, type, value, traceback):
        pass