From 151afc74548dbee1e5537dbbb768aba58a9e6f95 Mon Sep 17 00:00:00 2001 From: superjom Date: Tue, 21 Nov 2017 17:34:33 +0800 Subject: [PATCH] add scalar interface support --- visualdl/backend/logic/im.cc | 25 ++++- visualdl/backend/logic/im.h | 8 +- visualdl/backend/logic/pybind.cc | 51 ++++++++--- visualdl/backend/logic/sdk.cc | 122 +++++++++++++++++++++++-- visualdl/backend/logic/sdk.h | 90 ++++++++++++++++-- visualdl/backend/logic/sdk.hpp | 33 ------- visualdl/backend/storage/storage.proto | 39 +++++++- visualdl/backend/test.py | 36 +++++++- 8 files changed, 326 insertions(+), 78 deletions(-) delete mode 100644 visualdl/backend/logic/sdk.hpp diff --git a/visualdl/backend/logic/im.cc b/visualdl/backend/logic/im.cc index 241ff469..e53f3a24 100644 --- a/visualdl/backend/logic/im.cc +++ b/visualdl/backend/logic/im.cc @@ -47,9 +47,16 @@ void InformationMaintainer::AddRecord(const std::string &tag, auto num_records = tablet->num_records(); const auto num_samples = tablet->num_samples(); - const auto offset = ReserviorSample(num_samples, num_records + 1); - if (offset < 0) - return; + + int offset; + // use reservoir sampling or not + if (num_samples > 0) { + offset = ReserviorSample(num_samples, num_records + 1); + if (offset < 0) + return; + } else { + offset = num_records; + } storage::Record *record; if (offset >= num_records) { @@ -59,13 +66,21 @@ void InformationMaintainer::AddRecord(const std::string &tag, } *record = data; - tablet->set_num_records(num_records + 1); } +void InformationMaintainer::Clear() { + auto* data = storage().mutable_data(); + data->clear_tablets(); + data->clear_dir(); + data->clear_timestamp(); +} + void InformationMaintainer::PersistToDisk() { CHECK(!storage_.data().dir().empty()) << "path of storage should be set"; - storage_.Save(storage_.data().dir()); + // TODO make dir first + //MakeDir(storage_.data().dir()); + storage_.Save(storage_.data().dir() + "/storage.pb"); } } // namespace visualdl diff --git a/visualdl/backend/logic/im.h b/visualdl/backend/logic/im.h index b94102db..dfb73a83 100644 --- a/visualdl/backend/logic/im.h +++ b/visualdl/backend/logic/im.h @@ -29,14 +29,18 @@ public: /* * @tag: tag of the target Tablet. - * @type: type of target Tablet. - * @data: storage Record. + * @record: a record * * NOTE pass in the serialized protobuf message will trigger copying, but * simpler to support different Tablet data formats. */ void AddRecord(const std::string &tag, const storage::Record &record); + /* + * delete all the information. + */ + void Clear(); + /* * Save the Storage Protobuf to disk. */ diff --git a/visualdl/backend/logic/pybind.cc b/visualdl/backend/logic/pybind.cc index 242c25d3..2256542d 100644 --- a/visualdl/backend/logic/pybind.cc +++ b/visualdl/backend/logic/pybind.cc @@ -1,5 +1,6 @@ #include #include +#include #include "visualdl/backend/logic/sdk.h" @@ -10,19 +11,6 @@ PYBIND11_MODULE(core, m) { m.doc() = "visualdl python core API"; py::class_(m, "Tablet") - // interfaces for components - .def("add_scalar_int32", - &vs::TabletHelper::AddScalarRecord, - "add a scalar int32 record", - pybind11::arg("id"), - pybind11::arg("value")) - .def("add_scalar_int64", - &vs::TabletHelper::AddScalarRecord, - "add a scalar int64 record", - pybind11::arg("id"), - pybind11::arg("value")) - .def("add_scalar_float", &vs::TabletHelper::AddScalarRecord) - .def("add_scalar_double", &vs::TabletHelper::AddScalarRecord) // other member setter and getter .def("record_buffer", &vs::TabletHelper::record_buffer) .def("records_size", &vs::TabletHelper::records_size) @@ -30,7 +18,23 @@ PYBIND11_MODULE(core, m) { .def("human_readable_buffer", &vs::TabletHelper::human_readable_buffer) .def("set_buffer", (void (vs::TabletHelper::*)(const std::string&)) & - vs::TabletHelper::SetBuffer); + vs::TabletHelper::SetBuffer) + // scalar interface + .def("as_int32_scalar", + [](const vs::TabletHelper& self) { + return vs::components::ScalarHelper(&self.data()); + }) + .def("as_int64_scalar", + [](const vs::TabletHelper& self) { + return vs::components::ScalarHelper(&self.data()); + }) + .def("as_float_scalar", + [](const vs::TabletHelper& self) { + return vs::components::ScalarHelper(&self.data()); + }) + .def("as_double_scalar", [](const vs::TabletHelper& self) { + return vs::components::ScalarHelper(&self.data()); + }); py::class_(m, "Storage") .def("timestamp", &vs::StorageHelper::timestamp) @@ -46,7 +50,24 @@ PYBIND11_MODULE(core, m) { py::class_(m, "Im") .def("storage", &vs::ImHelper::storage) .def("tablet", &vs::ImHelper::tablet) - .def("add_tablet", &vs::ImHelper::AddTablet); + .def("add_tablet", &vs::ImHelper::AddTablet) + .def("persist_to_disk", &vs::ImHelper::PersistToDisk) + .def("clear_tablets", &vs::ImHelper::ClearTablets); m.def("im", &vs::get_im, "global information-maintainer object."); + +// interfaces for components +#define ADD_SCALAR_TYPED_INTERFACE(T, name__) \ + py::class_>(m, #name__) \ + .def("add_record", &vs::components::ScalarHelper::AddRecord) \ + .def("set_captions", &vs::components::ScalarHelper::SetCaptions) \ + .def("get_records", &vs::components::ScalarHelper::GetRecords) \ + .def("get_captions", &vs::components::ScalarHelper::GetCaptions) \ + .def("get_ids", &vs::components::ScalarHelper::GetIds) \ + .def("get_timestamps", &vs::components::ScalarHelper::GetTimestamps); + ADD_SCALAR_TYPED_INTERFACE(int32_t, ScalarInt32); + ADD_SCALAR_TYPED_INTERFACE(int64_t, ScalarInt64); + ADD_SCALAR_TYPED_INTERFACE(float, ScalarFloat); + ADD_SCALAR_TYPED_INTERFACE(double, ScalarDouble); +#undef ADD_SCALAR_TYPED_INTERFACE } diff --git a/visualdl/backend/logic/sdk.cc b/visualdl/backend/logic/sdk.cc index 10406da6..a2c7d638 100644 --- a/visualdl/backend/logic/sdk.cc +++ b/visualdl/backend/logic/sdk.cc @@ -1,12 +1,13 @@ -#include #include "visualdl/backend/logic/sdk.h" +#include namespace visualdl { -#define IMPL_ENTRY_SET_OR_ADD(method__, ctype__, dtype__, opr__) \ - template <> void Entry::method__(ctype__ v) { \ - entry->set_dtype(storage::DataType::dtype__); \ - entry->opr__(v); \ +#define IMPL_ENTRY_SET_OR_ADD(method__, ctype__, dtype__, opr__) \ + template <> \ + void EntryHelper::method__(ctype__ v) { \ + entry->set_dtype(storage::DataType::dtype__); \ + entry->opr__(v); \ } IMPL_ENTRY_SET_OR_ADD(Set, int32_t, kInt32, set_i32); @@ -16,9 +17,36 @@ IMPL_ENTRY_SET_OR_ADD(Set, float, kFloat, set_f); IMPL_ENTRY_SET_OR_ADD(Set, double, kDouble, set_d); IMPL_ENTRY_SET_OR_ADD(Add, int32_t, kInt32s, add_i32s); IMPL_ENTRY_SET_OR_ADD(Add, int64_t, kInt64s, add_i64s); -IMPL_ENTRY_SET_OR_ADD(Add, bool, kBools, add_bs); IMPL_ENTRY_SET_OR_ADD(Add, float, kFloats, add_fs); IMPL_ENTRY_SET_OR_ADD(Add, double, kDoubles, add_ds); +IMPL_ENTRY_SET_OR_ADD(Add, std::string, kStrings, add_ss); +IMPL_ENTRY_SET_OR_ADD(Add, bool, kBools, add_bs); + +#define IMPL_ENTRY_GET(T, fieldname__) \ + template <> \ + T EntryHelper::Get() const { \ + return entry->fieldname__(); \ + } +IMPL_ENTRY_GET(int32_t, i32); +IMPL_ENTRY_GET(int64_t, i64); +IMPL_ENTRY_GET(float, f); +IMPL_ENTRY_GET(double, d); +IMPL_ENTRY_GET(std::string, s); +IMPL_ENTRY_GET(bool, b); + +#define IMPL_ENTRY_GET_MULTI(T, fieldname__) \ + template <> \ + std::vector EntryHelper::GetMulti() const { \ + return std::vector(entry->fieldname__().begin(), \ + entry->fieldname__().end()); \ + } + +IMPL_ENTRY_GET_MULTI(int32_t, i32s); +IMPL_ENTRY_GET_MULTI(int64_t, i64s); +IMPL_ENTRY_GET_MULTI(float, fs); +IMPL_ENTRY_GET_MULTI(double, ds); +IMPL_ENTRY_GET_MULTI(std::string, ss); +IMPL_ENTRY_GET_MULTI(bool, bs); std::string StorageHelper::human_readable_buffer() const { std::string buffer; @@ -32,4 +60,84 @@ std::string TabletHelper::human_readable_buffer() const { return buffer; } -} // namespace visualdl +void ImHelper::PersistToDisk() const { + InformationMaintainer::Global().PersistToDisk(); +} + +// implementations for components +namespace components { + +template +void ScalarHelper::SetCaptions(const std::vector &captions) { + for (int i = 0; i < captions.size(); i++) { + data_->add_captions(captions[i]); + } +} + +template +void ScalarHelper::AddRecord(int id, const std::vector &values) { + CHECK_NOTNULL(data_); + CHECK_GT(data_->captions_size(), 0UL) << "captions should be set first"; + CHECK_EQ(data_->captions_size(), values.size()) + << "number of values in a record should be compatible with the " + "captions"; + // add record data + auto *record = data_->add_records(); + auto *data = record->add_data(); + EntryHelper entry_helper(data); + for (auto v : values) { + entry_helper.Add(v); + } + // set record id + record->set_id(id); + // set record timestamp + record->set_timestamp(time(NULL)); +} + +template +std::vector> ScalarHelper::GetRecords() const { + std::vector> result; + EntryHelper entry_helper; + for (int i = 0; i < data_->records_size(); i++) { + auto *entry = data_->mutable_records(i)->mutable_data(0); + entry_helper(entry); + auto datas = entry_helper.GetMulti(); + result.push_back(std::move(datas)); + } + return result; +} + +template +std::vector ScalarHelper::GetIds() const { + CHECK_NOTNULL(data_); + std::vector result; + for (int i = 0; i < data_->records_size(); i++) { + result.push_back(data_->records(i).id()); + } + return result; +} + +template +std::vector ScalarHelper::GetTimestamps() const { + CHECK_NOTNULL(data_); + std::vector result; + for (int i = 0; i < data_->records_size(); i++) { + result.push_back(data_->records(i).timestamp()); + } + return result; +} + +template +std::vector ScalarHelper::GetCaptions() const { + return std::vector(data_->captions().begin(), + data_->captions().end()); +} + +template class ScalarHelper; +template class ScalarHelper; +template class ScalarHelper; +template class ScalarHelper; + +} // namespace components + +} // namespace visualdl diff --git a/visualdl/backend/logic/sdk.h b/visualdl/backend/logic/sdk.h index c99958a7..8392d596 100644 --- a/visualdl/backend/logic/sdk.h +++ b/visualdl/backend/logic/sdk.h @@ -1,29 +1,67 @@ #ifndef VISUALDL_BACKEND_LOGIC_SDK_H #define VISUALDL_BACKEND_LOGIC_SDK_H -#include "visualdl/backend/logic/im.h" +#include +#include #include +#include "visualdl/backend/logic/im.h" + namespace visualdl { +/* + * Utility helper for storage::Entry. + */ +template +struct EntryHelper { + // use pointer to avoid copy + storage::Entry *entry{nullptr}; + + EntryHelper() {} + explicit EntryHelper(storage::Entry *entry) : entry(entry) {} + void operator()(storage::Entry *entry) { this->entry = entry; } + + /* + * Set a single value. + */ + void Set(T v); + + /* + * Add a value to repeated message field. + */ + void Add(T v); + + /* + * Get a single value. + */ + T Get() const; + + /* + * Get repeated field. + */ + std::vector GetMulti() const; +}; + class TabletHelper { public: - // method for each components - template - void AddScalarRecord(int id, T value); - // basic member getter and setter - std::string record_buffer(int idx) const { return data_->records(idx).SerializeAsString(); } + std::string record_buffer(int idx) const { + return data_->records(idx).SerializeAsString(); + } size_t records_size() const { return data_->records_size(); } std::string buffer() const { return data_->SerializeAsString(); } std::string human_readable_buffer() const; void SetBuffer(const storage::Tablet &t) { *data_ = t; } void SetBuffer(const std::string &b) { data_->ParseFromString(b); } + storage::Tablet &data() const { return *data_; } // constructor that enable concurrency. TabletHelper(storage::Tablet *t) : data_(t) {} // data updater that resuage of one instance. - TabletHelper &operator()(storage::Tablet *t) { data_ = t; return *this; } + TabletHelper &operator()(storage::Tablet *t) { + data_ = t; + return *this; + } private: storage::Tablet *data_; @@ -66,14 +104,46 @@ public: return TabletHelper( InformationMaintainer::Global().AddTablet(tag, num_samples)); } + void ClearTablets() { + InformationMaintainer::Global().storage().mutable_data()->clear_tablets(); + } + + void PersistToDisk() const; }; -static ImHelper& get_im() { +namespace components { + +/* + * Read and write support for a Scalar component. + */ +template +class ScalarHelper { +public: + ScalarHelper(storage::Tablet *tablet) : data_(tablet) {} + + void SetCaptions(const std::vector &captions); + + void AddRecord(int id, const std::vector &values); + + std::vector> GetRecords() const; + + std::vector GetIds() const; + + std::vector GetTimestamps() const; + + std::vector GetCaptions() const; + +private: + storage::Tablet *data_; +}; + +} // namespace components + +static ImHelper &get_im() { static ImHelper im; return im; } -} // namespace visualdl +} // namespace visualdl -#include "visualdl/backend/logic/sdk.hpp" #endif // VISUALDL_BACKEND_LOGIC_SDK_H diff --git a/visualdl/backend/logic/sdk.hpp b/visualdl/backend/logic/sdk.hpp deleted file mode 100644 index 517abe3f..00000000 --- a/visualdl/backend/logic/sdk.hpp +++ /dev/null @@ -1,33 +0,0 @@ -#include "visualdl/backend/logic/im.h" - -namespace visualdl { - -/* - * Utility helper for storage::Entry. - */ -template struct Entry { - // use pointer to avoid copy - storage::Entry *entry{nullptr}; - - Entry(storage::Entry *entry) : entry(entry) {} - - /* - * Set a single value. - */ - void Set(T v); - - /* - * Add a value to repeated message field. - */ - void Add(T v); -}; - -template -void TabletHelper::AddScalarRecord(int id, T value) { - auto* record = data_->add_records(); - record->set_id(id); - Entry entry_helper(record->mutable_data()); - entry_helper.Set(value); -} - -} // namespace visualdl diff --git a/visualdl/backend/storage/storage.proto b/visualdl/backend/storage/storage.proto index 2b5729f8..9c0f9d5c 100644 --- a/visualdl/backend/storage/storage.proto +++ b/visualdl/backend/storage/storage.proto @@ -37,11 +37,43 @@ message Entry { repeated float fs = 9; repeated double ds = 10; repeated int32 i32s = 11; - repeated bool bs = 12; + repeated string ss = 12; + repeated bool bs = 13; } +/* +The Record proto is designed to represent any data structure for any component, for +example, to store a record of Scalar component + +Record { + // training error is 0.1, testing error is 0.2 + data = [0.1, 0.2], + timestamp = xxxx, + // step id + id = xxxx +} + +to store a record of Image component + +Record { + // RBG pixel weights of a image + data = [[0.1, 0.2, ...], [...], [...]], + timestamp = xxxx, + // image shape + meta = [100, 200] +} + +for other complex structure which have more fields than `timestamp`, `id` and meta, the additional fields +can store in the `data` field, for it can store a list of values in different basic data types. + +A component handlers in logic layer should know how to write or load a record for the corresponding component. +*/ message Record { - Entry data = 1; + // one record might have multiple fields, one specific component should know how + // to parse the records. + repeated Entry data = 1; + // NOTE the timestamp, id, dtype might be useless for that all the meta info can + // be stored in `data` field. int64 timestamp = 2; // store the count of writing operations to the tablet. int64 id = 3; @@ -76,6 +108,9 @@ message Tablet { Entry meta = 5; // the unique identification for this `Tablet`. string tag = 6; + // one tablet might have multiple captions, for example, a scalar component might have + // two plots labeled "train" and "test". + repeated string captions = 7; } /* diff --git a/visualdl/backend/test.py b/visualdl/backend/test.py index 8755ae7f..e7c1add3 100644 --- a/visualdl/backend/test.py +++ b/visualdl/backend/test.py @@ -1,5 +1,6 @@ import sys import unittest +import numpy as np sys.path.append('../../build') import core @@ -9,6 +10,7 @@ im = core.im() class StorageTester(unittest.TestCase): def setUp(self): + im.clear_tablets() self.storage = im.storage() def test_size(self): @@ -34,15 +36,41 @@ class StorageTester(unittest.TestCase): class TabletTester(unittest.TestCase): def setUp(self): + im.clear_tablets() self.tablet = im.add_tablet("tag101", 20) - def test_add_scalar(self): - self.tablet.add_scalar_float(1, 0.3) - self.assertEqual(self.tablet.records_size(), 1) - def test_human_readable_buffer(self): print self.tablet.human_readable_buffer() + def test_scalar(self): + scalar = self.tablet.as_float_scalar() + py_captions = ["train", "test"] + step_ids = [10, 20, 30] + py_records = [ + [0.1, 0.2], + [0.2, 0.3], + [0.3, 0.4] + ] + + scalar.set_captions(py_captions) + for i in range(len(py_records)): + scalar.add_record(step_ids[i], py_records[i]) + + records = scalar.get_records() + ids = scalar.get_ids() + for i in range(len(py_records)): + self.assertTrue(np.isclose(py_records[i], records[i]).all()) + self.assertEqual(step_ids[i], ids[i]) + + +class ImTester(unittest.TestCase): + def test_persist(self): + im.clear_tablets() + tablet = im.add_tablet("tab0", 111) + self.assertEqual(im.storage().tablets_size(), 1) + im.storage().set_dir("./1") + im.persist_to_disk() + if __name__ == '__main__': unittest.main() -- GitLab