提交 66a17d70 编写于 作者: S superjom

finish test

上级 5895b7e6
......@@ -15,7 +15,7 @@ PYBIND11_PLUGIN(core) {
py::class_<cp::ScalarReader<T>>(m, "ScalarReader__" #T) \
.def("records", &cp::ScalarReader<T>::records) \
.def("timestamps", &cp::ScalarReader<T>::timestamps) \
.def("captions", &cp::ScalarReader<T>::captions);
.def("caption", &cp::ScalarReader<T>::caption);
ADD_SCALAR(int);
ADD_SCALAR(float);
ADD_SCALAR(double);
......@@ -56,15 +56,18 @@ PYBIND11_PLUGIN(core) {
})
py::class_<vs::Writer>(m, "Writer")
.def(
"__init__",
[](vs::Writer& instance,
const std::string& mode,
const std::string& dir) { new (&instance) vs::Writer(mode, dir); })
.def("__init__",
[](vs::Writer& instance,
const std::string& dir,
int sync_cycle) {
new (&instance) vs::Writer(dir);
instance.storage().meta.cycle = sync_cycle;
})
.def("as_mode", &vs::Writer::AsMode)
// clang-format off
ADD_SCALAR(float)
ADD_SCALAR(double)
ADD_SCALAR(int);
ADD_SCALAR(float)
ADD_SCALAR(double)
ADD_SCALAR(int);
// clang-format on
#undef ADD_SCALAR
......
......@@ -8,15 +8,23 @@ namespace visualdl {
class Writer {
public:
Writer(const std::string& mode, const std::string& dir) : mode_(mode) {
Writer(const std::string& dir) {
storage_.SetDir(dir);
}
Writer& AsMode(const std::string& mode) {
mode_ = mode;
storage_.AddMode(mode);
return *this;
}
Tablet AddTablet(const std::string& tag) {
// TODO(ChunweiYan) add string check here.
auto tmp = mode_ + "/" + tag;
string::TagEncode(tmp);
return storage_.AddTablet(tmp);
auto res = storage_.AddTablet(tmp);
res.SetCaptions(std::vector<std::string>({mode_}));
return res;
}
Storage& storage() { return storage_; }
......@@ -39,7 +47,7 @@ public:
private:
StorageReader reader_;
std::string mode_;
std::string mode_{"default"};
};
namespace components {
......@@ -75,7 +83,8 @@ struct ScalarReader {
std::vector<T> records() const;
std::vector<T> ids() const;
std::vector<T> timestamps() const;
std::vector<std::string> captions() const;
std::string caption() const;
size_t total_records() {return reader_.total_records();}
size_t size() const;
private:
......@@ -110,8 +119,9 @@ std::vector<T> ScalarReader<T>::timestamps() const {
}
template <typename T>
std::vector<std::string> ScalarReader<T>::captions() const {
return reader_.captions();
std::string ScalarReader<T>::caption() const {
CHECK(!reader_.captions().empty()) << "no caption";
return reader_.captions().front();
}
template <typename T>
......
......@@ -14,18 +14,19 @@ TEST(Scalar, write) {
components::Scalar<int> scalar(tablet);
scalar.SetCaption("train");
scalar.AddRecord(0, 12);
storage.PersistToDisk();
// read from disk
StorageReader reader(dir);
auto scalar_reader = reader.tablet("scalar0");
auto captioins = scalar_reader.captions();
ASSERT_EQ(captioins.size(), 1);
ASSERT_EQ(captioins.front(), "train");
auto tablet_reader = reader.tablet("scalar0");
auto scalar_reader = components::ScalarReader<int>(std::move(tablet_reader));
auto captioin = scalar_reader.caption();
ASSERT_EQ(captioin, "train");
ASSERT_EQ(scalar_reader.total_records(), 1);
auto record = scalar_reader.record(0);
auto record = scalar_reader.records();
ASSERT_EQ(record.size(), 1);
// check the first entry of first record
auto vs = record.data<int>(0).Get();
ASSERT_EQ(vs, 12);
ASSERT_EQ(record.front(), 12);
}
} // namespace visualdl
......@@ -23,8 +23,12 @@ class StorageReader(object):
class StorageWriter(object):
def __init__(self, mode, dir):
self.writer = core.Writer(mode, dir)
def __init__(self, dir, sync_cycle):
self.writer = core.Writer(dir, sync_cycle)
def as_mode(self, mode):
self.writer = self.writer.as_mode(mode)
return self
def scalar(self, tag, type='float'):
type2scalar = {
......
......@@ -7,18 +7,18 @@ import time
class StorageTest(unittest.TestCase):
def setUp(self):
self.dir = "./tmp/storage_test"
self.writer = storage.StorageWriter("train", self.dir)
self.writer = storage.StorageWriter(self.dir, sync_cycle=1).as_mode("train")
def test_write(self):
scalar = self.writer.scalar("model/scalar/min")
scalar.set_caption("model/scalar/min")
for i in range(10):
scalar.add_record(i, [1.0])
scalar.add_record(i, 1.0)
def test_read(self):
self.reader = storage.StorageReader("train", self.dir)
scalar = self.reader.scalar("model/scalar/min")
self.assertEqual(scalar.get_caption(), "model/scalar/min")
self.assertEqual(scalar.caption(), "train")
if __name__ == '__main__':
......
......@@ -2,10 +2,11 @@
#define VISUALDL_STORAGE_STORAGE_H
#include <glog/logging.h>
#include <visualdl/utils/guard.h>
#include <vector>
#include <set>
#include "visualdl/logic/im.h"
#include "visualdl/utils/guard.h"
#include "visualdl/storage/storage.pb.h"
#include "visualdl/storage/tablet.h"
#include "visualdl/utils/filesystem.h"
......@@ -49,7 +50,10 @@ struct Storage {
// write operations
void AddMode(const std::string& x) {
// avoid duplicate modes.
if (modes_.count(x) != 0) return;
*data_->add_modes() = x;
modes_.insert(x);
WRITE_GUARD
}
......@@ -68,7 +72,7 @@ struct Storage {
* Save memory to disk.
*/
void PersistToDisk(const std::string& dir) {
LOG(INFO) << "persist to disk " << dir;
// LOG(INFO) << "persist to disk " << dir;
CHECK(!dir.empty()) << "dir should be set.";
fs::TryRecurMkdir(dir);
......@@ -91,6 +95,7 @@ private:
std::string dir_;
std::map<std::string, storage::Tablet> tablets_;
std::shared_ptr<storage::Storage> data_;
std::set<std::string> modes_;
};
/*
......
......@@ -48,6 +48,7 @@ struct Tablet {
}
void SetCaptions(const std::vector<std::string>& xs) {
data_->clear_captions();
for (const auto& x : xs) {
*data_->add_captions() = x;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册