提交 23a4be30 编写于 作者: S superjom

add image consistent

上级 93abe951
......@@ -37,7 +37,7 @@ template class SimpleWriteSyncGuard<Entry<double>>;
template class SimpleWriteSyncGuard<Entry<bool>>;
template class SimpleWriteSyncGuard<Entry<long>>;
template class SimpleWriteSyncGuard<Entry<std::string>>;
template class SimpleWriteSyncGuard<Entry<std::vector<char>>>;
template class SimpleWriteSyncGuard<Entry<std::vector<byte_t>>>;
template class SimpleWriteSyncGuard<Entry<int>>;
} // namespace visualdl
......@@ -93,7 +93,10 @@ PYBIND11_PLUGIN(core) {
py::class_<cp::ImageReader::ImageRecord>(m, "ImageRecord")
// TODO(ChunweiYan) make these copyless.
.def("data", [](cp::ImageReader::ImageRecord& self) { return self.data; })
.def("data",
[](cp::ImageReader::ImageRecord& self) {
return self.data;
})
.def("shape",
[](cp::ImageReader::ImageRecord& self) { return self.shape; })
.def("step_id",
......
......@@ -113,13 +113,13 @@ void Image::SetSample(int index,
CHECK_LT(index, num_samples_);
CHECK_LE(index, num_records_);
auto entry = step_.MutableData<std::vector<char>>(index);
auto entry = step_.MutableData<std::vector<byte_t>>(index);
// trick to store int8 to protobuf
std::vector<char> data_str(data.size());
std::vector<byte_t> data_str(data.size());
for (int i = 0; i < data.size(); i++) {
data_str[i] = data[i];
}
entry.Set(data_str);
entry.SetRaw(std::string(data_str.begin(), data_str.end()));
static_assert(
!is_same_type<value_t, shape_t>::value,
......@@ -145,13 +145,13 @@ std::string ImageReader::caption() {
ImageReader::ImageRecord ImageReader::record(int offset, int index) {
ImageRecord res;
auto record = reader_.record(offset);
auto data_entry = record.data<std::vector<char>>(index);
auto data_entry = record.data<std::vector<byte_t>>(index);
auto shape_entry = record.data<shape_t>(index);
auto data_str = data_entry.Get();
auto data_str = data_entry.GetRaw();
std::transform(data_str.begin(),
data_str.end(),
std::back_inserter(res.data),
[](char i) { return (int)((unsigned char)i); });
[](byte_t i) { return (int)(i); });
res.shape = shape_entry.GetMulti();
res.step_id = record.id();
return res;
......
import random
import time
import unittest
from PIL import Image
import numpy as np
......@@ -62,6 +63,46 @@ class StorageTest(unittest.TestCase):
self.assertTrue(image_tags)
self.assertEqual(len(image_tags), 1)
def test_check_image(self):
'''
check whether the storage will keep image data consistent
'''
print 'check image'
tag = "layer1/check/image1"
image_writer = self.writer.image(tag, 10, 1)
image = Image.open("./dog.jpg")
shape = [image.size[1], image.size[0], 3]
origin_data = np.array(image.getdata()).flatten()
self.reader = storage.StorageReader(self.dir).as_mode("train")
image_writer.start_sampling()
index = image_writer.is_sample_taken()
image_writer.set_sample(index, shape, list(origin_data))
image_writer.finish_sampling()
# read and check whether the original image will be displayed
image_reader = self.reader.image(tag)
image_record = image_reader.record(0, 0)
data = image_record.data()
shape = image_record.shape()
PIL_image_shape = (shape[0]*shape[1], shape[2])
data = np.array(data, dtype='uint8').reshape(PIL_image_shape)
print 'origin', origin_data.flatten()
print 'data', data.flatten()
image = Image.fromarray(data.reshape(shape))
self.assertTrue(np.equal(origin_data.reshape(PIL_image_shape), data).all())
if __name__ == '__main__':
unittest.main()
......@@ -22,14 +22,14 @@ namespace visualdl {
}
template <>
void Entry<std::vector<char>>::Set(std::vector<char> v) {
void Entry<std::vector<byte_t>>::Set(std::vector<byte_t> v) {
entry->set_dtype(storage::DataType::kBytes);
entry->set_y(std::string(v.begin(), v.end()));
WRITE_GUARD
}
template <>
void Entry<std::vector<char>>::Add(std::vector<char> v) {
void Entry<std::vector<byte_t>>::Add(std::vector<byte_t> v) {
entry->set_dtype(storage::DataType::kBytess);
*entry->add_ys() = std::string(v.begin(), v.end());
WRITE_GUARD
......@@ -68,9 +68,9 @@ IMPL_ENTRY_GET(std::string, s);
IMPL_ENTRY_GET(bool, b);
template <>
std::vector<char> EntryReader<std::vector<char>>::Get() const {
std::vector<uint8_t> EntryReader<std::vector<byte_t>>::Get() const {
const auto& y = data_.y();
return std::vector<char>(y.begin(), y.end());
return std::vector<byte_t>(y.begin(), y.end());
}
#define IMPL_ENTRY_GET_MULTI(T, fieldname__) \
......@@ -91,12 +91,12 @@ template class Entry<int>;
template class Entry<float>;
template class Entry<double>;
template class Entry<bool>;
template class Entry<std::vector<char>>;
template class Entry<std::vector<byte_t>>;
template class EntryReader<int>;
template class EntryReader<float>;
template class EntryReader<double>;
template class EntryReader<bool>;
template class EntryReader<std::vector<char>>;
template class EntryReader<std::vector<byte_t>>;
} // namespace visualdl
......@@ -9,6 +9,8 @@ namespace visualdl {
struct Storage;
using byte_t = unsigned char;
/*
* Utility helper for storage::Entry.
*/
......@@ -30,6 +32,8 @@ struct Entry {
// Set a single value.
void Set(T v);
void SetRaw(const std::string& bytes) { entry->set_y(bytes); }
// Add a value to repeated message field.
void Add(T v);
......@@ -50,6 +54,8 @@ struct EntryReader {
// Get repeated field.
std::vector<T> GetMulti() const;
std::string GetRaw() { return data_.y(); }
private:
storage::Entry data_;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册