未验证 提交 6f02954f 编写于 作者: Y Yan Chunwei 提交者: GitHub

simplify image sample usage (#139)

上级 617c2dec
...@@ -118,7 +118,8 @@ PYBIND11_PLUGIN(core) { ...@@ -118,7 +118,8 @@ PYBIND11_PLUGIN(core) {
.def("start_sampling", &cp::Image::StartSampling) .def("start_sampling", &cp::Image::StartSampling)
.def("is_sample_taken", &cp::Image::IsSampleTaken) .def("is_sample_taken", &cp::Image::IsSampleTaken)
.def("finish_sampling", &cp::Image::FinishSampling) .def("finish_sampling", &cp::Image::FinishSampling)
.def("set_sample", &cp::Image::SetSample); .def("set_sample", &cp::Image::SetSample)
.def("add_sample", &cp::Image::AddSample);
py::class_<cp::ImageReader::ImageRecord>(m, "ImageRecord") py::class_<cp::ImageReader::ImageRecord>(m, "ImageRecord")
// TODO(ChunweiYan) make these copyless. // TODO(ChunweiYan) make these copyless.
......
...@@ -171,6 +171,14 @@ struct is_same_type<T, T> { ...@@ -171,6 +171,14 @@ struct is_same_type<T, T> {
static const bool value = true; static const bool value = true;
}; };
void Image::AddSample(const std::vector<shape_t>& shape,
const std::vector<value_t>& data) {
auto idx = IsSampleTaken();
if (idx >= 0) {
SetSample(idx, shape, data);
}
}
void Image::SetSample(int index, void Image::SetSample(int index,
const std::vector<shape_t>& shape, const std::vector<shape_t>& shape,
const std::vector<value_t>& data) { const std::vector<value_t>& data) {
......
...@@ -146,18 +146,29 @@ struct Image { ...@@ -146,18 +146,29 @@ struct Image {
} }
/* /*
* Start a sampling period. * Start a sampling period, this interface will start a new reservior sampling
* phase.
*/ */
void StartSampling(); void StartSampling();
/* /*
* Will this sample be taken. * End a sampling period, it will clear all states for reservior sampling.
*/ */
int IsSampleTaken(); void FinishSampling();
/* /*
* End a sampling period. * A combined interface for IsSampleTaken and SetSample, simpler but might be
* low effience.
*/ */
void FinishSampling(); void AddSample(const std::vector<shape_t>& shape,
const std::vector<value_t>& data);
/*
* Will this sample be taken, this interface is introduced to reduce the cost
* of copy image data, by testing whether this image will be sampled, and only
* copy data when it should be sampled. In that way, most of unsampled image
* data need not be copied or processed at all.
*/
int IsSampleTaken();
/* /*
* Just store a tensor with nothing to do with image format. * Just store a tensor with nothing to do with image format.
*/ */
......
...@@ -81,6 +81,40 @@ TEST(Image, test) { ...@@ -81,6 +81,40 @@ TEST(Image, test) {
CHECK_EQ(image2read.num_records(), num_steps); CHECK_EQ(image2read.num_records(), num_steps);
} }
TEST(Image, add_sample_test) {
const auto dir = "./tmp/sdk_test.image";
LogWriter writer__(dir, 4);
auto writer = writer__.AsMode("train");
auto tablet = writer.AddTablet("image0");
components::Image image(tablet, 3, 1);
const int num_steps = 10;
LOG(INFO) << "write images";
image.SetCaption("this is an image");
for (int step = 0; step < num_steps; step++) {
image.StartSampling();
for (int i = 0; i < 7; i++) {
vector<int64_t> shape({5, 5, 3});
vector<float> data;
for (int j = 0; j < 3 * 5 * 5; j++) {
data.push_back(float(rand()) / RAND_MAX);
}
image.AddSample(shape, data);
}
image.FinishSampling();
}
LOG(INFO) << "read images";
// read it
LogReader reader__(dir);
auto reader = reader__.AsMode("train");
auto tablet2read = reader.tablet("image0");
components::ImageReader image2read("train", tablet2read);
CHECK_EQ(image2read.caption(), "this is an image");
CHECK_EQ(image2read.num_records(), num_steps);
}
TEST(Histogram, AddRecord) { TEST(Histogram, AddRecord) {
const auto dir = "./tmp/sdk_test.histogram"; const auto dir = "./tmp/sdk_test.histogram";
LogWriter writer__(dir, 1); LogWriter writer__(dir, 1);
......
...@@ -13,8 +13,7 @@ from visualdl import LogWriter, LogReader ...@@ -13,8 +13,7 @@ from visualdl import LogWriter, LogReader
class StorageTest(unittest.TestCase): class StorageTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.dir = "./tmp/storage_test" self.dir = "./tmp/storage_test"
self.writer = LogWriter( self.writer = LogWriter(self.dir, sync_cycle=1).as_mode("train")
self.dir, sync_cycle=1).as_mode("train")
def test_scalar(self): def test_scalar(self):
print 'test write' print 'test write'
...@@ -30,7 +29,8 @@ class StorageTest(unittest.TestCase): ...@@ -30,7 +29,8 @@ class StorageTest(unittest.TestCase):
self.assertEqual(scalar.caption(), "train") self.assertEqual(scalar.caption(), "train")
records = scalar.records() records = scalar.records()
ids = scalar.ids() ids = scalar.ids()
self.assertTrue(np.equal(records, [float(i) for i in range(10)]).all()) self.assertTrue(
np.equal(records, [float(i) for i in range(10)]).all())
self.assertTrue(np.equal(ids, [float(i) for i in range(10)]).all()) self.assertTrue(np.equal(ids, [float(i) for i in range(10)]).all())
print 'records', records print 'records', records
print 'ids', ids print 'ids', ids
...@@ -45,11 +45,9 @@ class StorageTest(unittest.TestCase): ...@@ -45,11 +45,9 @@ class StorageTest(unittest.TestCase):
for pass_ in xrange(num_passes): for pass_ in xrange(num_passes):
image_writer.start_sampling() image_writer.start_sampling()
for ins in xrange(num_samples): for ins in xrange(num_samples):
index = image_writer.is_sample_taken() data = np.random.random(shape) * 256
if index != -1: data = np.ndarray.flatten(data)
data = np.random.random(shape) * 256 image_writer.add_sample(shape, list(data))
data = np.ndarray.flatten(data)
image_writer.set_sample(index, shape, list(data))
image_writer.finish_sampling() image_writer.finish_sampling()
self.reader = LogReader(self.dir) self.reader = LogReader(self.dir)
...@@ -88,7 +86,6 @@ class StorageTest(unittest.TestCase): ...@@ -88,7 +86,6 @@ class StorageTest(unittest.TestCase):
image_writer.finish_sampling() image_writer.finish_sampling()
# read and check whether the original image will be displayed # read and check whether the original image will be displayed
image_reader = reader.image(tag) image_reader = reader.image(tag)
image_record = image_reader.record(0, 0) image_record = image_reader.record(0, 0)
data = image_record.data() data = image_record.data()
...@@ -102,10 +99,6 @@ class StorageTest(unittest.TestCase): ...@@ -102,10 +99,6 @@ class StorageTest(unittest.TestCase):
# manully check the image and found that nothing wrong with the image storage. # manully check the image and found that nothing wrong with the image storage.
# image.show() # image.show()
# after scale, elements are changed.
# self.assertTrue(
# np.equal(origin_data.reshape(PIL_image_shape), data).all())
def test_with_syntax(self): def test_with_syntax(self):
with self.writer.mode("train") as writer: with self.writer.mode("train") as writer:
scalar = writer.scalar("model/scalar/average") scalar = writer.scalar("model/scalar/average")
...@@ -118,8 +111,7 @@ class StorageTest(unittest.TestCase): ...@@ -118,8 +111,7 @@ class StorageTest(unittest.TestCase):
self.assertEqual(scalar.caption(), "train") self.assertEqual(scalar.caption(), "train")
def test_modes(self): def test_modes(self):
store = LogWriter( store = LogWriter(self.dir, sync_cycle=1)
self.dir, sync_cycle=1)
scalars = [] scalars = []
...@@ -133,6 +125,5 @@ class StorageTest(unittest.TestCase): ...@@ -133,6 +125,5 @@ class StorageTest(unittest.TestCase):
scalar.add_record(i, float(i)) scalar.add_record(i, float(i))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -26,13 +26,9 @@ def add_image(writer, ...@@ -26,13 +26,9 @@ def add_image(writer,
for pass_ in xrange(num_passes): for pass_ in xrange(num_passes):
image_writer.start_sampling() image_writer.start_sampling()
for ins in xrange(2 * num_samples): for ins in xrange(2 * num_samples):
index = image_writer.is_sample_taken() data = np.random.random(shape) * 256
if index != -1: data = np.ndarray.flatten(data)
data = np.random.random(shape) * 256 image_writer.add_sample(shape, list(data))
data = np.ndarray.flatten(data)
assert shape
assert len(data) > 0
image_writer.set_sample(index, shape, list(data))
image_writer.finish_sampling() image_writer.finish_sampling()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册