提交 66a17d70 编写于 作者: S superjom

finish test

上级 5895b7e6
...@@ -15,7 +15,7 @@ PYBIND11_PLUGIN(core) { ...@@ -15,7 +15,7 @@ PYBIND11_PLUGIN(core) {
py::class_<cp::ScalarReader<T>>(m, "ScalarReader__" #T) \ py::class_<cp::ScalarReader<T>>(m, "ScalarReader__" #T) \
.def("records", &cp::ScalarReader<T>::records) \ .def("records", &cp::ScalarReader<T>::records) \
.def("timestamps", &cp::ScalarReader<T>::timestamps) \ .def("timestamps", &cp::ScalarReader<T>::timestamps) \
.def("captions", &cp::ScalarReader<T>::captions); .def("caption", &cp::ScalarReader<T>::caption);
ADD_SCALAR(int); ADD_SCALAR(int);
ADD_SCALAR(float); ADD_SCALAR(float);
ADD_SCALAR(double); ADD_SCALAR(double);
...@@ -56,15 +56,18 @@ PYBIND11_PLUGIN(core) { ...@@ -56,15 +56,18 @@ PYBIND11_PLUGIN(core) {
}) })
py::class_<vs::Writer>(m, "Writer") py::class_<vs::Writer>(m, "Writer")
.def( .def("__init__",
"__init__", [](vs::Writer& instance,
[](vs::Writer& instance, const std::string& dir,
const std::string& mode, int sync_cycle) {
const std::string& dir) { new (&instance) vs::Writer(mode, dir); }) new (&instance) vs::Writer(dir);
instance.storage().meta.cycle = sync_cycle;
})
.def("as_mode", &vs::Writer::AsMode)
// clang-format off // clang-format off
ADD_SCALAR(float) ADD_SCALAR(float)
ADD_SCALAR(double) ADD_SCALAR(double)
ADD_SCALAR(int); ADD_SCALAR(int);
// clang-format on // clang-format on
#undef ADD_SCALAR #undef ADD_SCALAR
......
...@@ -8,15 +8,23 @@ namespace visualdl { ...@@ -8,15 +8,23 @@ namespace visualdl {
class Writer { class Writer {
public: public:
Writer(const std::string& mode, const std::string& dir) : mode_(mode) { Writer(const std::string& dir) {
storage_.SetDir(dir); storage_.SetDir(dir);
} }
Writer& AsMode(const std::string& mode) {
mode_ = mode;
storage_.AddMode(mode);
return *this;
}
Tablet AddTablet(const std::string& tag) { Tablet AddTablet(const std::string& tag) {
// TODO(ChunweiYan) add string check here. // TODO(ChunweiYan) add string check here.
auto tmp = mode_ + "/" + tag; auto tmp = mode_ + "/" + tag;
string::TagEncode(tmp); string::TagEncode(tmp);
return storage_.AddTablet(tmp); auto res = storage_.AddTablet(tmp);
res.SetCaptions(std::vector<std::string>({mode_}));
return res;
} }
Storage& storage() { return storage_; } Storage& storage() { return storage_; }
...@@ -39,7 +47,7 @@ public: ...@@ -39,7 +47,7 @@ public:
private: private:
StorageReader reader_; StorageReader reader_;
std::string mode_; std::string mode_{"default"};
}; };
namespace components { namespace components {
...@@ -75,7 +83,8 @@ struct ScalarReader { ...@@ -75,7 +83,8 @@ struct ScalarReader {
std::vector<T> records() const; std::vector<T> records() const;
std::vector<T> ids() const; std::vector<T> ids() const;
std::vector<T> timestamps() const; std::vector<T> timestamps() const;
std::vector<std::string> captions() const; std::string caption() const;
size_t total_records() {return reader_.total_records();}
size_t size() const; size_t size() const;
private: private:
...@@ -110,8 +119,9 @@ std::vector<T> ScalarReader<T>::timestamps() const { ...@@ -110,8 +119,9 @@ std::vector<T> ScalarReader<T>::timestamps() const {
} }
template <typename T> template <typename T>
std::vector<std::string> ScalarReader<T>::captions() const { std::string ScalarReader<T>::caption() const {
return reader_.captions(); CHECK(!reader_.captions().empty()) << "no caption";
return reader_.captions().front();
} }
template <typename T> template <typename T>
......
...@@ -14,18 +14,19 @@ TEST(Scalar, write) { ...@@ -14,18 +14,19 @@ TEST(Scalar, write) {
components::Scalar<int> scalar(tablet); components::Scalar<int> scalar(tablet);
scalar.SetCaption("train"); scalar.SetCaption("train");
scalar.AddRecord(0, 12); scalar.AddRecord(0, 12);
storage.PersistToDisk();
// read from disk // read from disk
StorageReader reader(dir); StorageReader reader(dir);
auto scalar_reader = reader.tablet("scalar0"); auto tablet_reader = reader.tablet("scalar0");
auto captioins = scalar_reader.captions(); auto scalar_reader = components::ScalarReader<int>(std::move(tablet_reader));
ASSERT_EQ(captioins.size(), 1); auto captioin = scalar_reader.caption();
ASSERT_EQ(captioins.front(), "train"); ASSERT_EQ(captioin, "train");
ASSERT_EQ(scalar_reader.total_records(), 1); ASSERT_EQ(scalar_reader.total_records(), 1);
auto record = scalar_reader.record(0); auto record = scalar_reader.records();
ASSERT_EQ(record.size(), 1);
// check the first entry of first record // check the first entry of first record
auto vs = record.data<int>(0).Get(); ASSERT_EQ(record.front(), 12);
ASSERT_EQ(vs, 12);
} }
} // namespace visualdl } // namespace visualdl
...@@ -23,8 +23,12 @@ class StorageReader(object): ...@@ -23,8 +23,12 @@ class StorageReader(object):
class StorageWriter(object): class StorageWriter(object):
def __init__(self, mode, dir): def __init__(self, dir, sync_cycle):
self.writer = core.Writer(mode, dir) self.writer = core.Writer(dir, sync_cycle)
def as_mode(self, mode):
self.writer = self.writer.as_mode(mode)
return self
def scalar(self, tag, type='float'): def scalar(self, tag, type='float'):
type2scalar = { type2scalar = {
......
...@@ -7,18 +7,18 @@ import time ...@@ -7,18 +7,18 @@ import time
class StorageTest(unittest.TestCase): class StorageTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.dir = "./tmp/storage_test" self.dir = "./tmp/storage_test"
self.writer = storage.StorageWriter("train", self.dir) self.writer = storage.StorageWriter(self.dir, sync_cycle=1).as_mode("train")
def test_write(self): def test_write(self):
scalar = self.writer.scalar("model/scalar/min") scalar = self.writer.scalar("model/scalar/min")
scalar.set_caption("model/scalar/min") scalar.set_caption("model/scalar/min")
for i in range(10): for i in range(10):
scalar.add_record(i, [1.0]) scalar.add_record(i, 1.0)
def test_read(self): def test_read(self):
self.reader = storage.StorageReader("train", self.dir) self.reader = storage.StorageReader("train", self.dir)
scalar = self.reader.scalar("model/scalar/min") scalar = self.reader.scalar("model/scalar/min")
self.assertEqual(scalar.get_caption(), "model/scalar/min") self.assertEqual(scalar.caption(), "train")
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
#define VISUALDL_STORAGE_STORAGE_H #define VISUALDL_STORAGE_STORAGE_H
#include <glog/logging.h> #include <glog/logging.h>
#include <visualdl/utils/guard.h>
#include <vector> #include <vector>
#include <set>
#include "visualdl/logic/im.h" #include "visualdl/logic/im.h"
#include "visualdl/utils/guard.h"
#include "visualdl/storage/storage.pb.h" #include "visualdl/storage/storage.pb.h"
#include "visualdl/storage/tablet.h" #include "visualdl/storage/tablet.h"
#include "visualdl/utils/filesystem.h" #include "visualdl/utils/filesystem.h"
...@@ -49,7 +50,10 @@ struct Storage { ...@@ -49,7 +50,10 @@ struct Storage {
// write operations // write operations
void AddMode(const std::string& x) { void AddMode(const std::string& x) {
// avoid duplicate modes.
if (modes_.count(x) != 0) return;
*data_->add_modes() = x; *data_->add_modes() = x;
modes_.insert(x);
WRITE_GUARD WRITE_GUARD
} }
...@@ -68,7 +72,7 @@ struct Storage { ...@@ -68,7 +72,7 @@ struct Storage {
* Save memory to disk. * Save memory to disk.
*/ */
void PersistToDisk(const std::string& dir) { void PersistToDisk(const std::string& dir) {
LOG(INFO) << "persist to disk " << dir; // LOG(INFO) << "persist to disk " << dir;
CHECK(!dir.empty()) << "dir should be set."; CHECK(!dir.empty()) << "dir should be set.";
fs::TryRecurMkdir(dir); fs::TryRecurMkdir(dir);
...@@ -91,6 +95,7 @@ private: ...@@ -91,6 +95,7 @@ private:
std::string dir_; std::string dir_;
std::map<std::string, storage::Tablet> tablets_; std::map<std::string, storage::Tablet> tablets_;
std::shared_ptr<storage::Storage> data_; std::shared_ptr<storage::Storage> data_;
std::set<std::string> modes_;
}; };
/* /*
......
...@@ -48,6 +48,7 @@ struct Tablet { ...@@ -48,6 +48,7 @@ struct Tablet {
} }
void SetCaptions(const std::vector<std::string>& xs) { void SetCaptions(const std::vector<std::string>& xs) {
data_->clear_captions();
for (const auto& x : xs) { for (const auto& x : xs) {
*data_->add_captions() = x; *data_->add_captions() = x;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册