未验证 提交 6f02954f 编写于 作者: Y Yan Chunwei 提交者: GitHub

simplify image sample usage (#139)

上级 617c2dec
......@@ -118,7 +118,8 @@ PYBIND11_PLUGIN(core) {
.def("start_sampling", &cp::Image::StartSampling)
.def("is_sample_taken", &cp::Image::IsSampleTaken)
.def("finish_sampling", &cp::Image::FinishSampling)
.def("set_sample", &cp::Image::SetSample);
.def("set_sample", &cp::Image::SetSample)
.def("add_sample", &cp::Image::AddSample);
py::class_<cp::ImageReader::ImageRecord>(m, "ImageRecord")
// TODO(ChunweiYan) make these copyless.
......
......@@ -171,6 +171,14 @@ struct is_same_type<T, T> {
static const bool value = true;
};
void Image::AddSample(const std::vector<shape_t>& shape,
const std::vector<value_t>& data) {
auto idx = IsSampleTaken();
if (idx >= 0) {
SetSample(idx, shape, data);
}
}
void Image::SetSample(int index,
const std::vector<shape_t>& shape,
const std::vector<value_t>& data) {
......
......@@ -146,18 +146,29 @@ struct Image {
}
/*
* Start a sampling period.
* Start a sampling period, this interface will start a new reservior sampling
* phase.
*/
void StartSampling();
/*
* Will this sample be taken.
* End a sampling period, it will clear all states for reservior sampling.
*/
int IsSampleTaken();
void FinishSampling();
/*
* End a sampling period.
* A combined interface for IsSampleTaken and SetSample, simpler but might be
* low effience.
*/
void FinishSampling();
void AddSample(const std::vector<shape_t>& shape,
const std::vector<value_t>& data);
/*
* Will this sample be taken, this interface is introduced to reduce the cost
* of copy image data, by testing whether this image will be sampled, and only
* copy data when it should be sampled. In that way, most of unsampled image
* data need not be copied or processed at all.
*/
int IsSampleTaken();
/*
* Just store a tensor with nothing to do with image format.
*/
......
......@@ -81,6 +81,40 @@ TEST(Image, test) {
CHECK_EQ(image2read.num_records(), num_steps);
}
TEST(Image, add_sample_test) {
const auto dir = "./tmp/sdk_test.image";
LogWriter writer__(dir, 4);
auto writer = writer__.AsMode("train");
auto tablet = writer.AddTablet("image0");
components::Image image(tablet, 3, 1);
const int num_steps = 10;
LOG(INFO) << "write images";
image.SetCaption("this is an image");
for (int step = 0; step < num_steps; step++) {
image.StartSampling();
for (int i = 0; i < 7; i++) {
vector<int64_t> shape({5, 5, 3});
vector<float> data;
for (int j = 0; j < 3 * 5 * 5; j++) {
data.push_back(float(rand()) / RAND_MAX);
}
image.AddSample(shape, data);
}
image.FinishSampling();
}
LOG(INFO) << "read images";
// read it
LogReader reader__(dir);
auto reader = reader__.AsMode("train");
auto tablet2read = reader.tablet("image0");
components::ImageReader image2read("train", tablet2read);
CHECK_EQ(image2read.caption(), "this is an image");
CHECK_EQ(image2read.num_records(), num_steps);
}
TEST(Histogram, AddRecord) {
const auto dir = "./tmp/sdk_test.histogram";
LogWriter writer__(dir, 1);
......
......@@ -13,8 +13,7 @@ from visualdl import LogWriter, LogReader
class StorageTest(unittest.TestCase):
def setUp(self):
self.dir = "./tmp/storage_test"
self.writer = LogWriter(
self.dir, sync_cycle=1).as_mode("train")
self.writer = LogWriter(self.dir, sync_cycle=1).as_mode("train")
def test_scalar(self):
print 'test write'
......@@ -30,7 +29,8 @@ class StorageTest(unittest.TestCase):
self.assertEqual(scalar.caption(), "train")
records = scalar.records()
ids = scalar.ids()
self.assertTrue(np.equal(records, [float(i) for i in range(10)]).all())
self.assertTrue(
np.equal(records, [float(i) for i in range(10)]).all())
self.assertTrue(np.equal(ids, [float(i) for i in range(10)]).all())
print 'records', records
print 'ids', ids
......@@ -45,11 +45,9 @@ class StorageTest(unittest.TestCase):
for pass_ in xrange(num_passes):
image_writer.start_sampling()
for ins in xrange(num_samples):
index = image_writer.is_sample_taken()
if index != -1:
data = np.random.random(shape) * 256
data = np.ndarray.flatten(data)
image_writer.set_sample(index, shape, list(data))
data = np.random.random(shape) * 256
data = np.ndarray.flatten(data)
image_writer.add_sample(shape, list(data))
image_writer.finish_sampling()
self.reader = LogReader(self.dir)
......@@ -88,7 +86,6 @@ class StorageTest(unittest.TestCase):
image_writer.finish_sampling()
# read and check whether the original image will be displayed
image_reader = reader.image(tag)
image_record = image_reader.record(0, 0)
data = image_record.data()
......@@ -102,10 +99,6 @@ class StorageTest(unittest.TestCase):
# manully check the image and found that nothing wrong with the image storage.
# image.show()
# after scale, elements are changed.
# self.assertTrue(
# np.equal(origin_data.reshape(PIL_image_shape), data).all())
def test_with_syntax(self):
with self.writer.mode("train") as writer:
scalar = writer.scalar("model/scalar/average")
......@@ -118,8 +111,7 @@ class StorageTest(unittest.TestCase):
self.assertEqual(scalar.caption(), "train")
def test_modes(self):
store = LogWriter(
self.dir, sync_cycle=1)
store = LogWriter(self.dir, sync_cycle=1)
scalars = []
......@@ -133,6 +125,5 @@ class StorageTest(unittest.TestCase):
scalar.add_record(i, float(i))
if __name__ == '__main__':
unittest.main()
......@@ -26,13 +26,9 @@ def add_image(writer,
for pass_ in xrange(num_passes):
image_writer.start_sampling()
for ins in xrange(2 * num_samples):
index = image_writer.is_sample_taken()
if index != -1:
data = np.random.random(shape) * 256
data = np.ndarray.flatten(data)
assert shape
assert len(data) > 0
image_writer.set_sample(index, shape, list(data))
data = np.random.random(shape) * 256
data = np.ndarray.flatten(data)
image_writer.add_sample(shape, list(data))
image_writer.finish_sampling()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册