提交 401e139b 编写于 作者: S superjom

refactor image get record interface

上级 926e65b5
...@@ -3,6 +3,8 @@ import re ...@@ -3,6 +3,8 @@ import re
import storage import storage
from PIL import Image from PIL import Image
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
import pprint
import urllib
def get_scalar_tags(storage, mode): def get_scalar_tags(storage, mode):
...@@ -31,14 +33,14 @@ def get_image_tags(storage, mode): ...@@ -31,14 +33,14 @@ def get_image_tags(storage, mode):
image = reader.image(tag) image = reader.image(tag)
if image.num_samples() == 1: if image.num_samples() == 1:
result[mode][tag] = { result[mode][tag] = {
'displayName': reader.scalar(tag).caption(), 'displayName': mage.caption(),
'description': "", 'description': "",
'samples': 1, 'samples': 1,
} }
else: else:
for i in xrange(image.num_samples()): for i in xrange(image.num_samples()):
result[mode][tag + '/%d' % i] = { result[mode][tag + '/%d' % i] = {
'displayName': reader.scalar(tag).caption(), 'displayName': image.caption(),
'description': "", 'description': "",
'samples': 1, 'samples': 1,
} }
...@@ -48,32 +50,35 @@ def get_image_tags(storage, mode): ...@@ -48,32 +50,35 @@ def get_image_tags(storage, mode):
def get_image_tag_steps(storage, mode, tag): def get_image_tag_steps(storage, mode, tag):
# remove suffix '/x' # remove suffix '/x'
res = re.search(r".*/([0-9]+$)", tag) res = re.search(r".*/([0-9]+$)", tag)
step_index = 0
if res: if res:
tag = tag[:tag.rfind('/')] tag = tag[:tag.rfind('/')]
step_index = int(res.groups()[0])
reader = storage.as_mode(mode) reader = storage.as_mode(mode)
image = reader.image(tag) image = reader.image(tag)
# TODO(ChunweiYan) make max_steps a config
max_steps = 10
res = [] res = []
steps = []
if image.num_records() > max_steps:
span = int(image.num_records() / max_steps)
steps = [image.num_records() - i * span - 1 for i in xrange(max_steps)]
steps = [i for i in reversed(steps)]
steps[0] = max(steps[0], 0)
else:
steps = [i for i in xrange(image.num_records())]
for step in steps: for i in range(image.num_samples()):
record = image.record(step_index, i)
shape = record.shape()
query = urllib.urlencode({
'sample': 0,
'index': i,
'tag': tag,
'run': mode,
})
res.append({ res.append({
'wall_time': image.timestamp(step), 'height': shape[0],
'step': step, 'width': shape[1],
'step': record.step_id(),
'wall_time': image.timestamp(step_index),
'query': query,
}) })
return res return res
def get_invididual_image(storage, mode, tag, index): def get_invididual_image(storage, mode, tag, step_index):
reader = storage.as_mode(mode) reader = storage.as_mode(mode)
res = re.search(r".*/([0-9]+$)", tag) res = re.search(r".*/([0-9]+$)", tag)
# remove suffix '/x' # remove suffix '/x'
...@@ -82,11 +87,9 @@ def get_invididual_image(storage, mode, tag, index): ...@@ -82,11 +87,9 @@ def get_invididual_image(storage, mode, tag, index):
tag = tag[:tag.rfind('/')] tag = tag[:tag.rfind('/')]
image = reader.image(tag) image = reader.image(tag)
data = image.data(offset, index) record = image.record(step_index, offset)
shape = image.shape(offset, index)
# print data data = np.array(record.data(), dtype='uint8').reshape(record.shape())
# print shape
data = np.array(data, dtype='uint8').reshape(shape)
tempfile = NamedTemporaryFile(mode='w+b', suffix='.png') tempfile = NamedTemporaryFile(mode='w+b', suffix='.png')
with Image.fromarray(data) as im: with Image.fromarray(data) as im:
im.save(tempfile) im.save(tempfile)
...@@ -99,7 +102,8 @@ if __name__ == '__main__': ...@@ -99,7 +102,8 @@ if __name__ == '__main__':
tags = get_image_tags(reader, 'train') tags = get_image_tags(reader, 'train')
tags = get_image_tag_steps(reader, 'train', 'layer1/layer2/image0/0') tags = get_image_tag_steps(reader, 'train', 'layer1/layer2/image0/0')
print 'image step tags', tags print 'image step tags'
pprint.pprint(tags)
image = get_invididual_image(reader, "train", 'layer1/layer2/image0/0', 2) image = get_invididual_image(reader, "train", 'layer1/layer2/image0/0', 2)
print image print image
...@@ -27,7 +27,7 @@ for i in range(100): ...@@ -27,7 +27,7 @@ for i in range(100):
def add_image(mode): def add_image(mode):
writer_ = writer.as_mode(mode) writer_ = writer.as_mode(mode)
tag = "layer1/layer2/image0" tag = "layer1/layer2/image0"
image_writer = writer_.image(tag, 10) image_writer = writer_.image(tag, 10, 1)
num_passes = 25 num_passes = 25
num_samples = 100 num_samples = 100
shape = [10, 10, 3] shape = [10, 10, 3]
......
...@@ -53,9 +53,12 @@ PYBIND11_PLUGIN(core) { ...@@ -53,9 +53,12 @@ PYBIND11_PLUGIN(core) {
WRITER_ADD_SCALAR(int) WRITER_ADD_SCALAR(int)
// clang-format on // clang-format on
.def("new_image", .def("new_image",
[](vs::Writer& self, const std::string& tag, int num_samples) { [](vs::Writer& self,
const std::string& tag,
int num_samples,
int step_cycle) {
auto tablet = self.AddTablet(tag); auto tablet = self.AddTablet(tag);
return vs::components::Image(tablet, num_samples); return vs::components::Image(tablet, num_samples, step_cycle);
}); });
//------------------- components -------------------- //------------------- components --------------------
...@@ -88,12 +91,22 @@ PYBIND11_PLUGIN(core) { ...@@ -88,12 +91,22 @@ PYBIND11_PLUGIN(core) {
.def("finish_sampling", &cp::Image::FinishSampling) .def("finish_sampling", &cp::Image::FinishSampling)
.def("set_sample", &cp::Image::SetSample); .def("set_sample", &cp::Image::SetSample);
py::class_<cp::ImageReader::ImageRecord>(m, "ImageRecord")
// TODO(ChunweiYan) make these copyless.
.def("data", [](cp::ImageReader::ImageRecord& self) { return self.data; })
.def("shape",
[](cp::ImageReader::ImageRecord& self) { return self.shape; })
.def("step_id",
[](cp::ImageReader::ImageRecord& self) { return self.step_id; });
py::class_<cp::ImageReader>(m, "ImageReader") py::class_<cp::ImageReader>(m, "ImageReader")
.def("caption", &cp::ImageReader::caption) .def("caption", &cp::ImageReader::caption)
.def("num_records", &cp::ImageReader::num_records) .def("num_records", &cp::ImageReader::num_records)
.def("num_samples", &cp::ImageReader::num_samples) .def("num_samples", &cp::ImageReader::num_samples)
.def("timestamp", &cp::ImageReader::timestamp) .def("record", &cp::ImageReader::record)
.def("data", &cp::ImageReader::data) .def("timestamp", &cp::ImageReader::timestamp);
.def("shape", &cp::ImageReader::shape);
// .def("data", &cp::ImageReader::data)
// .def("shape", &cp::ImageReader::shape);
} // end pybind } // end pybind
...@@ -48,7 +48,11 @@ template class ScalarReader<float>; ...@@ -48,7 +48,11 @@ template class ScalarReader<float>;
template class ScalarReader<double>; template class ScalarReader<double>;
void Image::StartSampling() { void Image::StartSampling() {
// TODO(ChunweiYan) big bug here, every step will be stored in protobuf
// and that might result in explosion in some scenerios, Just sampling
// some steps should be better.
step_ = writer_.AddRecord(); step_ = writer_.AddRecord();
step_.SetId(step_id_);
time_t time = std::time(nullptr); time_t time = std::time(nullptr);
step_.SetTimeStamp(time); step_.SetTimeStamp(time);
...@@ -61,6 +65,7 @@ void Image::StartSampling() { ...@@ -61,6 +65,7 @@ void Image::StartSampling() {
} }
int Image::IsSampleTaken() { int Image::IsSampleTaken() {
if (!ToSampleThisStep()) return -1;
num_records_++; num_records_++;
if (num_records_ <= num_samples_) { if (num_records_ <= num_samples_) {
return num_records_ - 1; return num_records_ - 1;
...@@ -76,8 +81,11 @@ int Image::IsSampleTaken() { ...@@ -76,8 +81,11 @@ int Image::IsSampleTaken() {
} }
void Image::FinishSampling() { void Image::FinishSampling() {
step_id_++;
if (ToSampleThisStep()) {
// TODO(ChunweiYan) much optimizement here. // TODO(ChunweiYan) much optimizement here.
writer_.parent()->PersistToDisk(); writer_.parent()->PersistToDisk();
}
} }
template <typename T, typename U> template <typename T, typename U>
...@@ -124,16 +132,16 @@ std::string ImageReader::caption() { ...@@ -124,16 +132,16 @@ std::string ImageReader::caption() {
return caption; return caption;
} }
std::vector<ImageReader::value_t> ImageReader::data(int step, int index) { ImageReader::ImageRecord ImageReader::record(int offset, int index) {
auto record = reader_.record(step); ImageRecord res;
auto entry = record.data<value_t>(index); auto record = reader_.record(offset);
return entry.GetMulti(); auto data_entry = record.data<value_t>(index);
} auto shape_entry = record.data<shape_t>(index);
std::vector<ImageReader::shape_t> ImageReader::shape(int step, int index) { res.data = data_entry.GetMulti();
auto record = reader_.record(step); res.shape = shape_entry.GetMulti();
auto entry = record.data<shape_t>(index); res.step_id = record.id();
return entry.GetMulti(); return res;
} }
} // namespace components } // namespace components
......
...@@ -152,12 +152,19 @@ struct Image { ...@@ -152,12 +152,19 @@ struct Image {
using value_t = float; using value_t = float;
using shape_t = int64_t; using shape_t = int64_t;
Image(Tablet tablet, int num_samples) : writer_(tablet) { /*
* step_cycle: store every `step_cycle` as a record.
* num_samples: how many samples to take in a step.
*/
Image(Tablet tablet, int num_samples, int step_cycle)
: writer_(tablet), num_samples_(num_samples), step_cycle_(step_cycle) {
CHECK_GT(step_cycle, 0);
CHECK_GT(num_samples, 0);
writer_.SetType(Tablet::Type::kImage); writer_.SetType(Tablet::Type::kImage);
// make image's tag as the default caption. // make image's tag as the default caption.
writer_.SetNumSamples(num_samples); writer_.SetNumSamples(num_samples);
SetCaption(tablet.reader().tag()); SetCaption(tablet.reader().tag());
num_samples_ = num_samples;
} }
void SetCaption(const std::string& c) { void SetCaption(const std::string& c) {
writer_.SetCaptions(std::vector<std::string>({c})); writer_.SetCaptions(std::vector<std::string>({c}));
...@@ -182,11 +189,16 @@ struct Image { ...@@ -182,11 +189,16 @@ struct Image {
const std::vector<shape_t>& shape, const std::vector<shape_t>& shape,
const std::vector<value_t>& data); const std::vector<value_t>& data);
protected:
bool ToSampleThisStep() { return step_id_ % step_cycle_ == 0; }
private: private:
Tablet writer_; Tablet writer_;
Record step_; Record step_;
int num_records_{0}; int num_records_{0};
int num_samples_{0}; int num_samples_{0};
int step_id_{0};
int step_cycle_;
}; };
/* /*
...@@ -196,6 +208,12 @@ struct ImageReader { ...@@ -196,6 +208,12 @@ struct ImageReader {
using value_t = typename Image::value_t; using value_t = typename Image::value_t;
using shape_t = typename Image::shape_t; using shape_t = typename Image::shape_t;
struct ImageRecord {
int step_id;
std::vector<value_t> data;
std::vector<shape_t> shape;
};
ImageReader(const std::string& mode, TabletReader tablet) ImageReader(const std::string& mode, TabletReader tablet)
: reader_(tablet), mode_{mode} {} : reader_(tablet), mode_{mode} {}
...@@ -208,9 +226,25 @@ struct ImageReader { ...@@ -208,9 +226,25 @@ struct ImageReader {
int64_t timestamp(int step) { return reader_.record(step).timestamp(); } int64_t timestamp(int step) { return reader_.record(step).timestamp(); }
std::vector<value_t> data(int step, int index); /*
* offset: offset of a step.
* index: index of a sample.
*/
ImageRecord record(int offset, int index);
/*
* offset: offset of a step.
* index: index of a sample.
*/
std::vector<value_t> data(int offset, int index);
/*
* offset: offset of a step.
* index: index of a sample.
*/
std::vector<shape_t> shape(int offset, int index);
std::vector<shape_t> shape(int step, int index); int stepid(int offset, int index);
private: private:
TabletReader reader_; TabletReader reader_;
......
...@@ -48,7 +48,7 @@ TEST(Image, test) { ...@@ -48,7 +48,7 @@ TEST(Image, test) {
auto writer = writer__.AsMode("train"); auto writer = writer__.AsMode("train");
auto tablet = writer.AddTablet("image0"); auto tablet = writer.AddTablet("image0");
components::Image image(tablet, 3); components::Image image(tablet, 3, 1);
const int num_steps = 10; const int num_steps = 10;
LOG(INFO) << "write images"; LOG(INFO) << "write images";
......
...@@ -54,5 +54,5 @@ class StorageWriter(object): ...@@ -54,5 +54,5 @@ class StorageWriter(object):
} }
return type2scalar[type](tag) return type2scalar[type](tag)
def image(self, tag, num_samples): def image(self, tag, num_samples, step_cycle):
return self.writer.new_image(tag, num_samples) return self.writer.new_image(tag, num_samples, step_cycle)
...@@ -31,7 +31,7 @@ class StorageTest(unittest.TestCase): ...@@ -31,7 +31,7 @@ class StorageTest(unittest.TestCase):
def test_image(self): def test_image(self):
tag = "layer1/layer2/image0" tag = "layer1/layer2/image0"
image_writer = self.writer.image(tag, 10) image_writer = self.writer.image(tag, 10, 1)
num_passes = 10 num_passes = 10
num_samples = 100 num_samples = 100
shape = [3, 10, 10] shape = [3, 10, 10]
...@@ -50,8 +50,10 @@ class StorageTest(unittest.TestCase): ...@@ -50,8 +50,10 @@ class StorageTest(unittest.TestCase):
image_reader = self.reader.image(tag) image_reader = self.reader.image(tag)
self.assertEqual(image_reader.caption(), tag) self.assertEqual(image_reader.caption(), tag)
self.assertEqual(image_reader.num_records(), num_passes) self.assertEqual(image_reader.num_records(), num_passes)
self.assertTrue(np.equal(image_reader.shape(0, 1), shape).all())
data = image_reader.data(0, 1) image_record = image_reader.record(0, 1)
self.assertTrue(np.equal(image_record.shape(), shape).all())
data = image_record.data()
self.assertEqual(len(data), np.prod(shape)) self.assertEqual(len(data), np.prod(shape))
image_tags = self.reader.tags("image") image_tags = self.reader.tags("image")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册