提交 41c22081 编写于 作者: S superjom

add image py test

上级 194e3792
......@@ -43,7 +43,7 @@ add_executable(vl_test
${PROJECT_SOURCE_DIR}/visualdl/utils/concurrency.h
${PROJECT_SOURCE_DIR}/visualdl/utils/filesystem.h
)
target_link_libraries(vl_test sdk storage entry im gtest glog protobuf gflags pthread)
target_link_libraries(vl_test sdk storage entry tablet im gtest glog protobuf gflags pthread)
enable_testing ()
......
#add_library(sdk ${PROJECT_SOURCE_DIR}/visualdl/logic/sdk.cc)
add_library(im ${PROJECT_SOURCE_DIR}/visualdl/logic/im.cc)
add_library(sdk ${PROJECT_SOURCE_DIR}/visualdl/logic/sdk.cc)
add_dependencies(im storage_proto)
......@@ -6,6 +5,6 @@ add_dependencies(sdk entry storage storage_proto)
## pybind
add_library(core SHARED ${PROJECT_SOURCE_DIR}/visualdl/logic/pybind.cc)
add_dependencies(core pybind python im entry storage sdk protobuf glog)
target_link_libraries(core PRIVATE pybind entry python im storage sdk protobuf glog)
add_dependencies(core pybind python im entry tablet storage sdk protobuf glog)
target_link_libraries(core PRIVATE pybind entry python im tablet storage sdk protobuf glog)
set_target_properties(core PROPERTIES PREFIX "" SUFFIX ".so")
......@@ -11,28 +11,7 @@ namespace cp = visualdl::components;
PYBIND11_PLUGIN(core) {
py::module m("core", "C++ core of VisualDL");
#define ADD_SCALAR(T) \
py::class_<cp::ScalarReader<T>>(m, "ScalarReader__" #T) \
.def("records", &cp::ScalarReader<T>::records) \
.def("timestamps", &cp::ScalarReader<T>::timestamps) \
.def("ids", &cp::ScalarReader<T>::ids) \
.def("caption", &cp::ScalarReader<T>::caption);
ADD_SCALAR(int);
ADD_SCALAR(float);
ADD_SCALAR(double);
ADD_SCALAR(int64_t);
#undef ADD_SCALAR
#define ADD_SCALAR_WRITER(T) \
py::class_<cp::Scalar<T>>(m, "ScalarWriter__" #T) \
.def("set_caption", &cp::Scalar<T>::SetCaption) \
.def("add_record", &cp::Scalar<T>::AddRecord);
ADD_SCALAR_WRITER(int);
ADD_SCALAR_WRITER(float);
ADD_SCALAR_WRITER(double);
#undef ADD_SCALAR_WRITER
#define ADD_SCALAR(T) \
#define READER_ADD_SCALAR(T) \
.def("get_scalar_" #T, [](vs::Reader& self, const std::string& tag) { \
auto tablet = self.tablet(tag); \
return vs::components::ScalarReader<T>(std::move(tablet)); \
......@@ -46,13 +25,17 @@ PYBIND11_PLUGIN(core) {
.def("modes", [](vs::Reader& self) { return self.storage().modes(); })
.def("tags", &vs::Reader::tags)
// clang-format off
ADD_SCALAR(float)
ADD_SCALAR(double)
ADD_SCALAR(int);
// clang-format on
#undef ADD_SCALAR
READER_ADD_SCALAR(float)
READER_ADD_SCALAR(double)
READER_ADD_SCALAR(int)
// clang-format on
.def("get_image", [](vs::Reader& self, const std::string& tag) {
auto tablet = self.tablet(tag);
return vs::components::ImageReader(self.mode(), tablet);
});
#undef READER_ADD_SCALAR
#define ADD_SCALAR(T) \
#define WRITER_ADD_SCALAR(T) \
.def("new_scalar_" #T, [](vs::Writer& self, const std::string& tag) { \
auto tablet = self.AddTablet(tag); \
return cp::Scalar<T>(tablet); \
......@@ -65,10 +48,50 @@ PYBIND11_PLUGIN(core) {
})
.def("as_mode", &vs::Writer::AsMode)
// clang-format off
ADD_SCALAR(float)
ADD_SCALAR(double)
ADD_SCALAR(int);
// clang-format on
#undef ADD_SCALAR
WRITER_ADD_SCALAR(float)
WRITER_ADD_SCALAR(double)
WRITER_ADD_SCALAR(int)
// clang-format on
.def("new_image",
[](vs::Writer& self, const std::string& tag, int num_samples) {
auto tablet = self.AddTablet(tag);
return vs::components::Image(tablet, num_samples);
});
//------------------- components --------------------
#define ADD_SCALAR_READER(T) \
py::class_<cp::ScalarReader<T>>(m, "ScalarReader__" #T) \
.def("records", &cp::ScalarReader<T>::records) \
.def("timestamps", &cp::ScalarReader<T>::timestamps) \
.def("ids", &cp::ScalarReader<T>::ids) \
.def("caption", &cp::ScalarReader<T>::caption);
ADD_SCALAR_READER(int);
ADD_SCALAR_READER(float);
ADD_SCALAR_READER(double);
ADD_SCALAR_READER(int64_t);
#undef ADD_SCALAR_READER
#define ADD_SCALAR_WRITER(T) \
py::class_<cp::Scalar<T>>(m, "ScalarWriter__" #T) \
.def("set_caption", &cp::Scalar<T>::SetCaption) \
.def("add_record", &cp::Scalar<T>::AddRecord);
ADD_SCALAR_WRITER(int);
ADD_SCALAR_WRITER(float);
ADD_SCALAR_WRITER(double);
#undef ADD_SCALAR_WRITER
// clang-format on
py::class_<cp::Image>(m, "ImageWriter")
.def("set_caption", &cp::Image::SetCaption)
.def("start_sampling", &cp::Image::StartSampling)
.def("is_sample_taken", &cp::Image::IsSampleTaken)
.def("finish_sampling", &cp::Image::FinishSampling)
.def("set_sample", &cp::Image::SetSample);
py::class_<cp::ImageReader>(m, "ImageReader")
.def("caption", &cp::ImageReader::caption)
.def("num_records", &cp::ImageReader::num_records)
.def("data", &cp::ImageReader::data)
.def("shape", &cp::ImageReader::shape);
} // end pybind
......@@ -32,6 +32,7 @@ public:
string::TagEncode(tmp);
auto res = storage_.AddTablet(tmp);
res.SetCaptions(std::vector<std::string>({mode_}));
res.SetTag(mode_, tag);
return res;
}
......@@ -52,6 +53,8 @@ public:
return tmp;
}
const std::string& mode() { return mode_; }
TabletReader tablet(const std::string& tag) {
auto tmp = mode_ + "/" + tag;
string::TagEncode(tmp);
......@@ -62,7 +65,7 @@ public:
auto tags = reader_.all_tags();
auto it =
std::remove_if(tags.begin(), tags.end(), [&](const std::string& tag) {
return !TagMatchMode(tag);
return !TagMatchMode(tag, mode_);
});
tags.erase(it + 1);
return tags;
......@@ -74,8 +77,8 @@ public:
CHECK(!tags.empty());
std::vector<std::string> res;
for (const auto& tag : tags) {
if (TagMatchMode(tag)) {
res.push_back(GenReadableTag(tag));
if (TagMatchMode(tag, mode_)) {
res.push_back(GenReadableTag(mode_, tag));
}
}
return res;
......@@ -83,17 +86,19 @@ public:
StorageReader& storage() { return reader_; }
protected:
bool TagMatchMode(const std::string& tag) {
if (tag.size() <= mode_.size()) return false;
return tag.substr(0, mode_.size()) == mode_;
}
std::string GenReadableTag(const std::string& tag) {
static std::string GenReadableTag(const std::string& mode,
const std::string& tag) {
auto tmp = tag;
string::TagDecode(tmp);
return tmp.substr(mode_.size() + 1); // including `/`
return tmp.substr(mode.size() + 1); // including `/`
}
static bool TagMatchMode(const std::string& tag, const std::string& mode) {
if (tag.size() <= mode.size()) return false;
return tag.substr(0, mode.size()) == mode;
}
protected:
private:
StorageReader reader_;
std::string mode_{kDefaultMode};
......@@ -149,7 +154,9 @@ struct Image {
Image(Tablet tablet, int num_samples) : writer_(tablet) {
writer_.SetType(Tablet::Type::kImage);
// make image's tag as the default caption.
writer_.SetNumSamples(num_samples);
SetCaption(tablet.reader().tag());
num_samples_ = num_samples;
}
void SetCaption(const std::string& c) {
......@@ -168,6 +175,9 @@ struct Image {
*/
void FinishSampling();
/*
* Just store a tensor with nothing to do with image format.
*/
void SetSample(int index,
const std::vector<shape_t>& shape,
const std::vector<value_t>& data);
......@@ -186,11 +196,17 @@ struct ImageReader {
using value_t = typename Image::value_t;
using shape_t = typename Image::shape_t;
ImageReader(TabletReader tablet) : reader_(tablet) {}
ImageReader(const std::string& mode, TabletReader tablet)
: reader_(tablet), mode_{mode} {}
std::string caption() {
CHECK_EQ(reader_.captions().size(), 1);
return reader_.captions().front();
auto caption = reader_.captions().front();
if (Reader::TagMatchMode(caption, mode_)) {
return Reader::GenReadableTag(mode_, caption);
}
string::TagDecode(caption);
return caption;
}
// number of steps.
......@@ -202,6 +218,7 @@ struct ImageReader {
private:
TabletReader reader_;
std::string mode_;
};
} // namespace components
......
......@@ -74,7 +74,7 @@ TEST(Image, test) {
Reader reader__(dir);
auto reader = reader__.AsMode("train");
auto tablet2read = reader.tablet("image0");
components::ImageReader image2read(tablet2read);
components::ImageReader image2read("train", tablet2read);
CHECK_EQ(image2read.caption(), "this is an image");
CHECK_EQ(image2read.num_records(), num_steps);
}
......
......@@ -31,6 +31,9 @@ class StorageReader(object):
}
return type2scalar[type](tag)
def image(self, tag):
return self.reader.get_image(tag)
class StorageWriter(object):
......@@ -50,3 +53,6 @@ class StorageWriter(object):
'int': self.writer.new_scalar_int,
}
return type2scalar[type](tag)
def image(self, tag, num_samples):
return self.writer.new_image(tag, num_samples)
......@@ -8,11 +8,11 @@ import time
class StorageTest(unittest.TestCase):
def setUp(self):
self.dir = "./tmp/storage_test"
def test_read(self):
print 'test write'
self.writer = storage.StorageWriter(
self.dir, sync_cycle=1).as_mode("train")
def test_scalar(self):
print 'test write'
scalar = self.writer.scalar("model/scalar/min")
# scalar.set_caption("model/scalar/min")
for i in range(10):
......@@ -29,6 +29,27 @@ class StorageTest(unittest.TestCase):
print 'records', records
print 'ids', ids
def test_image(self):
tag = "layer1/layer2/image0"
image_writer = self.writer.image(tag, 10)
num_passes = 10
num_samples = 100
for pass_ in xrange(num_passes):
image_writer.start_sampling()
for ins in xrange(num_samples):
index = image_writer.is_sample_taken()
if index != -1:
shape = [3, 10, 10]
data = np.random.random(shape) * 256
data = np.ndarray.flatten(data)
image_writer.set_sample(index, shape, list(data))
image_writer.finish_sampling()
self.reader = storage.StorageReader(self.dir).as_mode("train")
image_reader = self.reader.image(tag)
self.assertEqual(image_reader.caption(), tag)
if __name__ == '__main__':
unittest.main()
......@@ -12,4 +12,4 @@ add_library(storage storage.cc storage.h ${PROTO_SRCS} ${PROTO_HDRS})
add_dependencies(entry storage_proto im)
add_dependencies(record storage_proto entry)
add_dependencies(tablet storage_proto)
add_dependencies(storage storage_proto)
add_dependencies(storage storage_proto record tablet entry)
......@@ -78,7 +78,7 @@ struct Storage {
* Save memory to disk.
*/
void PersistToDisk(const std::string& dir) {
// LOG(INFO) << "persist to disk " << dir;
LOG(INFO) << "persist to disk " << dir;
CHECK(!dir.empty()) << "dir should be set.";
fs::TryRecurMkdir(dir);
......
#include "visualdl/storage/tablet.h"
namespace visualdl {
TabletReader Tablet::reader() { return TabletReader(*data_); }
} // namespace visualdl
......@@ -10,6 +10,8 @@
namespace visualdl {
struct TabletReader;
/*
* Tablet is a helper for operations on storage::Tablet.
*/
......@@ -80,6 +82,8 @@ struct Tablet {
WRITE_GUARD
}
TabletReader reader();
Storage* parent() const { return x_; }
private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册