提交 401e139b 编写于 作者: S superjom

refactor image get record interface

上级 926e65b5
......@@ -3,6 +3,8 @@ import re
import storage
from PIL import Image
from tempfile import NamedTemporaryFile
import pprint
import urllib
def get_scalar_tags(storage, mode):
......@@ -31,14 +33,14 @@ def get_image_tags(storage, mode):
image = reader.image(tag)
if image.num_samples() == 1:
result[mode][tag] = {
'displayName': reader.scalar(tag).caption(),
'displayName': mage.caption(),
'description': "",
'samples': 1,
}
else:
for i in xrange(image.num_samples()):
result[mode][tag + '/%d' % i] = {
'displayName': reader.scalar(tag).caption(),
'displayName': image.caption(),
'description': "",
'samples': 1,
}
......@@ -48,32 +50,35 @@ def get_image_tags(storage, mode):
def get_image_tag_steps(storage, mode, tag):
# remove suffix '/x'
res = re.search(r".*/([0-9]+$)", tag)
step_index = 0
if res:
tag = tag[:tag.rfind('/')]
step_index = int(res.groups()[0])
reader = storage.as_mode(mode)
image = reader.image(tag)
# TODO(ChunweiYan) make max_steps a config
max_steps = 10
res = []
steps = []
if image.num_records() > max_steps:
span = int(image.num_records() / max_steps)
steps = [image.num_records() - i * span - 1 for i in xrange(max_steps)]
steps = [i for i in reversed(steps)]
steps[0] = max(steps[0], 0)
else:
steps = [i for i in xrange(image.num_records())]
for step in steps:
for i in range(image.num_samples()):
record = image.record(step_index, i)
shape = record.shape()
query = urllib.urlencode({
'sample': 0,
'index': i,
'tag': tag,
'run': mode,
})
res.append({
'wall_time': image.timestamp(step),
'step': step,
'height': shape[0],
'width': shape[1],
'step': record.step_id(),
'wall_time': image.timestamp(step_index),
'query': query,
})
return res
def get_invididual_image(storage, mode, tag, index):
def get_invididual_image(storage, mode, tag, step_index):
reader = storage.as_mode(mode)
res = re.search(r".*/([0-9]+$)", tag)
# remove suffix '/x'
......@@ -82,11 +87,9 @@ def get_invididual_image(storage, mode, tag, index):
tag = tag[:tag.rfind('/')]
image = reader.image(tag)
data = image.data(offset, index)
shape = image.shape(offset, index)
# print data
# print shape
data = np.array(data, dtype='uint8').reshape(shape)
record = image.record(step_index, offset)
data = np.array(record.data(), dtype='uint8').reshape(record.shape())
tempfile = NamedTemporaryFile(mode='w+b', suffix='.png')
with Image.fromarray(data) as im:
im.save(tempfile)
......@@ -99,7 +102,8 @@ if __name__ == '__main__':
tags = get_image_tags(reader, 'train')
tags = get_image_tag_steps(reader, 'train', 'layer1/layer2/image0/0')
print 'image step tags', tags
print 'image step tags'
pprint.pprint(tags)
image = get_invididual_image(reader, "train", 'layer1/layer2/image0/0', 2)
print image
......@@ -27,7 +27,7 @@ for i in range(100):
def add_image(mode):
writer_ = writer.as_mode(mode)
tag = "layer1/layer2/image0"
image_writer = writer_.image(tag, 10)
image_writer = writer_.image(tag, 10, 1)
num_passes = 25
num_samples = 100
shape = [10, 10, 3]
......
......@@ -53,9 +53,12 @@ PYBIND11_PLUGIN(core) {
WRITER_ADD_SCALAR(int)
// clang-format on
.def("new_image",
[](vs::Writer& self, const std::string& tag, int num_samples) {
[](vs::Writer& self,
const std::string& tag,
int num_samples,
int step_cycle) {
auto tablet = self.AddTablet(tag);
return vs::components::Image(tablet, num_samples);
return vs::components::Image(tablet, num_samples, step_cycle);
});
//------------------- components --------------------
......@@ -88,12 +91,22 @@ PYBIND11_PLUGIN(core) {
.def("finish_sampling", &cp::Image::FinishSampling)
.def("set_sample", &cp::Image::SetSample);
py::class_<cp::ImageReader::ImageRecord>(m, "ImageRecord")
// TODO(ChunweiYan) make these copyless.
.def("data", [](cp::ImageReader::ImageRecord& self) { return self.data; })
.def("shape",
[](cp::ImageReader::ImageRecord& self) { return self.shape; })
.def("step_id",
[](cp::ImageReader::ImageRecord& self) { return self.step_id; });
py::class_<cp::ImageReader>(m, "ImageReader")
.def("caption", &cp::ImageReader::caption)
.def("num_records", &cp::ImageReader::num_records)
.def("num_samples", &cp::ImageReader::num_samples)
.def("timestamp", &cp::ImageReader::timestamp)
.def("data", &cp::ImageReader::data)
.def("shape", &cp::ImageReader::shape);
.def("record", &cp::ImageReader::record)
.def("timestamp", &cp::ImageReader::timestamp);
// .def("data", &cp::ImageReader::data)
// .def("shape", &cp::ImageReader::shape);
} // end pybind
......@@ -48,7 +48,11 @@ template class ScalarReader<float>;
template class ScalarReader<double>;
void Image::StartSampling() {
// TODO(ChunweiYan) big bug here, every step will be stored in protobuf
// and that might result in explosion in some scenerios, Just sampling
// some steps should be better.
step_ = writer_.AddRecord();
step_.SetId(step_id_);
time_t time = std::time(nullptr);
step_.SetTimeStamp(time);
......@@ -61,6 +65,7 @@ void Image::StartSampling() {
}
int Image::IsSampleTaken() {
if (!ToSampleThisStep()) return -1;
num_records_++;
if (num_records_ <= num_samples_) {
return num_records_ - 1;
......@@ -76,8 +81,11 @@ int Image::IsSampleTaken() {
}
void Image::FinishSampling() {
// TODO(ChunweiYan) much optimizement here.
writer_.parent()->PersistToDisk();
step_id_++;
if (ToSampleThisStep()) {
// TODO(ChunweiYan) much optimizement here.
writer_.parent()->PersistToDisk();
}
}
template <typename T, typename U>
......@@ -124,16 +132,16 @@ std::string ImageReader::caption() {
return caption;
}
std::vector<ImageReader::value_t> ImageReader::data(int step, int index) {
auto record = reader_.record(step);
auto entry = record.data<value_t>(index);
return entry.GetMulti();
}
ImageReader::ImageRecord ImageReader::record(int offset, int index) {
ImageRecord res;
auto record = reader_.record(offset);
auto data_entry = record.data<value_t>(index);
auto shape_entry = record.data<shape_t>(index);
std::vector<ImageReader::shape_t> ImageReader::shape(int step, int index) {
auto record = reader_.record(step);
auto entry = record.data<shape_t>(index);
return entry.GetMulti();
res.data = data_entry.GetMulti();
res.shape = shape_entry.GetMulti();
res.step_id = record.id();
return res;
}
} // namespace components
......
......@@ -152,12 +152,19 @@ struct Image {
using value_t = float;
using shape_t = int64_t;
Image(Tablet tablet, int num_samples) : writer_(tablet) {
/*
* step_cycle: store every `step_cycle` as a record.
* num_samples: how many samples to take in a step.
*/
Image(Tablet tablet, int num_samples, int step_cycle)
: writer_(tablet), num_samples_(num_samples), step_cycle_(step_cycle) {
CHECK_GT(step_cycle, 0);
CHECK_GT(num_samples, 0);
writer_.SetType(Tablet::Type::kImage);
// make image's tag as the default caption.
writer_.SetNumSamples(num_samples);
SetCaption(tablet.reader().tag());
num_samples_ = num_samples;
}
void SetCaption(const std::string& c) {
writer_.SetCaptions(std::vector<std::string>({c}));
......@@ -182,11 +189,16 @@ struct Image {
const std::vector<shape_t>& shape,
const std::vector<value_t>& data);
protected:
bool ToSampleThisStep() { return step_id_ % step_cycle_ == 0; }
private:
Tablet writer_;
Record step_;
int num_records_{0};
int num_samples_{0};
int step_id_{0};
int step_cycle_;
};
/*
......@@ -196,6 +208,12 @@ struct ImageReader {
using value_t = typename Image::value_t;
using shape_t = typename Image::shape_t;
struct ImageRecord {
int step_id;
std::vector<value_t> data;
std::vector<shape_t> shape;
};
ImageReader(const std::string& mode, TabletReader tablet)
: reader_(tablet), mode_{mode} {}
......@@ -208,9 +226,25 @@ struct ImageReader {
int64_t timestamp(int step) { return reader_.record(step).timestamp(); }
std::vector<value_t> data(int step, int index);
/*
* offset: offset of a step.
* index: index of a sample.
*/
ImageRecord record(int offset, int index);
/*
* offset: offset of a step.
* index: index of a sample.
*/
std::vector<value_t> data(int offset, int index);
/*
* offset: offset of a step.
* index: index of a sample.
*/
std::vector<shape_t> shape(int offset, int index);
std::vector<shape_t> shape(int step, int index);
int stepid(int offset, int index);
private:
TabletReader reader_;
......
......@@ -48,7 +48,7 @@ TEST(Image, test) {
auto writer = writer__.AsMode("train");
auto tablet = writer.AddTablet("image0");
components::Image image(tablet, 3);
components::Image image(tablet, 3, 1);
const int num_steps = 10;
LOG(INFO) << "write images";
......
......@@ -54,5 +54,5 @@ class StorageWriter(object):
}
return type2scalar[type](tag)
def image(self, tag, num_samples):
return self.writer.new_image(tag, num_samples)
def image(self, tag, num_samples, step_cycle):
return self.writer.new_image(tag, num_samples, step_cycle)
......@@ -31,7 +31,7 @@ class StorageTest(unittest.TestCase):
def test_image(self):
tag = "layer1/layer2/image0"
image_writer = self.writer.image(tag, 10)
image_writer = self.writer.image(tag, 10, 1)
num_passes = 10
num_samples = 100
shape = [3, 10, 10]
......@@ -50,8 +50,10 @@ class StorageTest(unittest.TestCase):
image_reader = self.reader.image(tag)
self.assertEqual(image_reader.caption(), tag)
self.assertEqual(image_reader.num_records(), num_passes)
self.assertTrue(np.equal(image_reader.shape(0, 1), shape).all())
data = image_reader.data(0, 1)
image_record = image_reader.record(0, 1)
self.assertTrue(np.equal(image_record.shape(), shape).all())
data = image_record.data()
self.assertEqual(len(data), np.prod(shape))
image_tags = self.reader.tags("image")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册