提交 151afc74 编写于 作者: S superjom

add scalar interface support

上级 8087f404
...@@ -47,9 +47,16 @@ void InformationMaintainer::AddRecord(const std::string &tag, ...@@ -47,9 +47,16 @@ void InformationMaintainer::AddRecord(const std::string &tag,
auto num_records = tablet->num_records(); auto num_records = tablet->num_records();
const auto num_samples = tablet->num_samples(); const auto num_samples = tablet->num_samples();
const auto offset = ReserviorSample(num_samples, num_records + 1);
if (offset < 0) int offset;
return; // use reservoir sampling or not
if (num_samples > 0) {
offset = ReserviorSample(num_samples, num_records + 1);
if (offset < 0)
return;
} else {
offset = num_records;
}
storage::Record *record; storage::Record *record;
if (offset >= num_records) { if (offset >= num_records) {
...@@ -59,13 +66,21 @@ void InformationMaintainer::AddRecord(const std::string &tag, ...@@ -59,13 +66,21 @@ void InformationMaintainer::AddRecord(const std::string &tag,
} }
*record = data; *record = data;
tablet->set_num_records(num_records + 1); tablet->set_num_records(num_records + 1);
} }
void InformationMaintainer::Clear() {
auto* data = storage().mutable_data();
data->clear_tablets();
data->clear_dir();
data->clear_timestamp();
}
void InformationMaintainer::PersistToDisk() { void InformationMaintainer::PersistToDisk() {
CHECK(!storage_.data().dir().empty()) << "path of storage should be set"; CHECK(!storage_.data().dir().empty()) << "path of storage should be set";
storage_.Save(storage_.data().dir()); // TODO make dir first
//MakeDir(storage_.data().dir());
storage_.Save(storage_.data().dir() + "/storage.pb");
} }
} // namespace visualdl } // namespace visualdl
...@@ -29,14 +29,18 @@ public: ...@@ -29,14 +29,18 @@ public:
/* /*
* @tag: tag of the target Tablet. * @tag: tag of the target Tablet.
* @type: type of target Tablet. * @record: a record
* @data: storage Record.
* *
* NOTE pass in the serialized protobuf message will trigger copying, but * NOTE pass in the serialized protobuf message will trigger copying, but
* simpler to support different Tablet data formats. * simpler to support different Tablet data formats.
*/ */
void AddRecord(const std::string &tag, const storage::Record &record); void AddRecord(const std::string &tag, const storage::Record &record);
/*
* delete all the information.
*/
void Clear();
/* /*
* Save the Storage Protobuf to disk. * Save the Storage Protobuf to disk.
*/ */
......
#include <ctype.h> #include <ctype.h>
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "visualdl/backend/logic/sdk.h" #include "visualdl/backend/logic/sdk.h"
...@@ -10,19 +11,6 @@ PYBIND11_MODULE(core, m) { ...@@ -10,19 +11,6 @@ PYBIND11_MODULE(core, m) {
m.doc() = "visualdl python core API"; m.doc() = "visualdl python core API";
py::class_<vs::TabletHelper>(m, "Tablet") py::class_<vs::TabletHelper>(m, "Tablet")
// interfaces for components
.def("add_scalar_int32",
&vs::TabletHelper::AddScalarRecord<int32_t>,
"add a scalar int32 record",
pybind11::arg("id"),
pybind11::arg("value"))
.def("add_scalar_int64",
&vs::TabletHelper::AddScalarRecord<int64_t>,
"add a scalar int64 record",
pybind11::arg("id"),
pybind11::arg("value"))
.def("add_scalar_float", &vs::TabletHelper::AddScalarRecord<float>)
.def("add_scalar_double", &vs::TabletHelper::AddScalarRecord<double>)
// other member setter and getter // other member setter and getter
.def("record_buffer", &vs::TabletHelper::record_buffer) .def("record_buffer", &vs::TabletHelper::record_buffer)
.def("records_size", &vs::TabletHelper::records_size) .def("records_size", &vs::TabletHelper::records_size)
...@@ -30,7 +18,23 @@ PYBIND11_MODULE(core, m) { ...@@ -30,7 +18,23 @@ PYBIND11_MODULE(core, m) {
.def("human_readable_buffer", &vs::TabletHelper::human_readable_buffer) .def("human_readable_buffer", &vs::TabletHelper::human_readable_buffer)
.def("set_buffer", .def("set_buffer",
(void (vs::TabletHelper::*)(const std::string&)) & (void (vs::TabletHelper::*)(const std::string&)) &
vs::TabletHelper::SetBuffer); vs::TabletHelper::SetBuffer)
// scalar interface
.def("as_int32_scalar",
[](const vs::TabletHelper& self) {
return vs::components::ScalarHelper<int32_t>(&self.data());
})
.def("as_int64_scalar",
[](const vs::TabletHelper& self) {
return vs::components::ScalarHelper<int64_t>(&self.data());
})
.def("as_float_scalar",
[](const vs::TabletHelper& self) {
return vs::components::ScalarHelper<float>(&self.data());
})
.def("as_double_scalar", [](const vs::TabletHelper& self) {
return vs::components::ScalarHelper<double>(&self.data());
});
py::class_<vs::StorageHelper>(m, "Storage") py::class_<vs::StorageHelper>(m, "Storage")
.def("timestamp", &vs::StorageHelper::timestamp) .def("timestamp", &vs::StorageHelper::timestamp)
...@@ -46,7 +50,24 @@ PYBIND11_MODULE(core, m) { ...@@ -46,7 +50,24 @@ PYBIND11_MODULE(core, m) {
py::class_<vs::ImHelper>(m, "Im") py::class_<vs::ImHelper>(m, "Im")
.def("storage", &vs::ImHelper::storage) .def("storage", &vs::ImHelper::storage)
.def("tablet", &vs::ImHelper::tablet) .def("tablet", &vs::ImHelper::tablet)
.def("add_tablet", &vs::ImHelper::AddTablet); .def("add_tablet", &vs::ImHelper::AddTablet)
.def("persist_to_disk", &vs::ImHelper::PersistToDisk)
.def("clear_tablets", &vs::ImHelper::ClearTablets);
m.def("im", &vs::get_im, "global information-maintainer object."); m.def("im", &vs::get_im, "global information-maintainer object.");
// interfaces for components
#define ADD_SCALAR_TYPED_INTERFACE(T, name__) \
py::class_<vs::components::ScalarHelper<T>>(m, #name__) \
.def("add_record", &vs::components::ScalarHelper<T>::AddRecord) \
.def("set_captions", &vs::components::ScalarHelper<T>::SetCaptions) \
.def("get_records", &vs::components::ScalarHelper<T>::GetRecords) \
.def("get_captions", &vs::components::ScalarHelper<T>::GetCaptions) \
.def("get_ids", &vs::components::ScalarHelper<T>::GetIds) \
.def("get_timestamps", &vs::components::ScalarHelper<T>::GetTimestamps);
ADD_SCALAR_TYPED_INTERFACE(int32_t, ScalarInt32);
ADD_SCALAR_TYPED_INTERFACE(int64_t, ScalarInt64);
ADD_SCALAR_TYPED_INTERFACE(float, ScalarFloat);
ADD_SCALAR_TYPED_INTERFACE(double, ScalarDouble);
#undef ADD_SCALAR_TYPED_INTERFACE
} }
#include <google/protobuf/text_format.h>
#include "visualdl/backend/logic/sdk.h" #include "visualdl/backend/logic/sdk.h"
#include <google/protobuf/text_format.h>
namespace visualdl { namespace visualdl {
#define IMPL_ENTRY_SET_OR_ADD(method__, ctype__, dtype__, opr__) \ #define IMPL_ENTRY_SET_OR_ADD(method__, ctype__, dtype__, opr__) \
template <> void Entry<ctype__>::method__(ctype__ v) { \ template <> \
entry->set_dtype(storage::DataType::dtype__); \ void EntryHelper<ctype__>::method__(ctype__ v) { \
entry->opr__(v); \ entry->set_dtype(storage::DataType::dtype__); \
entry->opr__(v); \
} }
IMPL_ENTRY_SET_OR_ADD(Set, int32_t, kInt32, set_i32); IMPL_ENTRY_SET_OR_ADD(Set, int32_t, kInt32, set_i32);
...@@ -16,9 +17,36 @@ IMPL_ENTRY_SET_OR_ADD(Set, float, kFloat, set_f); ...@@ -16,9 +17,36 @@ IMPL_ENTRY_SET_OR_ADD(Set, float, kFloat, set_f);
IMPL_ENTRY_SET_OR_ADD(Set, double, kDouble, set_d); IMPL_ENTRY_SET_OR_ADD(Set, double, kDouble, set_d);
IMPL_ENTRY_SET_OR_ADD(Add, int32_t, kInt32s, add_i32s); IMPL_ENTRY_SET_OR_ADD(Add, int32_t, kInt32s, add_i32s);
IMPL_ENTRY_SET_OR_ADD(Add, int64_t, kInt64s, add_i64s); IMPL_ENTRY_SET_OR_ADD(Add, int64_t, kInt64s, add_i64s);
IMPL_ENTRY_SET_OR_ADD(Add, bool, kBools, add_bs);
IMPL_ENTRY_SET_OR_ADD(Add, float, kFloats, add_fs); IMPL_ENTRY_SET_OR_ADD(Add, float, kFloats, add_fs);
IMPL_ENTRY_SET_OR_ADD(Add, double, kDoubles, add_ds); IMPL_ENTRY_SET_OR_ADD(Add, double, kDoubles, add_ds);
IMPL_ENTRY_SET_OR_ADD(Add, std::string, kStrings, add_ss);
IMPL_ENTRY_SET_OR_ADD(Add, bool, kBools, add_bs);
#define IMPL_ENTRY_GET(T, fieldname__) \
template <> \
T EntryHelper<T>::Get() const { \
return entry->fieldname__(); \
}
IMPL_ENTRY_GET(int32_t, i32);
IMPL_ENTRY_GET(int64_t, i64);
IMPL_ENTRY_GET(float, f);
IMPL_ENTRY_GET(double, d);
IMPL_ENTRY_GET(std::string, s);
IMPL_ENTRY_GET(bool, b);
#define IMPL_ENTRY_GET_MULTI(T, fieldname__) \
template <> \
std::vector<T> EntryHelper<T>::GetMulti() const { \
return std::vector<T>(entry->fieldname__().begin(), \
entry->fieldname__().end()); \
}
IMPL_ENTRY_GET_MULTI(int32_t, i32s);
IMPL_ENTRY_GET_MULTI(int64_t, i64s);
IMPL_ENTRY_GET_MULTI(float, fs);
IMPL_ENTRY_GET_MULTI(double, ds);
IMPL_ENTRY_GET_MULTI(std::string, ss);
IMPL_ENTRY_GET_MULTI(bool, bs);
std::string StorageHelper::human_readable_buffer() const { std::string StorageHelper::human_readable_buffer() const {
std::string buffer; std::string buffer;
...@@ -32,4 +60,84 @@ std::string TabletHelper::human_readable_buffer() const { ...@@ -32,4 +60,84 @@ std::string TabletHelper::human_readable_buffer() const {
return buffer; return buffer;
} }
} // namespace visualdl void ImHelper::PersistToDisk() const {
InformationMaintainer::Global().PersistToDisk();
}
// implementations for components
namespace components {
template <typename T>
void ScalarHelper<T>::SetCaptions(const std::vector<std::string> &captions) {
for (int i = 0; i < captions.size(); i++) {
data_->add_captions(captions[i]);
}
}
template <typename T>
void ScalarHelper<T>::AddRecord(int id, const std::vector<T> &values) {
CHECK_NOTNULL(data_);
CHECK_GT(data_->captions_size(), 0UL) << "captions should be set first";
CHECK_EQ(data_->captions_size(), values.size())
<< "number of values in a record should be compatible with the "
"captions";
// add record data
auto *record = data_->add_records();
auto *data = record->add_data();
EntryHelper<T> entry_helper(data);
for (auto v : values) {
entry_helper.Add(v);
}
// set record id
record->set_id(id);
// set record timestamp
record->set_timestamp(time(NULL));
}
template <typename T>
std::vector<std::vector<T>> ScalarHelper<T>::GetRecords() const {
std::vector<std::vector<T>> result;
EntryHelper<T> entry_helper;
for (int i = 0; i < data_->records_size(); i++) {
auto *entry = data_->mutable_records(i)->mutable_data(0);
entry_helper(entry);
auto datas = entry_helper.GetMulti();
result.push_back(std::move(datas));
}
return result;
}
template <typename T>
std::vector<int> ScalarHelper<T>::GetIds() const {
CHECK_NOTNULL(data_);
std::vector<int> result;
for (int i = 0; i < data_->records_size(); i++) {
result.push_back(data_->records(i).id());
}
return result;
}
template <typename T>
std::vector<int> ScalarHelper<T>::GetTimestamps() const {
CHECK_NOTNULL(data_);
std::vector<int> result;
for (int i = 0; i < data_->records_size(); i++) {
result.push_back(data_->records(i).timestamp());
}
return result;
}
template <typename T>
std::vector<std::string> ScalarHelper<T>::GetCaptions() const {
return std::vector<std::string>(data_->captions().begin(),
data_->captions().end());
}
template class ScalarHelper<int32_t>;
template class ScalarHelper<int64_t>;
template class ScalarHelper<float>;
template class ScalarHelper<double>;
} // namespace components
} // namespace visualdl
#ifndef VISUALDL_BACKEND_LOGIC_SDK_H #ifndef VISUALDL_BACKEND_LOGIC_SDK_H
#define VISUALDL_BACKEND_LOGIC_SDK_H #define VISUALDL_BACKEND_LOGIC_SDK_H
#include "visualdl/backend/logic/im.h"
#include <glog/logging.h>
#include <time.h>
#include <map> #include <map>
#include "visualdl/backend/logic/im.h"
namespace visualdl { namespace visualdl {
/*
* Utility helper for storage::Entry.
*/
template <typename T>
struct EntryHelper {
// use pointer to avoid copy
storage::Entry *entry{nullptr};
EntryHelper() {}
explicit EntryHelper(storage::Entry *entry) : entry(entry) {}
void operator()(storage::Entry *entry) { this->entry = entry; }
/*
* Set a single value.
*/
void Set(T v);
/*
* Add a value to repeated message field.
*/
void Add(T v);
/*
* Get a single value.
*/
T Get() const;
/*
* Get repeated field.
*/
std::vector<T> GetMulti() const;
};
class TabletHelper { class TabletHelper {
public: public:
// method for each components
template <typename T>
void AddScalarRecord(int id, T value);
// basic member getter and setter // basic member getter and setter
std::string record_buffer(int idx) const { return data_->records(idx).SerializeAsString(); } std::string record_buffer(int idx) const {
return data_->records(idx).SerializeAsString();
}
size_t records_size() const { return data_->records_size(); } size_t records_size() const { return data_->records_size(); }
std::string buffer() const { return data_->SerializeAsString(); } std::string buffer() const { return data_->SerializeAsString(); }
std::string human_readable_buffer() const; std::string human_readable_buffer() const;
void SetBuffer(const storage::Tablet &t) { *data_ = t; } void SetBuffer(const storage::Tablet &t) { *data_ = t; }
void SetBuffer(const std::string &b) { data_->ParseFromString(b); } void SetBuffer(const std::string &b) { data_->ParseFromString(b); }
storage::Tablet &data() const { return *data_; }
// constructor that enable concurrency. // constructor that enable concurrency.
TabletHelper(storage::Tablet *t) : data_(t) {} TabletHelper(storage::Tablet *t) : data_(t) {}
// data updater that resuage of one instance. // data updater that resuage of one instance.
TabletHelper &operator()(storage::Tablet *t) { data_ = t; return *this; } TabletHelper &operator()(storage::Tablet *t) {
data_ = t;
return *this;
}
private: private:
storage::Tablet *data_; storage::Tablet *data_;
...@@ -66,14 +104,46 @@ public: ...@@ -66,14 +104,46 @@ public:
return TabletHelper( return TabletHelper(
InformationMaintainer::Global().AddTablet(tag, num_samples)); InformationMaintainer::Global().AddTablet(tag, num_samples));
} }
void ClearTablets() {
InformationMaintainer::Global().storage().mutable_data()->clear_tablets();
}
void PersistToDisk() const;
}; };
static ImHelper& get_im() { namespace components {
/*
* Read and write support for a Scalar component.
*/
template <typename T>
class ScalarHelper {
public:
ScalarHelper(storage::Tablet *tablet) : data_(tablet) {}
void SetCaptions(const std::vector<std::string> &captions);
void AddRecord(int id, const std::vector<T> &values);
std::vector<std::vector<T>> GetRecords() const;
std::vector<int> GetIds() const;
std::vector<int> GetTimestamps() const;
std::vector<std::string> GetCaptions() const;
private:
storage::Tablet *data_;
};
} // namespace components
static ImHelper &get_im() {
static ImHelper im; static ImHelper im;
return im; return im;
} }
} // namespace visualdl } // namespace visualdl
#include "visualdl/backend/logic/sdk.hpp"
#endif // VISUALDL_BACKEND_LOGIC_SDK_H #endif // VISUALDL_BACKEND_LOGIC_SDK_H
#include "visualdl/backend/logic/im.h"
namespace visualdl {
/*
* Utility helper for storage::Entry.
*/
template <typename T> struct Entry {
// use pointer to avoid copy
storage::Entry *entry{nullptr};
Entry(storage::Entry *entry) : entry(entry) {}
/*
* Set a single value.
*/
void Set(T v);
/*
* Add a value to repeated message field.
*/
void Add(T v);
};
template <typename T>
void TabletHelper::AddScalarRecord(int id, T value) {
auto* record = data_->add_records();
record->set_id(id);
Entry<T> entry_helper(record->mutable_data());
entry_helper.Set(value);
}
} // namespace visualdl
...@@ -37,11 +37,43 @@ message Entry { ...@@ -37,11 +37,43 @@ message Entry {
repeated float fs = 9; repeated float fs = 9;
repeated double ds = 10; repeated double ds = 10;
repeated int32 i32s = 11; repeated int32 i32s = 11;
repeated bool bs = 12; repeated string ss = 12;
repeated bool bs = 13;
} }
/*
The Record proto is designed to represent any data structure for any component, for
example, to store a record of Scalar component
Record {
// training error is 0.1, testing error is 0.2
data = [0.1, 0.2],
timestamp = xxxx,
// step id
id = xxxx
}
to store a record of Image component
Record {
// RBG pixel weights of a image
data = [[0.1, 0.2, ...], [...], [...]],
timestamp = xxxx,
// image shape
meta = [100, 200]
}
for other complex structure which have more fields than `timestamp`, `id` and meta, the additional fields
can store in the `data` field, for it can store a list of values in different basic data types.
A component handlers in logic layer should know how to write or load a record for the corresponding component.
*/
message Record { message Record {
Entry data = 1; // one record might have multiple fields, one specific component should know how
// to parse the records.
repeated Entry data = 1;
// NOTE the timestamp, id, dtype might be useless for that all the meta info can
// be stored in `data` field.
int64 timestamp = 2; int64 timestamp = 2;
// store the count of writing operations to the tablet. // store the count of writing operations to the tablet.
int64 id = 3; int64 id = 3;
...@@ -76,6 +108,9 @@ message Tablet { ...@@ -76,6 +108,9 @@ message Tablet {
Entry meta = 5; Entry meta = 5;
// the unique identification for this `Tablet`. // the unique identification for this `Tablet`.
string tag = 6; string tag = 6;
// one tablet might have multiple captions, for example, a scalar component might have
// two plots labeled "train" and "test".
repeated string captions = 7;
} }
/* /*
......
import sys import sys
import unittest import unittest
import numpy as np
sys.path.append('../../build') sys.path.append('../../build')
import core import core
...@@ -9,6 +10,7 @@ im = core.im() ...@@ -9,6 +10,7 @@ im = core.im()
class StorageTester(unittest.TestCase): class StorageTester(unittest.TestCase):
def setUp(self): def setUp(self):
im.clear_tablets()
self.storage = im.storage() self.storage = im.storage()
def test_size(self): def test_size(self):
...@@ -34,15 +36,41 @@ class StorageTester(unittest.TestCase): ...@@ -34,15 +36,41 @@ class StorageTester(unittest.TestCase):
class TabletTester(unittest.TestCase): class TabletTester(unittest.TestCase):
def setUp(self): def setUp(self):
im.clear_tablets()
self.tablet = im.add_tablet("tag101", 20) self.tablet = im.add_tablet("tag101", 20)
def test_add_scalar(self):
self.tablet.add_scalar_float(1, 0.3)
self.assertEqual(self.tablet.records_size(), 1)
def test_human_readable_buffer(self): def test_human_readable_buffer(self):
print self.tablet.human_readable_buffer() print self.tablet.human_readable_buffer()
def test_scalar(self):
scalar = self.tablet.as_float_scalar()
py_captions = ["train", "test"]
step_ids = [10, 20, 30]
py_records = [
[0.1, 0.2],
[0.2, 0.3],
[0.3, 0.4]
]
scalar.set_captions(py_captions)
for i in range(len(py_records)):
scalar.add_record(step_ids[i], py_records[i])
records = scalar.get_records()
ids = scalar.get_ids()
for i in range(len(py_records)):
self.assertTrue(np.isclose(py_records[i], records[i]).all())
self.assertEqual(step_ids[i], ids[i])
class ImTester(unittest.TestCase):
def test_persist(self):
im.clear_tablets()
tablet = im.add_tablet("tab0", 111)
self.assertEqual(im.storage().tablets_size(), 1)
im.storage().set_dir("./1")
im.persist_to_disk()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册