提交 13086db4 编写于 作者: S superjom

add with syntax

上级 ed5e6596
......@@ -9,10 +9,16 @@ dtypes = ("float", "double", "int32", "int64")
class StorageReader(object):
cur_mode = None
def __init__(self, dir, reader=None):
self.dir = dir
self.reader = reader if reader else core.Reader(dir)
def mode(self, mode):
StorageReader.cur_mode = self.as_mode(mode)
return StorageReader.cur_mode
def as_mode(self, mode):
tmp = StorageReader(dir, self.reader.as_mode(mode))
return tmp
......@@ -34,17 +40,29 @@ class StorageReader(object):
def image(self, tag):
return self.reader.get_image(tag)
def __enter__(self):
return StorageReader.cur_mode
def __exit__(self, type, value, traceback):
pass
class StorageWriter(object):
cur_mode = None
def __init__(self, dir, sync_cycle, writer=None):
self.dir = dir
self.sync_cycle = sync_cycle
self.writer = writer if writer else core.Writer(dir, sync_cycle)
def mode(self, mode):
StorageWriter.cur_mode = self.as_mode(mode)
return StorageWriter.cur_mode
def as_mode(self, mode):
tmp = StorageWriter(self.dir, self.sync_cycle, self.writer.as_mode(mode))
return tmp
StorageWriter.cur_mode = StorageWriter(self.dir, self.sync_cycle, self.writer.as_mode(mode))
return StorageWriter.cur_mode
def scalar(self, tag, type='float'):
type2scalar = {
......@@ -56,3 +74,9 @@ class StorageWriter(object):
def image(self, tag, num_samples, step_cycle):
return self.writer.new_image(tag, num_samples, step_cycle)
def __enter__(self):
return StorageWriter.cur_mode
def __exit__(self, type, value, traceback):
pass
......@@ -101,6 +101,18 @@ class StorageTest(unittest.TestCase):
# self.assertTrue(
# np.equal(origin_data.reshape(PIL_image_shape), data).all())
def test_with_syntax(self):
with self.writer.mode("train") as writer:
scalar = writer.scalar("model/scalar/average")
for i in range(10):
scalar.add_record(i, float(i))
self.reader = storage.StorageReader(self.dir)
with self.reader.mode("train") as reader:
scalar = reader.scalar("model/scalar/average")
self.assertEqual(scalar.caption(), "train")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册