提交 64c5ac37 编写于 作者: Y Yan Chunwei 提交者: GitHub

feature/add histogram component backend (#52)

上级 e92ef43c
......@@ -30,8 +30,6 @@ include(external/python) # find python and set path
include_directories(${PROJECT_SOURCE_DIR})
include_directories(${CMAKE_CURRENT_BINARY_DIR})
# TODO(ChunweiYan) debug, remote latter
#include_directories(/home/superjom/project/VisualDL/build/third_party/eigen3/src/extern_eigen3)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/visualdl/storage)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/visualdl/logic)
......@@ -40,6 +38,7 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/visualdl/python)
add_executable(vl_test
${PROJECT_SOURCE_DIR}/visualdl/test.cc
${PROJECT_SOURCE_DIR}/visualdl/logic/sdk_test.cc
${PROJECT_SOURCE_DIR}/visualdl/logic/histogram_test.cc
${PROJECT_SOURCE_DIR}/visualdl/storage/storage_test.cc
${PROJECT_SOURCE_DIR}/visualdl/utils/test_concurrency.cc
${PROJECT_SOURCE_DIR}/visualdl/utils/test_image.cc
......
......@@ -7,18 +7,16 @@ from tempfile import NamedTemporaryFile
import numpy as np
from PIL import Image
import storage
def get_modes(storage):
return storage.modes()
def get_scalar_tags(storage, mode):
def get_tags(storage, component):
result = {}
for mode in storage.modes():
with storage.mode(mode) as reader:
tags = reader.tags('scalar')
tags = reader.tags(component)
if tags:
result[mode] = {}
for tag in tags:
......@@ -29,6 +27,10 @@ def get_scalar_tags(storage, mode):
return result
def get_scalar_tags(storage):
return get_tags(storage, 'scalar')
def get_scalar(storage, mode, tag, num_records=300):
assert num_records > 1
......@@ -143,6 +145,38 @@ def get_invididual_image(storage, mode, tag, step_index, max_size=80):
return tempfile
def get_histogram_tags(storage):
return get_tags(storage, 'histogram')
def get_histogram(storage, mode, tag):
with storage.mode(mode) as reader:
histogram = reader.histogram(tag)
res = []
for i in xrange(histogram.num_records()):
try:
# some bug with protobuf, some times may overflow
record = histogram.record(i)
except:
continue
res.append([])
py_record = res[-1]
py_record.append(record.timestamp())
py_record.append(record.step())
py_record.append([])
data = py_record[-1]
for j in xrange(record.num_instances()):
instance = record.instance(j)
data.append(
[instance.left(),
instance.right(),
instance.frequency()])
return res
if __name__ == '__main__':
reader = storage.LogReader('./tmp/mock')
tags = get_image_tags(reader)
......
......@@ -3,13 +3,15 @@ import unittest
import lib
import storage
from storage_mock import add_image, add_scalar
import pprint
from storage_mock import add_scalar, add_image, add_histogram
class LibTest(unittest.TestCase):
def setUp(self):
dir = "./tmp/mock"
writer = storage.LogWriter(dir, sync_cycle=20)
writer = storage.LogWriter(dir, sync_cycle=10)
add_scalar(writer, "train", "layer/scalar0/min", 1000, 1)
add_scalar(writer, "test", "layer/scalar0/min", 1000, 10)
......@@ -22,23 +24,26 @@ class LibTest(unittest.TestCase):
add_image(writer, "train", "layer/image0", 7, 10, 1)
add_image(writer, "test", "layer/image0", 7, 10, 3)
add_image(writer, "train", "layer/image1", 7, 10, 1, shape=[30,30,2])
add_image(writer, "test", "layer/image1", 7, 10, 1, shape=[30,30,2])
add_image(writer, "train", "layer/image1", 7, 10, 1, shape=[30, 30, 2])
add_image(writer, "test", "layer/image1", 7, 10, 1, shape=[30, 30, 2])
add_histogram(writer, "train", "layer/histogram0", 100)
add_histogram(writer, "test", "layer/histogram0", 100)
self.reader = storage.LogReader(dir)
def test_modes(self):
modes = lib.get_modes(self.reader)
self.assertEqual(sorted(modes), sorted(["default", "train", "test", "valid"]))
self.assertEqual(
sorted(modes), sorted(["default", "train", "test", "valid"]))
def test_scalar(self):
for mode in "train test valid".split():
tags = lib.get_scalar_tags(self.reader, mode)
print 'scalar tags:'
pprint.pprint(tags)
self.assertEqual(len(tags), 3)
self.assertEqual(sorted(tags.keys()), sorted("train test valid".split()))
tags = lib.get_scalar_tags(self.reader)
print 'scalar tags:'
pprint.pprint(tags)
self.assertEqual(len(tags), 3)
self.assertEqual(
sorted(tags.keys()), sorted("train test valid".split()))
def test_image(self):
tags = lib.get_image_tags(self.reader)
......@@ -47,9 +52,17 @@ class LibTest(unittest.TestCase):
tags = lib.get_image_tag_steps(self.reader, 'train', 'layer/image0/0')
pprint.pprint(tags)
image = lib.get_invididual_image(self.reader, "train", 'layer/image0/0', 2)
image = lib.get_invididual_image(self.reader, "train",
'layer/image0/0', 2)
print image
def test_histogram(self):
tags = lib.get_histogram_tags(self.reader)
self.assertEqual(len(tags), 2)
res = lib.get_histogram(self.reader, "train", "layer/histogram0")
pprint.pprint(res)
if __name__ == '__main__':
unittest.main()
......@@ -35,15 +35,8 @@ def add_image(writer,
image_writer.set_sample(index, shape, list(data))
image_writer.finish_sampling()
if __name__ == '__main__':
add_scalar("train", "layer/scalar0/min", 1000, 1)
add_scalar("test", "layer/scalar0/min", 1000, 10)
add_scalar("valid", "layer/scalar0/min", 1000, 10)
add_scalar("train", "layer/scalar0/max", 1000, 1)
add_scalar("test", "layer/scalar0/max", 1000, 10)
add_scalar("valid", "layer/scalar0/max", 1000, 10)
add_image("train", "layer/image0", 7, 10, 1)
add_image("test", "layer/image0", 7, 10, 3)
def add_histogram(writer, mode, tag, num_buckets):
with writer.mode(mode) as writer:
histogram = writer.histogram(tag, num_buckets)
for i in range(10):
histogram.add_record(i, np.random.normal(0.1 + i * 0.01, size=1000))
......@@ -98,8 +98,7 @@ def scalar_tags():
if is_debug:
result = mock_tags.data()
else:
result = lib.get_scalar_tags(log_reader, mode)
print 'scalar tags (mode: %s)' % mode, result
result = lib.get_scalar_tags(log_reader)
result = gen_result(0, "", result)
return Response(json.dumps(result), mimetype='application/json')
......@@ -108,7 +107,14 @@ def scalar_tags():
def image_tags():
mode = request.args.get('run')
result = lib.get_image_tags(log_reader)
print 'image tags (mode: %s)'%mode, result
result = gen_result(0, "", result)
return Response(json.dumps(result), mimetype='application/json')
@app.route("/data/plugin/histograms/tags")
def histogram_tags():
mode = request.args.get('run')
result = lib.get_histogram_tags(log_reader)
result = gen_result(0, "", result)
return Response(json.dumps(result), mimetype='application/json')
......@@ -151,6 +157,15 @@ def individual_image():
return response
@app.route('/data/plugin/histograms/histograms')
def histogram():
run = request.args.get('run')
tag = request.args.get('tag')
result = lib.get_histogram(log_reader, run, tag)
result = gen_result(0, "", result)
return Response(json.dumps(result), mimetype='application/json')
@app.route('/data/plugin/graphs/graph')
def graph():
model_json = graph.load_model("")
......
#ifndef VISUALDL_LOGIC_HISTOGRAM_H
#define VISUALDL_LOGIC_HISTOGRAM_H
#include <glog/logging.h>
#include <cstdlib>
#include <limits>
#include <vector>
namespace visualdl {
// An interface to retrieve records of a histogram.
template <typename T>
struct HistogramRecord {
struct Instance {
T left;
T right;
int32_t frequency;
};
uint64_t timestamp;
int step;
HistogramRecord(uint64_t timestamp,
int step,
T left,
T right,
std::vector<int32_t>&& frequency)
: timestamp(timestamp),
step(step),
left(left),
right(right),
frequency(frequency),
span_(float(right - left) / frequency.size()) {}
Instance instance(int i) const {
CHECK_LT(i, frequency.size());
Instance res;
res.left = left + span_ * i;
res.right = res.left + span_;
res.frequency = frequency[i];
return res;
}
size_t num_instances() const { return frequency.size(); }
private:
T span_;
T left;
T right;
std::vector<int32_t> frequency;
};
// Create a histogram with default(10%) set of bucket boundaries.
// The buckets cover the range from min to max.
template <typename T>
struct HistogramBuilder {
HistogramBuilder(int num_buckets) : num_buckets_(num_buckets) {}
void operator()(const std::vector<T>& data) {
CHECK_GE(data.size(), num_buckets_);
UpdateBoundary(data);
CreateBuckets(data);
}
T left_boundary{std::numeric_limits<T>::max()};
T right_boundary{std::numeric_limits<T>::min()};
std::vector<int> buckets;
void Get(size_t n, T* left, T* right, int* frequency) {
CHECK(!buckets.empty()) << "need to CreateBuckets first.";
CHECK_LT(n, num_buckets_) << "n out of range.";
*left = left_boundary + span_ * n;
*right = *left + span_;
*frequency = buckets[n];
}
private:
// Get the left and right boundaries.
void UpdateBoundary(const std::vector<T>& data) {
for (auto v : data) {
if (v > right_boundary)
right_boundary = v;
else if (v < left_boundary)
left_boundary = v;
}
}
// Create `num_buckets` buckets.
void CreateBuckets(const std::vector<T>& data) {
span_ = (float)right_boundary / num_buckets_ -
(float)left_boundary / num_buckets_;
buckets.resize(num_buckets_);
for (auto v : data) {
int offset = std::min(int((v - left_boundary) / span_), num_buckets_ - 1);
buckets[offset]++;
}
}
float span_;
int num_buckets_;
};
} // namespace visualdl
#endif
#include "visualdl/logic/histogram.h"
#include <gtest/gtest.h>
#include <cstdlib>
using namespace std;
using namespace visualdl;
TEST(HistogramBuilder, build) {
const int size = 3000;
std::vector<float> data(size);
for (auto& v : data) {
v = (float)rand() / RAND_MAX - 0.5;
}
HistogramBuilder<float> builder(100);
builder(data);
float left, right;
int frequency;
for (int i = 0; i < 100; i++) {
builder.Get(i, &left, &right, &frequency);
ASSERT_GT(frequency, 0);
}
}
......@@ -32,13 +32,6 @@ void SimpleWriteSyncGuard<T>::Sync() {
template class SimpleWriteSyncGuard<Storage>;
template class SimpleWriteSyncGuard<Tablet>;
template class SimpleWriteSyncGuard<Record>;
template class SimpleWriteSyncGuard<Entry<float>>;
template class SimpleWriteSyncGuard<Entry<double>>;
template class SimpleWriteSyncGuard<Entry<bool>>;
template class SimpleWriteSyncGuard<Entry<long>>;
template class SimpleWriteSyncGuard<Entry<long long>>;
template class SimpleWriteSyncGuard<Entry<std::string>>;
template class SimpleWriteSyncGuard<Entry<std::vector<byte_t>>>;
template class SimpleWriteSyncGuard<Entry<int>>;
template class SimpleWriteSyncGuard<Entry>;
} // namespace visualdl
......@@ -8,14 +8,15 @@ namespace py = pybind11;
namespace vs = visualdl;
namespace cp = visualdl::components;
#define ADD_FULL_TYPE_IMPL(CODE) \
CODE(int32_t); \
CODE(int64_t); \
CODE(float); \
CODE(double);
PYBIND11_PLUGIN(core) {
py::module m("core", "C++ core of VisualDL");
#define READER_ADD_SCALAR(T) \
.def("get_scalar_" #T, [](vs::LogReader& self, const std::string& tag) { \
auto tablet = self.tablet(tag); \
return vs::components::ScalarReader<T>(std::move(tablet)); \
})
py::class_<vs::LogReader>(m, "LogReader")
.def("__init__",
[](vs::LogReader& instance, const std::string& dir) {
......@@ -25,23 +26,35 @@ PYBIND11_PLUGIN(core) {
.def("set_mode", &vs::LogReader::SetMode)
.def("modes", [](vs::LogReader& self) { return self.storage().modes(); })
.def("tags", &vs::LogReader::tags)
// clang-format off
// clang-format off
#define READER_ADD_SCALAR(T) \
.def("get_scalar_" #T, [](vs::LogReader& self, const std::string& tag) { \
auto tablet = self.tablet(tag); \
return vs::components::ScalarReader<T>(std::move(tablet)); \
})
READER_ADD_SCALAR(float)
READER_ADD_SCALAR(double)
READER_ADD_SCALAR(int)
#undef READER_ADD_SCALAR
#define READER_ADD_HISTOGRAM(T) \
.def("get_histogram_" #T, [](vs::LogReader& self, const std::string& tag) { \
auto tablet = self.tablet(tag); \
return vs::components::HistogramReader<T>(std::move(tablet)); \
})
READER_ADD_HISTOGRAM(float)
READER_ADD_HISTOGRAM(double)
READER_ADD_HISTOGRAM(int)
#undef READER_ADD_HISTOGRAM
// clang-format on
.def("get_image", [](vs::LogReader& self, const std::string& tag) {
auto tablet = self.tablet(tag);
return vs::components::ImageReader(self.mode(), tablet);
});
#undef READER_ADD_SCALAR
#define WRITER_ADD_SCALAR(T) \
.def("new_scalar_" #T, [](vs::LogWriter& self, const std::string& tag) { \
auto tablet = self.AddTablet(tag); \
return cp::Scalar<T>(tablet); \
})
// clang-format on
py::class_<vs::LogWriter>(m, "LogWriter")
.def("__init__",
[](vs::LogWriter& instance, const std::string& dir, int sync_cycle) {
......@@ -49,10 +62,24 @@ PYBIND11_PLUGIN(core) {
})
.def("set_mode", &vs::LogWriter::SetMode)
.def("as_mode", &vs::LogWriter::AsMode)
// clang-format off
// clang-format off
#define WRITER_ADD_SCALAR(T) \
.def("new_scalar_" #T, [](vs::LogWriter& self, const std::string& tag) { \
auto tablet = self.AddTablet(tag); \
return cp::Scalar<T>(tablet); \
})
#define WRITER_ADD_HISTOGRAM(T) \
.def("new_histogram_" #T, \
[](vs::LogWriter& self, const std::string& tag, int num_buckets) { \
auto tablet = self.AddTablet(tag); \
return cp::Histogram<T>(tablet, num_buckets); \
})
WRITER_ADD_SCALAR(float)
WRITER_ADD_SCALAR(double)
WRITER_ADD_SCALAR(int)
WRITER_ADD_HISTOGRAM(float)
WRITER_ADD_HISTOGRAM(double)
WRITER_ADD_HISTOGRAM(int)
// clang-format on
.def("new_image",
[](vs::LogWriter& self,
......@@ -108,7 +135,44 @@ PYBIND11_PLUGIN(core) {
.def("record", &cp::ImageReader::record)
.def("timestamp", &cp::ImageReader::timestamp);
// .def("data", &cp::ImageReader::data)
// .def("shape", &cp::ImageReader::shape);
#define ADD_HISTOGRAM_WRITER(T) \
py::class_<cp::Histogram<T>>(m, "HistogramWriter__" #T) \
.def("add_record", &cp::Histogram<T>::AddRecord);
ADD_FULL_TYPE_IMPL(ADD_HISTOGRAM_WRITER)
#undef ADD_HISTOGRAM_WRITER
#define ADD_HISTOGRAM_RECORD_INSTANCE(T) \
py::class_<vs::HistogramRecord<T>::Instance>(m, "HistogramInstance__" #T) \
.def("left", \
[](typename vs::HistogramRecord<T>::Instance& self) { \
return self.left; \
}) \
.def("right", \
[](typename vs::HistogramRecord<T>::Instance& self) { \
return self.right; \
}) \
.def("frequency", [](typename vs::HistogramRecord<T>::Instance& self) { \
return self.frequency; \
});
ADD_FULL_TYPE_IMPL(ADD_HISTOGRAM_RECORD_INSTANCE)
#undef ADD_HISTOGRAM_RECORD_INSTANCE
#define ADD_HISTOGRAM_RECORD(T) \
py::class_<vs::HistogramRecord<T>>(m, "HistogramRecord__" #T) \
.def("step", [](vs::HistogramRecord<T>& self) { return self.step; }) \
.def("timestamp", \
[](vs::HistogramRecord<T>& self) { return self.timestamp; }) \
.def("instance", &vs::HistogramRecord<T>::instance) \
.def("num_instances", &vs::HistogramRecord<T>::num_instances);
ADD_FULL_TYPE_IMPL(ADD_HISTOGRAM_RECORD)
#undef ADD_HISTOGRAM_RECORD
#define ADD_HISTOGRAM_READER(T) \
py::class_<cp::HistogramReader<T>>(m, "HistogramReader__" #T) \
.def("num_records", &cp::HistogramReader<T>::num_records) \
.def("record", &cp::HistogramReader<T>::record);
ADD_FULL_TYPE_IMPL(ADD_HISTOGRAM_READER)
#undef ADD_HISTOGRAM_READER
} // end pybind
#include "visualdl/logic/sdk.h"
#include "visualdl/logic/histogram.h"
#include "visualdl/utils/image.h"
#include "visualdl/utils/macro.h"
namespace visualdl {
......@@ -10,7 +12,7 @@ template <typename T>
std::vector<T> ScalarReader<T>::records() const {
std::vector<T> res;
for (int i = 0; i < reader_.total_records(); i++) {
res.push_back(reader_.record(i).template data<T>(0).Get());
res.push_back(reader_.record(i).data(0).template Get<T>());
}
return res;
}
......@@ -44,11 +46,6 @@ size_t ScalarReader<T>::size() const {
return reader_.total_records();
}
template class ScalarReader<int>;
template class ScalarReader<int64_t>;
template class ScalarReader<float>;
template class ScalarReader<double>;
void Image::StartSampling() {
if (!ToSampleThisStep()) return;
......@@ -60,7 +57,7 @@ void Image::StartSampling() {
// resize record
for (int i = 0; i < num_samples_; i++) {
step_.AddData<value_t>();
step_.AddData();
}
num_records_ = 0;
}
......@@ -129,11 +126,14 @@ void Image::SetSample(int index,
!is_same_type<value_t, shape_t>::value,
"value_t should not use int64_t field, this type is used to store shape");
// set meta with hack
Entry<shape_t> meta;
meta.set_parent(entry.parent());
meta.entry = entry.entry;
meta.SetMulti(shape);
// set meta.
entry.SetMulti(shape);
// // set meta with hack
// Entry<shape_t> meta;
// meta.set_parent(entry.parent());
// meta.entry = entry.entry;
// meta.SetMulti(shape);
}
std::string ImageReader::caption() {
......@@ -149,18 +149,56 @@ std::string ImageReader::caption() {
ImageReader::ImageRecord ImageReader::record(int offset, int index) {
ImageRecord res;
auto record = reader_.record(offset);
auto data_entry = record.data<std::vector<byte_t>>(index);
auto shape_entry = record.data<shape_t>(index);
auto data_str = data_entry.GetRaw();
auto entry = record.data(index);
auto data_str = entry.GetRaw();
std::transform(data_str.begin(),
data_str.end(),
std::back_inserter(res.data),
[](byte_t i) { return (int)(i); });
res.shape = shape_entry.GetMulti();
res.shape = entry.GetMulti<shape_t>();
res.step_id = record.id();
return res;
}
template <typename T>
void Histogram<T>::AddRecord(int step, const std::vector<T>& data) {
HistogramBuilder<T> builder(num_buckets_);
builder(data);
auto record = writer_.AddRecord();
record.SetId(step);
time_t time = std::time(nullptr);
record.SetTimeStamp(time);
// set frequencies.
auto entry = record.AddData();
entry.SetMulti<int32_t>(builder.buckets);
// Serialize left and right boundaries.
std::string boundaries_str = std::to_string(builder.left_boundary) + " " +
std::to_string(builder.right_boundary);
entry.SetRaw(boundaries_str);
}
template <typename T>
HistogramRecord<T> HistogramReader<T>::record(int i) {
CHECK_LT(i, reader_.total_records());
auto r = reader_.record(i);
auto d = r.data(0);
auto boundaries_str = d.GetRaw();
std::stringstream ss(boundaries_str);
T left, right;
ss >> left >> right;
auto frequency = d.GetMulti<int32_t>();
auto timestamp = r.timestamp();
auto step = r.id();
return HistogramRecord<T>(timestamp, step, left, right, std::move(frequency));
}
DECL_BASIC_TYPES_CLASS_IMPL(class, ScalarReader)
DECL_BASIC_TYPES_CLASS_IMPL(struct, Histogram)
DECL_BASIC_TYPES_CLASS_IMPL(struct, HistogramReader)
} // namespace components
} // namespace visualdl
#ifndef VISUALDL_LOGIC_SDK_H
#define VISUALDL_LOGIC_SDK_H
#include "visualdl/logic/histogram.h"
#include "visualdl/storage/storage.h"
#include "visualdl/storage/tablet.h"
#include "visualdl/utils/string.h"
namespace visualdl {
const static std::string kDefaultMode{"default"};
......@@ -82,7 +84,8 @@ public:
std::vector<std::string> tags(const std::string& component) {
auto type = Tablet::type(component);
auto tags = reader_.tags(type);
CHECK(!tags.empty());
CHECK(!tags.empty()) << "component " << component
<< " has no taged records";
std::vector<std::string> res;
for (const auto& tag : tags) {
if (TagMatchMode(tag, mode_)) {
......@@ -130,10 +133,10 @@ struct Scalar {
void AddRecord(int id, T value) {
auto record = tablet_.AddRecord();
record.SetId(id);
auto entry = record.AddData();
time_t time = std::time(nullptr);
record.SetTimeStamp(time);
auto entry = record.template AddData<T>();
entry.Set(value);
}
......@@ -262,6 +265,32 @@ private:
std::string mode_;
};
template <typename T>
struct Histogram {
Histogram(Tablet tablet, int num_buckets)
: writer_(tablet), num_buckets_(num_buckets) {
writer_.SetType(Tablet::Type::kHistogram);
}
void AddRecord(int step, const std::vector<T>& data);
private:
int num_buckets_;
Tablet writer_;
};
template <typename T>
struct HistogramReader {
HistogramReader(TabletReader tablet) : reader_(tablet) {}
size_t num_records() { return reader_.total_records(); }
HistogramRecord<T> record(int i);
private:
TabletReader reader_;
};
} // namespace components
} // namespace visualdl
......
......@@ -81,6 +81,22 @@ TEST(Image, test) {
CHECK_EQ(image2read.num_records(), num_steps);
}
TEST(Histogram, AddRecord) {
const auto dir = "./tmp/sdk_test.histogram";
LogWriter writer__(dir, 1);
auto writer = writer__.AsMode("train");
auto tablet = writer.AddTablet("histogram0");
components::Histogram<float> histogram(tablet, 10);
std::vector<float> data(1000);
for (auto& v : data) {
v = (float)rand() / RAND_MAX;
}
histogram.AddRecord(10, data);
}
TEST(Scalar, more_than_one_mode) {
const auto dir = "./tmp/sdk_multi_mode";
LogWriter log(dir, 20);
......
......@@ -40,6 +40,14 @@ class LogReader(object):
def image(self, tag):
return self.reader.get_image(tag)
def histogram(self, tag, type='float'):
type2scalar = {
'float': self.reader.get_histogram_float,
'double': self.reader.get_histogram_double,
'int': self.reader.get_histogram_int,
}
return type2scalar[type](tag)
def __enter__(self):
return self
......@@ -65,6 +73,9 @@ class LogWriter(object):
return LogWriter.cur_mode
def scalar(self, tag, type='float'):
'''
Create a scalar component.
'''
type2scalar = {
'float': self.writer.new_scalar_float,
'double': self.writer.new_scalar_double,
......@@ -73,8 +84,22 @@ class LogWriter(object):
return type2scalar[type](tag)
def image(self, tag, num_samples, step_cycle):
'''
Create an image component.
'''
return self.writer.new_image(tag, num_samples, step_cycle)
def histogram(self, tag, num_buckets, type='float'):
'''
Create a histogram component.
'''
types = {
'float': self.writer.new_histogram_float,
'double': self.writer.new_histogram_double,
'int': self.writer.new_histogram_int,
}
return types[type](tag, num_buckets)
def __enter__(self):
return self
......
......@@ -4,7 +4,7 @@ namespace visualdl {
#define IMPL_ENTRY_SET_OR_ADD(method__, ctype__, dtype__, opr__) \
template <> \
void Entry<ctype__>::method__(ctype__ v) { \
void Entry::method__<ctype__>(ctype__ v) { \
entry->set_dtype(storage::DataType::dtype__); \
entry->opr__(v); \
WRITE_GUARD \
......@@ -12,7 +12,7 @@ namespace visualdl {
#define IMPL_ENTRY_SETMUL(ctype__, dtype__, field__) \
template <> \
void Entry<ctype__>::SetMulti(const std::vector<ctype__>& vs) { \
void Entry::SetMulti<ctype__>(const std::vector<ctype__>& vs) { \
entry->set_dtype(storage::DataType::dtype__); \
entry->clear_##field__(); \
for (auto v : vs) { \
......@@ -22,14 +22,14 @@ namespace visualdl {
}
template <>
void Entry<std::vector<byte_t>>::Set(std::vector<byte_t> v) {
void Entry::Set<std::vector<byte_t>>(std::vector<byte_t> v) {
entry->set_dtype(storage::DataType::kBytes);
entry->set_y(std::string(v.begin(), v.end()));
WRITE_GUARD
}
template <>
void Entry<std::vector<byte_t>>::Add(std::vector<byte_t> v) {
void Entry::Add<std::vector<byte_t>>(std::vector<byte_t> v) {
entry->set_dtype(storage::DataType::kBytess);
*entry->add_ys() = std::string(v.begin(), v.end());
WRITE_GUARD
......@@ -56,7 +56,7 @@ IMPL_ENTRY_SETMUL(bool, kBool, bs);
#define IMPL_ENTRY_GET(T, fieldname__) \
template <> \
T EntryReader<T>::Get() const { \
T EntryReader::Get<T>() const { \
return data_.fieldname__(); \
}
......@@ -68,14 +68,14 @@ IMPL_ENTRY_GET(std::string, s);
IMPL_ENTRY_GET(bool, b);
template <>
std::vector<uint8_t> EntryReader<std::vector<byte_t>>::Get() const {
std::vector<uint8_t> EntryReader::Get<std::vector<byte_t>>() const {
const auto& y = data_.y();
return std::vector<byte_t>(y.begin(), y.end());
}
#define IMPL_ENTRY_GET_MULTI(T, fieldname__) \
template <> \
std::vector<T> EntryReader<T>::GetMulti() const { \
std::vector<T> EntryReader::GetMulti<T>() const { \
return std::vector<T>(data_.fieldname__().begin(), \
data_.fieldname__().end()); \
}
......@@ -87,16 +87,4 @@ IMPL_ENTRY_GET_MULTI(double, ds);
IMPL_ENTRY_GET_MULTI(std::string, ss);
IMPL_ENTRY_GET_MULTI(bool, bs);
template class Entry<int>;
template class Entry<float>;
template class Entry<double>;
template class Entry<bool>;
template class Entry<std::vector<byte_t>>;
template class EntryReader<int>;
template class EntryReader<float>;
template class EntryReader<double>;
template class EntryReader<bool>;
template class EntryReader<std::vector<byte_t>>;
} // namespace visualdl
......@@ -14,15 +14,14 @@ using byte_t = unsigned char;
/*
* Utility helper for storage::Entry.
*/
template <typename T>
struct Entry {
DECL_GUARD(Entry<T>)
DECL_GUARD(Entry)
// use pointer to avoid copy
storage::Entry* entry{nullptr};
Entry() {}
Entry(storage::Entry* entry, Storage* parent) : entry(entry), x_(parent) {}
Entry(const Entry<T>& other) : entry(other.entry), x_(other.x_) {}
Entry(const Entry& other) : entry(other.entry), x_(other.x_) {}
void operator()(storage::Entry* entry, Storage* parent) {
this->entry = entry;
......@@ -30,13 +29,16 @@ struct Entry {
}
// Set a single value.
template <typename T>
void Set(T v);
void SetRaw(const std::string& bytes) { entry->set_y(bytes); }
// Add a value to repeated message field.
template <typename T>
void Add(T v);
template <typename T>
void SetMulti(const std::vector<T>& v);
Storage* parent() { return x_; }
......@@ -46,12 +48,13 @@ private:
Storage* x_;
};
template <typename T>
struct EntryReader {
EntryReader(storage::Entry x) : data_(x) {}
// Get a single value.
template <typename T>
T Get() const;
// Get repeated field.
template <typename T>
std::vector<T> GetMulti() const;
std::string GetRaw() { return data_.y(); }
......
......@@ -51,20 +51,19 @@ struct Record {
}
template <typename T>
Entry<T> MutableMeta() {
return Entry<T>(data_->mutable_meta(), parent());
Entry MutableMeta() {
return Entry(data_->mutable_meta(), parent());
}
template <typename T>
Entry<T> AddData() {
Entry AddData() {
WRITE_GUARD
return Entry<T>(data_->add_data(), parent());
return Entry(data_->add_data(), parent());
}
template <typename T>
Entry<T> MutableData(int i) {
Entry MutableData(int i) {
WRITE_GUARD
return Entry<T>(data_->mutable_data(i), parent());
return Entry(data_->mutable_data(i), parent());
}
Storage* parent() { return x_; }
......@@ -80,19 +79,13 @@ struct RecordReader {
// read operations
size_t data_size() const { return data_.data_size(); }
template <typename T>
EntryReader<T> data(int i) {
return EntryReader<T>(data_.data(i));
}
EntryReader data(int i) { return EntryReader(data_.data(i)); }
int64_t timestamp() const { return data_.timestamp(); }
int64_t id() const { return data_.id(); }
Record::Dtype dtype() const { return (Record::Dtype)data_.dtype(); }
template <typename T>
Entry<T> meta() const {
return data_.meta();
}
EntryReader meta() const { return data_.meta(); }
private:
storage::Record data_;
......
......@@ -21,7 +21,7 @@ TEST_F(StorageTest, main) {
auto tag0 = storage.AddTablet("tag0");
auto tag1 = storage.AddTablet("tag1");
auto record = tag0.AddRecord();
auto entry = record.AddData<int>();
auto entry = record.AddData();
entry.Set(12);
StorageReader reader("./tmp/storage_test");
......
......@@ -60,8 +60,8 @@ struct Tablet {
}
template <typename T>
Entry<T> MutableMeta() {
Entry<T> x(data_->mutable_meta(), parent());
Entry MutableMeta() {
Entry x(data_->mutable_meta(), parent());
}
void SetCaptions(const std::vector<std::string>& xs) {
......@@ -104,8 +104,8 @@ struct TabletReader {
int32_t num_samples() const { return data_.num_samples(); }
RecordReader record(int i) const { return RecordReader(data_.records(i)); }
template <typename T>
EntryReader<T> meta() const {
return EntryReader<T>(data_.meta());
EntryReader meta() const {
return EntryReader(data_.meta());
}
std::vector<std::string> captions() const {
std::vector<std::string> x(data_.captions().begin(),
......
......@@ -67,6 +67,7 @@ static void NormalizeImage(Uint8Image* image,
scale = (image_max < kZeroThreshold ? 0.0f : 255.0f) / image_max;
offset = 0.0f;
}
// Transform image, turning nonfinite values to bad_color
for (int i = 0; i < depth; i++) {
auto tmp = scale * values.row(i).array() + offset;
......
#ifndef VISUALDL_UTILS_MACRO_H
#define VISUALDL_UTILS_MACRO_H
#define DECL_BASIC_TYPES_CLASS_IMPL(class__, name__) \
template class__ name__<int32_t>; \
template class__ name__<int64_t>; \
template class__ name__<float>; \
template class__ name__<double>;
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册