提交 c377a081 编写于 作者: Y Yan Chunwei 提交者: GitHub

Merge pull request #26 from Superjom/feature/backend_scalar_interface

......@@ -16,7 +16,6 @@ link_directories(${PROJECT_SOURCE_DIR}/thirdparty/local/lib)
add_library(storage
${PROJECT_SOURCE_DIR}/visualdl/backend/storage/storage.cc
${PROJECT_SOURCE_DIR}/visualdl/backend/storage/storage.pb.cc)
add_library(c_api ${PROJECT_SOURCE_DIR}/visualdl/backend/logic/c_api.cc)
add_library(sdk ${PROJECT_SOURCE_DIR}/visualdl/backend/logic/sdk.cc)
add_library(im ${PROJECT_SOURCE_DIR}/visualdl/backend/logic/im.cc)
......@@ -29,4 +28,3 @@ add_executable(vl_test
${PROJECT_SOURCE_DIR}/visualdl/backend/test.cc
${PROJECT_SOURCE_DIR}/visualdl/backend/logic/im_test.cc)
target_link_libraries(vl_test storage im gtest glog protobuf gflags)
#include <ctime>
#include <glog/logging.h>
#include <ctime>
#include "visualdl/backend/logic/im.h"
......@@ -45,11 +45,17 @@ void InformationMaintainer::AddRecord(const std::string &tag,
auto *tablet = storage_.Find(tag);
CHECK(tablet);
auto num_records = tablet->num_records();
auto num_records = tablet->total_records();
const auto num_samples = tablet->num_samples();
const auto offset = ReserviorSample(num_samples, num_records + 1);
if (offset < 0)
return;
int offset;
// use reservoir sampling or not
if (num_samples > 0) {
offset = ReserviorSample(num_samples, num_records + 1);
if (offset < 0) return;
} else {
offset = num_records;
}
storage::Record *record;
if (offset >= num_records) {
......@@ -59,13 +65,21 @@ void InformationMaintainer::AddRecord(const std::string &tag,
}
*record = data;
tablet->set_total_records(num_records + 1);
}
tablet->set_num_records(num_records + 1);
void InformationMaintainer::Clear() {
auto *data = storage().mutable_data();
data->clear_tablets();
data->clear_dir();
data->clear_timestamp();
}
void InformationMaintainer::PersistToDisk() {
CHECK(!storage_.data().dir().empty()) << "path of storage should be set";
storage_.Save(storage_.data().dir());
// TODO make dir first
// MakeDir(storage_.data().dir());
storage_.Save(storage_.data().dir() + "/storage.pb");
}
} // namespace visualdl
} // namespace visualdl
......@@ -29,14 +29,18 @@ public:
/*
* @tag: tag of the target Tablet.
* @type: type of target Tablet.
* @data: storage Record.
* @record: a record
*
* NOTE pass in the serialized protobuf message will trigger copying, but
* simpler to support different Tablet data formats.
*/
void AddRecord(const std::string &tag, const storage::Record &record);
/*
* delete all the information.
*/
void Clear();
/*
* Save the Storage Protobuf to disk.
*/
......
#include <ctype.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "visualdl/backend/logic/sdk.h"
......@@ -10,19 +11,6 @@ PYBIND11_MODULE(core, m) {
m.doc() = "visualdl python core API";
py::class_<vs::TabletHelper>(m, "Tablet")
// interfaces for components
.def("add_scalar_int32",
&vs::TabletHelper::AddScalarRecord<int32_t>,
"add a scalar int32 record",
pybind11::arg("id"),
pybind11::arg("value"))
.def("add_scalar_int64",
&vs::TabletHelper::AddScalarRecord<int64_t>,
"add a scalar int64 record",
pybind11::arg("id"),
pybind11::arg("value"))
.def("add_scalar_float", &vs::TabletHelper::AddScalarRecord<float>)
.def("add_scalar_double", &vs::TabletHelper::AddScalarRecord<double>)
// other member setter and getter
.def("record_buffer", &vs::TabletHelper::record_buffer)
.def("records_size", &vs::TabletHelper::records_size)
......@@ -30,7 +18,23 @@ PYBIND11_MODULE(core, m) {
.def("human_readable_buffer", &vs::TabletHelper::human_readable_buffer)
.def("set_buffer",
(void (vs::TabletHelper::*)(const std::string&)) &
vs::TabletHelper::SetBuffer);
vs::TabletHelper::SetBuffer)
// scalar interface
.def("as_int32_scalar",
[](const vs::TabletHelper& self) {
return vs::components::ScalarHelper<int32_t>(&self.data());
})
.def("as_int64_scalar",
[](const vs::TabletHelper& self) {
return vs::components::ScalarHelper<int64_t>(&self.data());
})
.def("as_float_scalar",
[](const vs::TabletHelper& self) {
return vs::components::ScalarHelper<float>(&self.data());
})
.def("as_double_scalar", [](const vs::TabletHelper& self) {
return vs::components::ScalarHelper<double>(&self.data());
});
py::class_<vs::StorageHelper>(m, "Storage")
.def("timestamp", &vs::StorageHelper::timestamp)
......@@ -46,7 +50,24 @@ PYBIND11_MODULE(core, m) {
py::class_<vs::ImHelper>(m, "Im")
.def("storage", &vs::ImHelper::storage)
.def("tablet", &vs::ImHelper::tablet)
.def("add_tablet", &vs::ImHelper::AddTablet);
.def("add_tablet", &vs::ImHelper::AddTablet)
.def("persist_to_disk", &vs::ImHelper::PersistToDisk)
.def("clear_tablets", &vs::ImHelper::ClearTablets);
m.def("im", &vs::get_im, "global information-maintainer object.");
// interfaces for components
#define ADD_SCALAR_TYPED_INTERFACE(T, name__) \
py::class_<vs::components::ScalarHelper<T>>(m, #name__) \
.def("add_record", &vs::components::ScalarHelper<T>::AddRecord) \
.def("set_captions", &vs::components::ScalarHelper<T>::SetCaptions) \
.def("get_records", &vs::components::ScalarHelper<T>::GetRecords) \
.def("get_captions", &vs::components::ScalarHelper<T>::GetCaptions) \
.def("get_ids", &vs::components::ScalarHelper<T>::GetIds) \
.def("get_timestamps", &vs::components::ScalarHelper<T>::GetTimestamps);
ADD_SCALAR_TYPED_INTERFACE(int32_t, ScalarInt32);
ADD_SCALAR_TYPED_INTERFACE(int64_t, ScalarInt64);
ADD_SCALAR_TYPED_INTERFACE(float, ScalarFloat);
ADD_SCALAR_TYPED_INTERFACE(double, ScalarDouble);
#undef ADD_SCALAR_TYPED_INTERFACE
}
#include <google/protobuf/text_format.h>
#include "visualdl/backend/logic/sdk.h"
#include <google/protobuf/text_format.h>
namespace visualdl {
#define IMPL_ENTRY_SET_OR_ADD(method__, ctype__, dtype__, opr__) \
template <> void Entry<ctype__>::method__(ctype__ v) { \
entry->set_dtype(storage::DataType::dtype__); \
entry->opr__(v); \
#define IMPL_ENTRY_SET_OR_ADD(method__, ctype__, dtype__, opr__) \
template <> \
void EntryHelper<ctype__>::method__(ctype__ v) { \
entry->set_dtype(storage::DataType::dtype__); \
entry->opr__(v); \
}
IMPL_ENTRY_SET_OR_ADD(Set, int32_t, kInt32, set_i32);
......@@ -16,9 +17,36 @@ IMPL_ENTRY_SET_OR_ADD(Set, float, kFloat, set_f);
IMPL_ENTRY_SET_OR_ADD(Set, double, kDouble, set_d);
IMPL_ENTRY_SET_OR_ADD(Add, int32_t, kInt32s, add_i32s);
IMPL_ENTRY_SET_OR_ADD(Add, int64_t, kInt64s, add_i64s);
IMPL_ENTRY_SET_OR_ADD(Add, bool, kBools, add_bs);
IMPL_ENTRY_SET_OR_ADD(Add, float, kFloats, add_fs);
IMPL_ENTRY_SET_OR_ADD(Add, double, kDoubles, add_ds);
IMPL_ENTRY_SET_OR_ADD(Add, std::string, kStrings, add_ss);
IMPL_ENTRY_SET_OR_ADD(Add, bool, kBools, add_bs);
#define IMPL_ENTRY_GET(T, fieldname__) \
template <> \
T EntryHelper<T>::Get() const { \
return entry->fieldname__(); \
}
IMPL_ENTRY_GET(int32_t, i32);
IMPL_ENTRY_GET(int64_t, i64);
IMPL_ENTRY_GET(float, f);
IMPL_ENTRY_GET(double, d);
IMPL_ENTRY_GET(std::string, s);
IMPL_ENTRY_GET(bool, b);
#define IMPL_ENTRY_GET_MULTI(T, fieldname__) \
template <> \
std::vector<T> EntryHelper<T>::GetMulti() const { \
return std::vector<T>(entry->fieldname__().begin(), \
entry->fieldname__().end()); \
}
IMPL_ENTRY_GET_MULTI(int32_t, i32s);
IMPL_ENTRY_GET_MULTI(int64_t, i64s);
IMPL_ENTRY_GET_MULTI(float, fs);
IMPL_ENTRY_GET_MULTI(double, ds);
IMPL_ENTRY_GET_MULTI(std::string, ss);
IMPL_ENTRY_GET_MULTI(bool, bs);
std::string StorageHelper::human_readable_buffer() const {
std::string buffer;
......@@ -32,4 +60,84 @@ std::string TabletHelper::human_readable_buffer() const {
return buffer;
}
} // namespace visualdl
void ImHelper::PersistToDisk() const {
InformationMaintainer::Global().PersistToDisk();
}
// implementations for components
namespace components {
template <typename T>
void ScalarHelper<T>::SetCaptions(const std::vector<std::string> &captions) {
for (int i = 0; i < captions.size(); i++) {
data_->add_captions(captions[i]);
}
}
template <typename T>
void ScalarHelper<T>::AddRecord(int id, const std::vector<T> &values) {
CHECK_NOTNULL(data_);
CHECK_GT(data_->captions_size(), 0UL) << "captions should be set first";
CHECK_EQ(data_->captions_size(), values.size())
<< "number of values in a record should be compatible with the "
"captions";
// add record data
auto *record = data_->add_records();
auto *data = record->add_data();
EntryHelper<T> entry_helper(data);
for (auto v : values) {
entry_helper.Add(v);
}
// set record id
record->set_id(id);
// set record timestamp
record->set_timestamp(time(NULL));
}
template <typename T>
std::vector<std::vector<T>> ScalarHelper<T>::GetRecords() const {
std::vector<std::vector<T>> result;
EntryHelper<T> entry_helper;
for (int i = 0; i < data_->records_size(); i++) {
auto *entry = data_->mutable_records(i)->mutable_data(0);
entry_helper(entry);
auto datas = entry_helper.GetMulti();
result.push_back(std::move(datas));
}
return result;
}
template <typename T>
std::vector<int> ScalarHelper<T>::GetIds() const {
CHECK_NOTNULL(data_);
std::vector<int> result;
for (int i = 0; i < data_->records_size(); i++) {
result.push_back(data_->records(i).id());
}
return result;
}
template <typename T>
std::vector<int> ScalarHelper<T>::GetTimestamps() const {
CHECK_NOTNULL(data_);
std::vector<int> result;
for (int i = 0; i < data_->records_size(); i++) {
result.push_back(data_->records(i).timestamp());
}
return result;
}
template <typename T>
std::vector<std::string> ScalarHelper<T>::GetCaptions() const {
return std::vector<std::string>(data_->captions().begin(),
data_->captions().end());
}
template class ScalarHelper<int32_t>;
template class ScalarHelper<int64_t>;
template class ScalarHelper<float>;
template class ScalarHelper<double>;
} // namespace components
} // namespace visualdl
#ifndef VISUALDL_BACKEND_LOGIC_SDK_H
#define VISUALDL_BACKEND_LOGIC_SDK_H
#include "visualdl/backend/logic/im.h"
#include <glog/logging.h>
#include <time.h>
#include <map>
#include "visualdl/backend/logic/im.h"
namespace visualdl {
/*
* Utility helper for storage::Entry.
*/
template <typename T>
struct EntryHelper {
// use pointer to avoid copy
storage::Entry *entry{nullptr};
EntryHelper() {}
explicit EntryHelper(storage::Entry *entry) : entry(entry) {}
void operator()(storage::Entry *entry) { this->entry = entry; }
/*
* Set a single value.
*/
void Set(T v);
/*
* Add a value to repeated message field.
*/
void Add(T v);
/*
* Get a single value.
*/
T Get() const;
/*
* Get repeated field.
*/
std::vector<T> GetMulti() const;
};
class TabletHelper {
public:
// method for each components
template <typename T>
void AddScalarRecord(int id, T value);
// basic member getter and setter
std::string record_buffer(int idx) const { return data_->records(idx).SerializeAsString(); }
std::string record_buffer(int idx) const {
return data_->records(idx).SerializeAsString();
}
size_t records_size() const { return data_->records_size(); }
std::string buffer() const { return data_->SerializeAsString(); }
std::string human_readable_buffer() const;
void SetBuffer(const storage::Tablet &t) { *data_ = t; }
void SetBuffer(const std::string &b) { data_->ParseFromString(b); }
storage::Tablet &data() const { return *data_; }
// constructor that enable concurrency.
TabletHelper(storage::Tablet *t) : data_(t) {}
// data updater that resuage of one instance.
TabletHelper &operator()(storage::Tablet *t) { data_ = t; return *this; }
TabletHelper &operator()(storage::Tablet *t) {
data_ = t;
return *this;
}
private:
storage::Tablet *data_;
......@@ -66,14 +104,46 @@ public:
return TabletHelper(
InformationMaintainer::Global().AddTablet(tag, num_samples));
}
void ClearTablets() {
InformationMaintainer::Global().storage().mutable_data()->clear_tablets();
}
void PersistToDisk() const;
};
static ImHelper& get_im() {
namespace components {
/*
* Read and write support for Scalar component.
*/
template <typename T>
class ScalarHelper {
public:
ScalarHelper(storage::Tablet *tablet) : data_(tablet) {}
void SetCaptions(const std::vector<std::string> &captions);
void AddRecord(int id, const std::vector<T> &values);
std::vector<std::vector<T>> GetRecords() const;
std::vector<int> GetIds() const;
std::vector<int> GetTimestamps() const;
std::vector<std::string> GetCaptions() const;
private:
storage::Tablet *data_;
};
} // namespace components
static ImHelper &get_im() {
static ImHelper im;
return im;
}
} // namespace visualdl
} // namespace visualdl
#include "visualdl/backend/logic/sdk.hpp"
#endif // VISUALDL_BACKEND_LOGIC_SDK_H
#include "visualdl/backend/logic/im.h"
namespace visualdl {
/*
* Utility helper for storage::Entry.
*/
template <typename T> struct Entry {
// use pointer to avoid copy
storage::Entry *entry{nullptr};
Entry(storage::Entry *entry) : entry(entry) {}
/*
* Set a single value.
*/
void Set(T v);
/*
* Add a value to repeated message field.
*/
void Add(T v);
};
template <typename T>
void TabletHelper::AddScalarRecord(int id, T value) {
auto* record = data_->add_records();
record->set_id(id);
Entry<T> entry_helper(record->mutable_data());
entry_helper.Set(value);
}
} // namespace visualdl
#include <fstream>
#include <glog/logging.h>
#include <fstream>
#include "visualdl/backend/storage/storage.h"
......@@ -24,15 +24,15 @@ storage::Record *Storage::NewRecord(const std::string &tag) {
CHECK(tablet) << "Tablet" << tag << " should be create first";
auto *record = tablet->mutable_records()->Add();
// increase num_records
int num_records = tablet->num_records();
tablet->set_num_records(num_records + 1);
int num_records = tablet->total_records();
tablet->set_total_records(num_records + 1);
return record;
}
storage::Record *Storage::GetRecord(const std::string &tag, int offset) {
auto *tablet = Find(tag);
CHECK(tablet) << "Tablet" << tag << " should be create first";
auto num_records = tablet->num_records();
auto num_records = tablet->total_records();
CHECK_LT(offset, num_records) << "invalid offset";
return tablet->mutable_records()->Mutable(offset);
}
......@@ -60,4 +60,4 @@ void Storage::DeSerialize(const std::string &data) {
proto_.ParseFromString(data);
}
} // namespace visualdl
} // namespace visualdl
......@@ -2,88 +2,135 @@ syntax = "proto3";
package storage;
enum DataType {
// single entry
kInt32 = 0;
kInt64 = 1;
kFloat = 2;
kDouble = 3;
kString = 4;
kBool = 5;
// entrys
kInt64s = 6;
kFloats = 7;
kDoubles = 8;
kStrings = 9;
kInt32s = 10;
kBools = 11;
// single entry
kInt32 = 0;
kInt64 = 1;
kFloat = 2;
kDouble = 3;
kString = 4;
kBool = 5;
// entrys
kInt64s = 6;
kFloats = 7;
kDoubles = 8;
kStrings = 9;
kInt32s = 10;
kBools = 11;
kUnknown = 12;
kUnknown = 12;
}
// A data array, which type is `type`.
message Entry {
// if all the entries in a record share the same data type, ignore this field
// and store type to `dtype` in `Record`.
DataType dtype = 1;
// single element
int32 i32 = 2;
int64 i64 = 3;
string s = 4;
float f = 5;
double d = 6;
bool b = 7;
// array
repeated int64 i64s = 8;
repeated float fs = 9;
repeated double ds = 10;
repeated int32 i32s = 11;
repeated bool bs = 12;
// if all the entries in a record share the same data type, ignore this field
// and store type to `dtype` in `Record`.
DataType dtype = 1;
// single element
int32 i32 = 2;
int64 i64 = 3;
string s = 4;
float f = 5;
double d = 6;
bool b = 7;
// array
repeated int64 i64s = 8;
repeated float fs = 9;
repeated double ds = 10;
repeated int32 i32s = 11;
repeated string ss = 12;
repeated bool bs = 13;
}
/*
The Record proto is designed to represent any data structure for any component,
for
example, to store a record of Scalar component
Record {
// training error is 0.1, testing error is 0.2
data = [0.1, 0.2],
timestamp = xxxx,
// step id
id = xxxx
}
to store a record of Image component
Record {
// RBG pixel weights of a image
data = [[0.1, 0.2, ...], [...], [...]],
timestamp = xxxx,
// image shape
meta = [100, 200]
}
for other complex structure which have more fields than `timestamp`, `id` and
meta, the additional fields
can store in the `data` field, for it can store a list of values in different
basic data types.
A component handlers in logic layer should know how to write or load a record
for the corresponding component.
*/
message Record {
Entry data = 1;
int64 timestamp = 2;
// store the count of writing operations to the tablet.
int64 id = 3;
DataType dtype = 4;
// shape or some other meta infomation for this record, if all the records
// share the same meta, just store one copy of meta in `Storage`, or create
// a unique copy for each `Record`.
Entry meta = 5;
// one record might have multiple fields, one specific component should know
// how
// to parse the records.
repeated Entry data = 1;
// NOTE the timestamp, id, dtype might be useless for that all the meta info
// can
// be stored in `data` field.
int64 timestamp = 2;
int64 id = 3;
DataType dtype = 4;
// shape or some other meta infomation for this record, if all the records
// share the same meta, just store one copy of meta in `Storage`, or create
// a unique copy for each `Record`.
Entry meta = 5;
}
/*
A Tablet stores the records of a component which type is `component` and indidates as `tag`.
The records will be saved in a file which name contains `tag`. During the running period,
`num_records` will be accumulated, and `num_samples` indicates the size of sample set the
A Tablet stores the records of a component which type is `component` and
indidates as `tag`.
The records will be saved in a file which name contains `tag`. During the
running period,
`num_records` will be accumulated, and `num_samples` indicates the size of
sample set the
reservoir sampling algorithm will collect.
*/
message Tablet {
// the kinds of the components that supported
enum Type {
kScalar = 0;
kHistogram = 1;
kGraph = 2;
}
// type of the component, different component should have different storage format.
Type component = 1;
// records the total count of records, each Write operation should increate this value.
int64 num_records = 2;
// indicate the number of instances to sample, this should be a constant value.
int32 num_samples = 3;
repeated Record records = 4;
// store a meta infomation if all the records share.
Entry meta = 5;
// the unique identification for this `Tablet`.
string tag = 6;
// the kinds of the components that supported
enum Type {
kScalar = 0;
kHistogram = 1;
kGraph = 2;
}
// type of the component, different component should have different storage
// format.
Type component = 1;
// records the total count of records, each Write operation should increate
// this value.
int64 total_records = 2;
// indicate the number of instances to sample, this should be a constant
// value.
int32 num_samples = 3;
repeated Record records = 4;
// store a meta infomation if all the records share.
Entry meta = 5;
// the unique identification for this `Tablet`.
string tag = 6;
// one tablet might have multiple captions, for example, a scalar component
// might have
// two plots labeled "train" and "test".
repeated string captions = 7;
}
/*
The Storage stores all the records.
*/
message Storage {
// tags to Tablet, should be thread safe if fix the keys after initialization.
map<string, Tablet> tablets = 1;
string dir = 2;
int64 timestamp = 3;
// tags to Tablet, should be thread safe if fix the keys after initialization.
map<string, Tablet> tablets = 1;
string dir = 2;
int64 timestamp = 3;
}
\ No newline at end of file
import sys
import unittest
import numpy as np
sys.path.append('../../build')
import core
......@@ -9,6 +10,7 @@ im = core.im()
class StorageTester(unittest.TestCase):
def setUp(self):
im.clear_tablets()
self.storage = im.storage()
def test_size(self):
......@@ -34,15 +36,41 @@ class StorageTester(unittest.TestCase):
class TabletTester(unittest.TestCase):
def setUp(self):
im.clear_tablets()
self.tablet = im.add_tablet("tag101", 20)
def test_add_scalar(self):
self.tablet.add_scalar_float(1, 0.3)
self.assertEqual(self.tablet.records_size(), 1)
def test_human_readable_buffer(self):
print self.tablet.human_readable_buffer()
def test_scalar(self):
scalar = self.tablet.as_float_scalar()
py_captions = ["train", "test"]
step_ids = [10, 20, 30]
py_records = [
[0.1, 0.2],
[0.2, 0.3],
[0.3, 0.4]
]
scalar.set_captions(py_captions)
for i in range(len(py_records)):
scalar.add_record(step_ids[i], py_records[i])
records = scalar.get_records()
ids = scalar.get_ids()
for i in range(len(py_records)):
self.assertTrue(np.isclose(py_records[i], records[i]).all())
self.assertEqual(step_ids[i], ids[i])
class ImTester(unittest.TestCase):
def test_persist(self):
im.clear_tablets()
tablet = im.add_tablet("tab0", 111)
self.assertEqual(im.storage().tablets_size(), 1)
im.storage().set_dir("./1")
im.persist_to_disk()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册