提交 f0b6db3c 编写于 作者: S superjom

change all `as_mode` to `with mode`

上级 13086db4
...@@ -16,7 +16,7 @@ def get_modes(storage): ...@@ -16,7 +16,7 @@ def get_modes(storage):
def get_scalar_tags(storage, mode): def get_scalar_tags(storage, mode):
result = {} result = {}
for mode in storage.modes(): for mode in storage.modes():
reader = storage.as_mode(mode) with storage.mode(mode) as reader:
tags = reader.tags('scalar') tags = reader.tags('scalar')
if tags: if tags:
result[mode] = {} result[mode] = {}
...@@ -29,7 +29,7 @@ def get_scalar_tags(storage, mode): ...@@ -29,7 +29,7 @@ def get_scalar_tags(storage, mode):
def get_scalar(storage, mode, tag): def get_scalar(storage, mode, tag):
reader = storage.as_mode(mode) with storage.mode(mode) as reader:
scalar = reader.scalar(tag) scalar = reader.scalar(tag)
records = scalar.records() records = scalar.records()
...@@ -44,7 +44,7 @@ def get_image_tags(storage): ...@@ -44,7 +44,7 @@ def get_image_tags(storage):
result = {} result = {}
for mode in storage.modes(): for mode in storage.modes():
reader = storage.as_mode(mode) with storage.mode(mode) as reader:
tags = reader.tags('image') tags = reader.tags('image')
if tags: if tags:
result[mode] = {} result[mode] = {}
...@@ -70,7 +70,7 @@ def get_image_tag_steps(storage, mode, tag): ...@@ -70,7 +70,7 @@ def get_image_tag_steps(storage, mode, tag):
tag = tag[:tag.rfind('/')] tag = tag[:tag.rfind('/')]
sample_index = int(res.groups()[0]) sample_index = int(res.groups()[0])
reader = storage.as_mode(mode) with storage.mode(mode) as reader:
image = reader.image(tag) image = reader.image(tag)
res = [] res = []
...@@ -95,7 +95,7 @@ def get_image_tag_steps(storage, mode, tag): ...@@ -95,7 +95,7 @@ def get_image_tag_steps(storage, mode, tag):
def get_invididual_image(storage, mode, tag, step_index): def get_invididual_image(storage, mode, tag, step_index):
reader = storage.as_mode(mode) with storage.mode(mode) as reader:
res = re.search(r".*/([0-9]+$)", tag) res = re.search(r".*/([0-9]+$)", tag)
# remove suffix '/x' # remove suffix '/x'
if res: if res:
......
...@@ -6,7 +6,7 @@ import numpy as np ...@@ -6,7 +6,7 @@ import numpy as np
def add_scalar(writer, mode, tag, num_steps, skip): def add_scalar(writer, mode, tag, num_steps, skip):
my_writer = writer.as_mode(mode) with writer.mode(mode) as my_writer:
scalar = my_writer.scalar(tag) scalar = my_writer.scalar(tag)
for i in range(num_steps): for i in range(num_steps):
if i % skip == 0: if i % skip == 0:
...@@ -20,7 +20,7 @@ def add_image(writer, ...@@ -20,7 +20,7 @@ def add_image(writer,
num_passes, num_passes,
step_cycle, step_cycle,
shape=[50, 50, 3]): shape=[50, 50, 3]):
writer_ = writer.as_mode(mode) with writer.mode(mode) as writer_:
image_writer = writer_.image(tag, num_samples, step_cycle) image_writer = writer_.image(tag, num_samples, step_cycle)
for pass_ in xrange(num_passes): for pass_ in xrange(num_passes):
......
...@@ -22,8 +22,9 @@ class StorageTest(unittest.TestCase): ...@@ -22,8 +22,9 @@ class StorageTest(unittest.TestCase):
scalar.add_record(i, float(i)) scalar.add_record(i, float(i))
print 'test read' print 'test read'
self.reader = storage.StorageReader(self.dir).as_mode("train") self.reader = storage.StorageReader(self.dir)
scalar = self.reader.scalar("model/scalar/min") with self.reader.mode("train") as reader:
scalar = reader.scalar("model/scalar/min")
self.assertEqual(scalar.caption(), "train") self.assertEqual(scalar.caption(), "train")
records = scalar.records() records = scalar.records()
ids = scalar.ids() ids = scalar.ids()
...@@ -49,8 +50,9 @@ class StorageTest(unittest.TestCase): ...@@ -49,8 +50,9 @@ class StorageTest(unittest.TestCase):
image_writer.set_sample(index, shape, list(data)) image_writer.set_sample(index, shape, list(data))
image_writer.finish_sampling() image_writer.finish_sampling()
self.reader = storage.StorageReader(self.dir).as_mode("train") self.reader = storage.StorageReader(self.dir)
image_reader = self.reader.image(tag) with self.reader.mode("train") as reader:
image_reader = reader.image(tag)
self.assertEqual(image_reader.caption(), tag) self.assertEqual(image_reader.caption(), tag)
self.assertEqual(image_reader.num_records(), num_passes) self.assertEqual(image_reader.num_records(), num_passes)
...@@ -59,7 +61,7 @@ class StorageTest(unittest.TestCase): ...@@ -59,7 +61,7 @@ class StorageTest(unittest.TestCase):
data = image_record.data() data = image_record.data()
self.assertEqual(len(data), np.prod(shape)) self.assertEqual(len(data), np.prod(shape))
image_tags = self.reader.tags("image") image_tags = reader.tags("image")
self.assertTrue(image_tags) self.assertTrue(image_tags)
self.assertEqual(len(image_tags), 1) self.assertEqual(len(image_tags), 1)
...@@ -75,7 +77,8 @@ class StorageTest(unittest.TestCase): ...@@ -75,7 +77,8 @@ class StorageTest(unittest.TestCase):
shape = [image.size[1], image.size[0], 3] shape = [image.size[1], image.size[0], 3]
origin_data = np.array(image.getdata()).flatten() origin_data = np.array(image.getdata()).flatten()
self.reader = storage.StorageReader(self.dir).as_mode("train") self.reader = storage.StorageReader(self.dir)
with self.reader.mode("train") as reader:
image_writer.start_sampling() image_writer.start_sampling()
index = image_writer.is_sample_taken() index = image_writer.is_sample_taken()
...@@ -84,7 +87,7 @@ class StorageTest(unittest.TestCase): ...@@ -84,7 +87,7 @@ class StorageTest(unittest.TestCase):
# read and check whether the original image will be displayed # read and check whether the original image will be displayed
image_reader = self.reader.image(tag) image_reader = reader.image(tag)
image_record = image_reader.record(0, 0) image_record = image_reader.record(0, 0)
data = image_record.data() data = image_record.data()
shape = image_record.shape() shape = image_record.shape()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册