提交 41c22081 编写于 作者: S superjom

add image py test

上级 194e3792
...@@ -43,7 +43,7 @@ add_executable(vl_test ...@@ -43,7 +43,7 @@ add_executable(vl_test
${PROJECT_SOURCE_DIR}/visualdl/utils/concurrency.h ${PROJECT_SOURCE_DIR}/visualdl/utils/concurrency.h
${PROJECT_SOURCE_DIR}/visualdl/utils/filesystem.h ${PROJECT_SOURCE_DIR}/visualdl/utils/filesystem.h
) )
target_link_libraries(vl_test sdk storage entry im gtest glog protobuf gflags pthread) target_link_libraries(vl_test sdk storage entry tablet im gtest glog protobuf gflags pthread)
enable_testing () enable_testing ()
......
#add_library(sdk ${PROJECT_SOURCE_DIR}/visualdl/logic/sdk.cc)
add_library(im ${PROJECT_SOURCE_DIR}/visualdl/logic/im.cc) add_library(im ${PROJECT_SOURCE_DIR}/visualdl/logic/im.cc)
add_library(sdk ${PROJECT_SOURCE_DIR}/visualdl/logic/sdk.cc) add_library(sdk ${PROJECT_SOURCE_DIR}/visualdl/logic/sdk.cc)
add_dependencies(im storage_proto) add_dependencies(im storage_proto)
...@@ -6,6 +5,6 @@ add_dependencies(sdk entry storage storage_proto) ...@@ -6,6 +5,6 @@ add_dependencies(sdk entry storage storage_proto)
## pybind ## pybind
add_library(core SHARED ${PROJECT_SOURCE_DIR}/visualdl/logic/pybind.cc) add_library(core SHARED ${PROJECT_SOURCE_DIR}/visualdl/logic/pybind.cc)
add_dependencies(core pybind python im entry storage sdk protobuf glog) add_dependencies(core pybind python im entry tablet storage sdk protobuf glog)
target_link_libraries(core PRIVATE pybind entry python im storage sdk protobuf glog) target_link_libraries(core PRIVATE pybind entry python im tablet storage sdk protobuf glog)
set_target_properties(core PROPERTIES PREFIX "" SUFFIX ".so") set_target_properties(core PROPERTIES PREFIX "" SUFFIX ".so")
...@@ -11,28 +11,7 @@ namespace cp = visualdl::components; ...@@ -11,28 +11,7 @@ namespace cp = visualdl::components;
PYBIND11_PLUGIN(core) { PYBIND11_PLUGIN(core) {
py::module m("core", "C++ core of VisualDL"); py::module m("core", "C++ core of VisualDL");
#define ADD_SCALAR(T) \ #define READER_ADD_SCALAR(T) \
py::class_<cp::ScalarReader<T>>(m, "ScalarReader__" #T) \
.def("records", &cp::ScalarReader<T>::records) \
.def("timestamps", &cp::ScalarReader<T>::timestamps) \
.def("ids", &cp::ScalarReader<T>::ids) \
.def("caption", &cp::ScalarReader<T>::caption);
ADD_SCALAR(int);
ADD_SCALAR(float);
ADD_SCALAR(double);
ADD_SCALAR(int64_t);
#undef ADD_SCALAR
#define ADD_SCALAR_WRITER(T) \
py::class_<cp::Scalar<T>>(m, "ScalarWriter__" #T) \
.def("set_caption", &cp::Scalar<T>::SetCaption) \
.def("add_record", &cp::Scalar<T>::AddRecord);
ADD_SCALAR_WRITER(int);
ADD_SCALAR_WRITER(float);
ADD_SCALAR_WRITER(double);
#undef ADD_SCALAR_WRITER
#define ADD_SCALAR(T) \
.def("get_scalar_" #T, [](vs::Reader& self, const std::string& tag) { \ .def("get_scalar_" #T, [](vs::Reader& self, const std::string& tag) { \
auto tablet = self.tablet(tag); \ auto tablet = self.tablet(tag); \
return vs::components::ScalarReader<T>(std::move(tablet)); \ return vs::components::ScalarReader<T>(std::move(tablet)); \
...@@ -46,13 +25,17 @@ PYBIND11_PLUGIN(core) { ...@@ -46,13 +25,17 @@ PYBIND11_PLUGIN(core) {
.def("modes", [](vs::Reader& self) { return self.storage().modes(); }) .def("modes", [](vs::Reader& self) { return self.storage().modes(); })
.def("tags", &vs::Reader::tags) .def("tags", &vs::Reader::tags)
// clang-format off // clang-format off
ADD_SCALAR(float) READER_ADD_SCALAR(float)
ADD_SCALAR(double) READER_ADD_SCALAR(double)
ADD_SCALAR(int); READER_ADD_SCALAR(int)
// clang-format on // clang-format on
#undef ADD_SCALAR .def("get_image", [](vs::Reader& self, const std::string& tag) {
auto tablet = self.tablet(tag);
return vs::components::ImageReader(self.mode(), tablet);
});
#undef READER_ADD_SCALAR
#define ADD_SCALAR(T) \ #define WRITER_ADD_SCALAR(T) \
.def("new_scalar_" #T, [](vs::Writer& self, const std::string& tag) { \ .def("new_scalar_" #T, [](vs::Writer& self, const std::string& tag) { \
auto tablet = self.AddTablet(tag); \ auto tablet = self.AddTablet(tag); \
return cp::Scalar<T>(tablet); \ return cp::Scalar<T>(tablet); \
...@@ -65,10 +48,50 @@ PYBIND11_PLUGIN(core) { ...@@ -65,10 +48,50 @@ PYBIND11_PLUGIN(core) {
}) })
.def("as_mode", &vs::Writer::AsMode) .def("as_mode", &vs::Writer::AsMode)
// clang-format off // clang-format off
ADD_SCALAR(float) WRITER_ADD_SCALAR(float)
ADD_SCALAR(double) WRITER_ADD_SCALAR(double)
ADD_SCALAR(int); WRITER_ADD_SCALAR(int)
// clang-format on // clang-format on
#undef ADD_SCALAR .def("new_image",
[](vs::Writer& self, const std::string& tag, int num_samples) {
auto tablet = self.AddTablet(tag);
return vs::components::Image(tablet, num_samples);
});
//------------------- components --------------------
#define ADD_SCALAR_READER(T) \
py::class_<cp::ScalarReader<T>>(m, "ScalarReader__" #T) \
.def("records", &cp::ScalarReader<T>::records) \
.def("timestamps", &cp::ScalarReader<T>::timestamps) \
.def("ids", &cp::ScalarReader<T>::ids) \
.def("caption", &cp::ScalarReader<T>::caption);
ADD_SCALAR_READER(int);
ADD_SCALAR_READER(float);
ADD_SCALAR_READER(double);
ADD_SCALAR_READER(int64_t);
#undef ADD_SCALAR_READER
#define ADD_SCALAR_WRITER(T) \
py::class_<cp::Scalar<T>>(m, "ScalarWriter__" #T) \
.def("set_caption", &cp::Scalar<T>::SetCaption) \
.def("add_record", &cp::Scalar<T>::AddRecord);
ADD_SCALAR_WRITER(int);
ADD_SCALAR_WRITER(float);
ADD_SCALAR_WRITER(double);
#undef ADD_SCALAR_WRITER
// clang-format on
py::class_<cp::Image>(m, "ImageWriter")
.def("set_caption", &cp::Image::SetCaption)
.def("start_sampling", &cp::Image::StartSampling)
.def("is_sample_taken", &cp::Image::IsSampleTaken)
.def("finish_sampling", &cp::Image::FinishSampling)
.def("set_sample", &cp::Image::SetSample);
py::class_<cp::ImageReader>(m, "ImageReader")
.def("caption", &cp::ImageReader::caption)
.def("num_records", &cp::ImageReader::num_records)
.def("data", &cp::ImageReader::data)
.def("shape", &cp::ImageReader::shape);
} // end pybind } // end pybind
...@@ -32,6 +32,7 @@ public: ...@@ -32,6 +32,7 @@ public:
string::TagEncode(tmp); string::TagEncode(tmp);
auto res = storage_.AddTablet(tmp); auto res = storage_.AddTablet(tmp);
res.SetCaptions(std::vector<std::string>({mode_})); res.SetCaptions(std::vector<std::string>({mode_}));
res.SetTag(mode_, tag);
return res; return res;
} }
...@@ -52,6 +53,8 @@ public: ...@@ -52,6 +53,8 @@ public:
return tmp; return tmp;
} }
const std::string& mode() { return mode_; }
TabletReader tablet(const std::string& tag) { TabletReader tablet(const std::string& tag) {
auto tmp = mode_ + "/" + tag; auto tmp = mode_ + "/" + tag;
string::TagEncode(tmp); string::TagEncode(tmp);
...@@ -62,7 +65,7 @@ public: ...@@ -62,7 +65,7 @@ public:
auto tags = reader_.all_tags(); auto tags = reader_.all_tags();
auto it = auto it =
std::remove_if(tags.begin(), tags.end(), [&](const std::string& tag) { std::remove_if(tags.begin(), tags.end(), [&](const std::string& tag) {
return !TagMatchMode(tag); return !TagMatchMode(tag, mode_);
}); });
tags.erase(it + 1); tags.erase(it + 1);
return tags; return tags;
...@@ -74,8 +77,8 @@ public: ...@@ -74,8 +77,8 @@ public:
CHECK(!tags.empty()); CHECK(!tags.empty());
std::vector<std::string> res; std::vector<std::string> res;
for (const auto& tag : tags) { for (const auto& tag : tags) {
if (TagMatchMode(tag)) { if (TagMatchMode(tag, mode_)) {
res.push_back(GenReadableTag(tag)); res.push_back(GenReadableTag(mode_, tag));
} }
} }
return res; return res;
...@@ -83,17 +86,19 @@ public: ...@@ -83,17 +86,19 @@ public:
StorageReader& storage() { return reader_; } StorageReader& storage() { return reader_; }
protected: static std::string GenReadableTag(const std::string& mode,
bool TagMatchMode(const std::string& tag) { const std::string& tag) {
if (tag.size() <= mode_.size()) return false;
return tag.substr(0, mode_.size()) == mode_;
}
std::string GenReadableTag(const std::string& tag) {
auto tmp = tag; auto tmp = tag;
string::TagDecode(tmp); string::TagDecode(tmp);
return tmp.substr(mode_.size() + 1); // including `/` return tmp.substr(mode.size() + 1); // including `/`
}
static bool TagMatchMode(const std::string& tag, const std::string& mode) {
if (tag.size() <= mode.size()) return false;
return tag.substr(0, mode.size()) == mode;
} }
protected:
private: private:
StorageReader reader_; StorageReader reader_;
std::string mode_{kDefaultMode}; std::string mode_{kDefaultMode};
...@@ -149,7 +154,9 @@ struct Image { ...@@ -149,7 +154,9 @@ struct Image {
Image(Tablet tablet, int num_samples) : writer_(tablet) { Image(Tablet tablet, int num_samples) : writer_(tablet) {
writer_.SetType(Tablet::Type::kImage); writer_.SetType(Tablet::Type::kImage);
// make image's tag as the default caption.
writer_.SetNumSamples(num_samples); writer_.SetNumSamples(num_samples);
SetCaption(tablet.reader().tag());
num_samples_ = num_samples; num_samples_ = num_samples;
} }
void SetCaption(const std::string& c) { void SetCaption(const std::string& c) {
...@@ -168,6 +175,9 @@ struct Image { ...@@ -168,6 +175,9 @@ struct Image {
*/ */
void FinishSampling(); void FinishSampling();
/*
* Just store a tensor with nothing to do with image format.
*/
void SetSample(int index, void SetSample(int index,
const std::vector<shape_t>& shape, const std::vector<shape_t>& shape,
const std::vector<value_t>& data); const std::vector<value_t>& data);
...@@ -186,11 +196,17 @@ struct ImageReader { ...@@ -186,11 +196,17 @@ struct ImageReader {
using value_t = typename Image::value_t; using value_t = typename Image::value_t;
using shape_t = typename Image::shape_t; using shape_t = typename Image::shape_t;
ImageReader(TabletReader tablet) : reader_(tablet) {} ImageReader(const std::string& mode, TabletReader tablet)
: reader_(tablet), mode_{mode} {}
std::string caption() { std::string caption() {
CHECK_EQ(reader_.captions().size(), 1); CHECK_EQ(reader_.captions().size(), 1);
return reader_.captions().front(); auto caption = reader_.captions().front();
if (Reader::TagMatchMode(caption, mode_)) {
return Reader::GenReadableTag(mode_, caption);
}
string::TagDecode(caption);
return caption;
} }
// number of steps. // number of steps.
...@@ -202,6 +218,7 @@ struct ImageReader { ...@@ -202,6 +218,7 @@ struct ImageReader {
private: private:
TabletReader reader_; TabletReader reader_;
std::string mode_;
}; };
} // namespace components } // namespace components
......
...@@ -74,7 +74,7 @@ TEST(Image, test) { ...@@ -74,7 +74,7 @@ TEST(Image, test) {
Reader reader__(dir); Reader reader__(dir);
auto reader = reader__.AsMode("train"); auto reader = reader__.AsMode("train");
auto tablet2read = reader.tablet("image0"); auto tablet2read = reader.tablet("image0");
components::ImageReader image2read(tablet2read); components::ImageReader image2read("train", tablet2read);
CHECK_EQ(image2read.caption(), "this is an image"); CHECK_EQ(image2read.caption(), "this is an image");
CHECK_EQ(image2read.num_records(), num_steps); CHECK_EQ(image2read.num_records(), num_steps);
} }
......
...@@ -31,6 +31,9 @@ class StorageReader(object): ...@@ -31,6 +31,9 @@ class StorageReader(object):
} }
return type2scalar[type](tag) return type2scalar[type](tag)
def image(self, tag):
return self.reader.get_image(tag)
class StorageWriter(object): class StorageWriter(object):
...@@ -50,3 +53,6 @@ class StorageWriter(object): ...@@ -50,3 +53,6 @@ class StorageWriter(object):
'int': self.writer.new_scalar_int, 'int': self.writer.new_scalar_int,
} }
return type2scalar[type](tag) return type2scalar[type](tag)
def image(self, tag, num_samples):
return self.writer.new_image(tag, num_samples)
...@@ -8,11 +8,11 @@ import time ...@@ -8,11 +8,11 @@ import time
class StorageTest(unittest.TestCase): class StorageTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.dir = "./tmp/storage_test" self.dir = "./tmp/storage_test"
def test_read(self):
print 'test write'
self.writer = storage.StorageWriter( self.writer = storage.StorageWriter(
self.dir, sync_cycle=1).as_mode("train") self.dir, sync_cycle=1).as_mode("train")
def test_scalar(self):
print 'test write'
scalar = self.writer.scalar("model/scalar/min") scalar = self.writer.scalar("model/scalar/min")
# scalar.set_caption("model/scalar/min") # scalar.set_caption("model/scalar/min")
for i in range(10): for i in range(10):
...@@ -29,6 +29,27 @@ class StorageTest(unittest.TestCase): ...@@ -29,6 +29,27 @@ class StorageTest(unittest.TestCase):
print 'records', records print 'records', records
print 'ids', ids print 'ids', ids
def test_image(self):
tag = "layer1/layer2/image0"
image_writer = self.writer.image(tag, 10)
num_passes = 10
num_samples = 100
for pass_ in xrange(num_passes):
image_writer.start_sampling()
for ins in xrange(num_samples):
index = image_writer.is_sample_taken()
if index != -1:
shape = [3, 10, 10]
data = np.random.random(shape) * 256
data = np.ndarray.flatten(data)
image_writer.set_sample(index, shape, list(data))
image_writer.finish_sampling()
self.reader = storage.StorageReader(self.dir).as_mode("train")
image_reader = self.reader.image(tag)
self.assertEqual(image_reader.caption(), tag)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -12,4 +12,4 @@ add_library(storage storage.cc storage.h ${PROTO_SRCS} ${PROTO_HDRS}) ...@@ -12,4 +12,4 @@ add_library(storage storage.cc storage.h ${PROTO_SRCS} ${PROTO_HDRS})
add_dependencies(entry storage_proto im) add_dependencies(entry storage_proto im)
add_dependencies(record storage_proto entry) add_dependencies(record storage_proto entry)
add_dependencies(tablet storage_proto) add_dependencies(tablet storage_proto)
add_dependencies(storage storage_proto) add_dependencies(storage storage_proto record tablet entry)
...@@ -78,7 +78,7 @@ struct Storage { ...@@ -78,7 +78,7 @@ struct Storage {
* Save memory to disk. * Save memory to disk.
*/ */
void PersistToDisk(const std::string& dir) { void PersistToDisk(const std::string& dir) {
// LOG(INFO) << "persist to disk " << dir; LOG(INFO) << "persist to disk " << dir;
CHECK(!dir.empty()) << "dir should be set."; CHECK(!dir.empty()) << "dir should be set.";
fs::TryRecurMkdir(dir); fs::TryRecurMkdir(dir);
......
#include "visualdl/storage/tablet.h"
namespace visualdl {
TabletReader Tablet::reader() { return TabletReader(*data_); }
} // namespace visualdl
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
namespace visualdl { namespace visualdl {
struct TabletReader;
/* /*
* Tablet is a helper for operations on storage::Tablet. * Tablet is a helper for operations on storage::Tablet.
*/ */
...@@ -80,6 +82,8 @@ struct Tablet { ...@@ -80,6 +82,8 @@ struct Tablet {
WRITE_GUARD WRITE_GUARD
} }
TabletReader reader();
Storage* parent() const { return x_; } Storage* parent() const { return x_; }
private: private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册