提交 f0b6db3c 编写于 作者: S superjom

change all `as_mode` to `with mode`

上级 13086db4
...@@ -16,47 +16,47 @@ def get_modes(storage): ...@@ -16,47 +16,47 @@ def get_modes(storage):
def get_scalar_tags(storage, mode): def get_scalar_tags(storage, mode):
result = {} result = {}
for mode in storage.modes(): for mode in storage.modes():
reader = storage.as_mode(mode) with storage.mode(mode) as reader:
tags = reader.tags('scalar') tags = reader.tags('scalar')
if tags: if tags:
result[mode] = {} result[mode] = {}
for tag in tags: for tag in tags:
result[mode][tag] = { result[mode][tag] = {
'displayName': reader.scalar(tag).caption(), 'displayName': reader.scalar(tag).caption(),
'description': "", 'description': "",
} }
return result return result
def get_scalar(storage, mode, tag): def get_scalar(storage, mode, tag):
reader = storage.as_mode(mode) with storage.mode(mode) as reader:
scalar = reader.scalar(tag) scalar = reader.scalar(tag)
records = scalar.records() records = scalar.records()
ids = scalar.ids() ids = scalar.ids()
timestamps = scalar.timestamps() timestamps = scalar.timestamps()
result = zip(timestamps, ids, records) result = zip(timestamps, ids, records)
return result return result
def get_image_tags(storage): def get_image_tags(storage):
result = {} result = {}
for mode in storage.modes(): for mode in storage.modes():
reader = storage.as_mode(mode) with storage.mode(mode) as reader:
tags = reader.tags('image') tags = reader.tags('image')
if tags: if tags:
result[mode] = {} result[mode] = {}
for tag in tags: for tag in tags:
image = reader.image(tag) image = reader.image(tag)
for i in xrange(max(1, image.num_samples())): for i in xrange(max(1, image.num_samples())):
caption = tag if image.num_samples() <= 1 else '%s/%d'%(tag, i) caption = tag if image.num_samples() <= 1 else '%s/%d'%(tag, i)
result[mode][caption] = { result[mode][caption] = {
'displayName': caption, 'displayName': caption,
'description': "", 'description': "",
'samples': 1, 'samples': 1,
} }
return result return result
...@@ -70,9 +70,9 @@ def get_image_tag_steps(storage, mode, tag): ...@@ -70,9 +70,9 @@ def get_image_tag_steps(storage, mode, tag):
tag = tag[:tag.rfind('/')] tag = tag[:tag.rfind('/')]
sample_index = int(res.groups()[0]) sample_index = int(res.groups()[0])
reader = storage.as_mode(mode) with storage.mode(mode) as reader:
image = reader.image(tag) image = reader.image(tag)
res = [] res = []
for step_index in range(image.num_records()): for step_index in range(image.num_records()):
record = image.record(step_index, sample_index) record = image.record(step_index, sample_index)
...@@ -95,22 +95,22 @@ def get_image_tag_steps(storage, mode, tag): ...@@ -95,22 +95,22 @@ def get_image_tag_steps(storage, mode, tag):
def get_invididual_image(storage, mode, tag, step_index): def get_invididual_image(storage, mode, tag, step_index):
reader = storage.as_mode(mode) with storage.mode(mode) as reader:
res = re.search(r".*/([0-9]+$)", tag) res = re.search(r".*/([0-9]+$)", tag)
# remove suffix '/x' # remove suffix '/x'
if res: if res:
offset = int(res.groups()[0]) offset = int(res.groups()[0])
tag = tag[:tag.rfind('/')] tag = tag[:tag.rfind('/')]
image = reader.image(tag) image = reader.image(tag)
record = image.record(step_index, offset) record = image.record(step_index, offset)
data = np.array(record.data(), dtype='uint8').reshape(record.shape()) data = np.array(record.data(), dtype='uint8').reshape(record.shape())
tempfile = NamedTemporaryFile(mode='w+b', suffix='.png') tempfile = NamedTemporaryFile(mode='w+b', suffix='.png')
with Image.fromarray(data) as im: with Image.fromarray(data) as im:
im.save(tempfile) im.save(tempfile)
tempfile.seek(0, 0) tempfile.seek(0, 0)
return tempfile return tempfile
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -6,11 +6,11 @@ import numpy as np ...@@ -6,11 +6,11 @@ import numpy as np
def add_scalar(writer, mode, tag, num_steps, skip): def add_scalar(writer, mode, tag, num_steps, skip):
my_writer = writer.as_mode(mode) with writer.mode(mode) as my_writer:
scalar = my_writer.scalar(tag) scalar = my_writer.scalar(tag)
for i in range(num_steps): for i in range(num_steps):
if i % skip == 0: if i % skip == 0:
scalar.add_record(i, random.random()) scalar.add_record(i, random.random())
def add_image(writer, def add_image(writer,
...@@ -20,20 +20,20 @@ def add_image(writer, ...@@ -20,20 +20,20 @@ def add_image(writer,
num_passes, num_passes,
step_cycle, step_cycle,
shape=[50, 50, 3]): shape=[50, 50, 3]):
writer_ = writer.as_mode(mode) with writer.mode(mode) as writer_:
image_writer = writer_.image(tag, num_samples, step_cycle) image_writer = writer_.image(tag, num_samples, step_cycle)
for pass_ in xrange(num_passes): for pass_ in xrange(num_passes):
image_writer.start_sampling() image_writer.start_sampling()
for ins in xrange(2 * num_samples): for ins in xrange(2 * num_samples):
index = image_writer.is_sample_taken() index = image_writer.is_sample_taken()
if index != -1: if index != -1:
data = np.random.random(shape) * 256 data = np.random.random(shape) * 256
data = np.ndarray.flatten(data) data = np.ndarray.flatten(data)
assert shape assert shape
assert len(data) > 0 assert len(data) > 0
image_writer.set_sample(index, shape, list(data)) image_writer.set_sample(index, shape, list(data))
image_writer.finish_sampling() image_writer.finish_sampling()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -22,15 +22,16 @@ class StorageTest(unittest.TestCase): ...@@ -22,15 +22,16 @@ class StorageTest(unittest.TestCase):
scalar.add_record(i, float(i)) scalar.add_record(i, float(i))
print 'test read' print 'test read'
self.reader = storage.StorageReader(self.dir).as_mode("train") self.reader = storage.StorageReader(self.dir)
scalar = self.reader.scalar("model/scalar/min") with self.reader.mode("train") as reader:
self.assertEqual(scalar.caption(), "train") scalar = reader.scalar("model/scalar/min")
records = scalar.records() self.assertEqual(scalar.caption(), "train")
ids = scalar.ids() records = scalar.records()
self.assertTrue(np.equal(records, [float(i) for i in range(10)]).all()) ids = scalar.ids()
self.assertTrue(np.equal(ids, [float(i) for i in range(10)]).all()) self.assertTrue(np.equal(records, [float(i) for i in range(10)]).all())
print 'records', records self.assertTrue(np.equal(ids, [float(i) for i in range(10)]).all())
print 'ids', ids print 'records', records
print 'ids', ids
def test_image(self): def test_image(self):
tag = "layer1/layer2/image0" tag = "layer1/layer2/image0"
...@@ -49,19 +50,20 @@ class StorageTest(unittest.TestCase): ...@@ -49,19 +50,20 @@ class StorageTest(unittest.TestCase):
image_writer.set_sample(index, shape, list(data)) image_writer.set_sample(index, shape, list(data))
image_writer.finish_sampling() image_writer.finish_sampling()
self.reader = storage.StorageReader(self.dir).as_mode("train") self.reader = storage.StorageReader(self.dir)
image_reader = self.reader.image(tag) with self.reader.mode("train") as reader:
self.assertEqual(image_reader.caption(), tag) image_reader = reader.image(tag)
self.assertEqual(image_reader.num_records(), num_passes) self.assertEqual(image_reader.caption(), tag)
self.assertEqual(image_reader.num_records(), num_passes)
image_record = image_reader.record(0, 1) image_record = image_reader.record(0, 1)
self.assertTrue(np.equal(image_record.shape(), shape).all()) self.assertTrue(np.equal(image_record.shape(), shape).all())
data = image_record.data() data = image_record.data()
self.assertEqual(len(data), np.prod(shape)) self.assertEqual(len(data), np.prod(shape))
image_tags = self.reader.tags("image") image_tags = reader.tags("image")
self.assertTrue(image_tags) self.assertTrue(image_tags)
self.assertEqual(len(image_tags), 1) self.assertEqual(len(image_tags), 1)
def test_check_image(self): def test_check_image(self):
''' '''
...@@ -75,31 +77,32 @@ class StorageTest(unittest.TestCase): ...@@ -75,31 +77,32 @@ class StorageTest(unittest.TestCase):
shape = [image.size[1], image.size[0], 3] shape = [image.size[1], image.size[0], 3]
origin_data = np.array(image.getdata()).flatten() origin_data = np.array(image.getdata()).flatten()
self.reader = storage.StorageReader(self.dir).as_mode("train") self.reader = storage.StorageReader(self.dir)
with self.reader.mode("train") as reader:
image_writer.start_sampling() image_writer.start_sampling()
index = image_writer.is_sample_taken() index = image_writer.is_sample_taken()
image_writer.set_sample(index, shape, list(origin_data)) image_writer.set_sample(index, shape, list(origin_data))
image_writer.finish_sampling() image_writer.finish_sampling()
# read and check whether the original image will be displayed # read and check whether the original image will be displayed
image_reader = self.reader.image(tag) image_reader = reader.image(tag)
image_record = image_reader.record(0, 0) image_record = image_reader.record(0, 0)
data = image_record.data() data = image_record.data()
shape = image_record.shape() shape = image_record.shape()
PIL_image_shape = (shape[0] * shape[1], shape[2]) PIL_image_shape = (shape[0] * shape[1], shape[2])
data = np.array(data, dtype='uint8').reshape(PIL_image_shape) data = np.array(data, dtype='uint8').reshape(PIL_image_shape)
print 'origin', origin_data.flatten() print 'origin', origin_data.flatten()
print 'data', data.flatten() print 'data', data.flatten()
image = Image.fromarray(data.reshape(shape)) image = Image.fromarray(data.reshape(shape))
# manully check the image and found that nothing wrong with the image storage. # manully check the image and found that nothing wrong with the image storage.
# image.show() # image.show()
# after scale, elements are changed. # after scale, elements are changed.
# self.assertTrue( # self.assertTrue(
# np.equal(origin_data.reshape(PIL_image_shape), data).all()) # np.equal(origin_data.reshape(PIL_image_shape), data).all())
def test_with_syntax(self): def test_with_syntax(self):
with self.writer.mode("train") as writer: with self.writer.mode("train") as writer:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册