提交 5895b7e6 编写于 作者: S superjom

add pybind

上级 befaf66c
...@@ -43,7 +43,7 @@ add_executable(vl_test ...@@ -43,7 +43,7 @@ add_executable(vl_test
${PROJECT_SOURCE_DIR}/visualdl/utils/concurrency.h ${PROJECT_SOURCE_DIR}/visualdl/utils/concurrency.h
${PROJECT_SOURCE_DIR}/visualdl/utils/filesystem.h ${PROJECT_SOURCE_DIR}/visualdl/utils/filesystem.h
) )
target_link_libraries(vl_test entry sdk storage im gtest glog protobuf gflags pthread) target_link_libraries(vl_test sdk storage entry im gtest glog protobuf gflags pthread)
enable_testing () enable_testing ()
......
...@@ -6,6 +6,6 @@ add_dependencies(sdk entry storage storage_proto) ...@@ -6,6 +6,6 @@ add_dependencies(sdk entry storage storage_proto)
## pybind ## pybind
add_library(core SHARED ${PROJECT_SOURCE_DIR}/visualdl/logic/pybind.cc) add_library(core SHARED ${PROJECT_SOURCE_DIR}/visualdl/logic/pybind.cc)
add_dependencies(core pybind python im storage sdk protobuf glog) add_dependencies(core pybind python im entry storage sdk protobuf glog)
target_link_libraries(core PRIVATE pybind python im storage sdk protobuf glog) target_link_libraries(core PRIVATE pybind entry python im storage sdk protobuf glog)
set_target_properties(core PROPERTIES PREFIX "" SUFFIX ".so") set_target_properties(core PROPERTIES PREFIX "" SUFFIX ".so")
...@@ -45,7 +45,7 @@ PYBIND11_PLUGIN(core) { ...@@ -45,7 +45,7 @@ PYBIND11_PLUGIN(core) {
// clang-format off // clang-format off
ADD_SCALAR(float) ADD_SCALAR(float)
ADD_SCALAR(double) ADD_SCALAR(double)
ADD_SCALAR(int) ADD_SCALAR(int);
// clang-format on // clang-format on
#undef ADD_SCALAR #undef ADD_SCALAR
...@@ -64,7 +64,7 @@ PYBIND11_PLUGIN(core) { ...@@ -64,7 +64,7 @@ PYBIND11_PLUGIN(core) {
// clang-format off // clang-format off
ADD_SCALAR(float) ADD_SCALAR(float)
ADD_SCALAR(double) ADD_SCALAR(double)
ADD_SCALAR(int) ADD_SCALAR(int);
// clang-format on // clang-format on
#undef ADD_SCALAR #undef ADD_SCALAR
......
...@@ -4,61 +4,61 @@ namespace visualdl { ...@@ -4,61 +4,61 @@ namespace visualdl {
namespace components { namespace components {
template <typename T> // template <typename T>
void components::Scalar<T>::AddRecord(int id, const std::vector<T> &values) { // void components::Scalar<T>::AddRecord(int id, const std::vector<T> &values) {
// add record data // // add record data
auto record = tablet_.AddRecord(); // auto record = tablet_.AddRecord();
auto entry = record.AddData<T>(); // auto entry = record.AddData<T>();
for (auto v : values) { // for (auto v : values) {
entry.Add(v); // entry.Add(v);
} // }
// set record id // // set record id
record.SetId(id); // record.SetId(id);
// set record timestamp // // set record timestamp
record.SetTimeStamp(time(NULL)); // record.SetTimeStamp(time(NULL));
} // }
template <typename T> // template <typename T>
std::vector<T> ScalarReader<T>::records() const { // std::vector<T> ScalarReader<T>::records() const {
std::vector<T> res; // std::vector<T> res;
for (int i = 0; i < reader_.total_records(); i++) { // for (int i = 0; i < reader_.total_records(); i++) {
res.push_back(reader_.record(i).data<T>(0)); // res.push_back(reader_.record(i).data<T>(0));
} // }
return res; // return res;
} // }
template <typename T> // template <typename T>
std::vector<int> ScalarReader<T>::ids() const { // std::vector<int> ScalarReader<T>::ids() const {
std::vector<int> res; // std::vector<int> res;
for (int i = 0; i < reader_.total_records(); i++) { // for (int i = 0; i < reader_.total_records(); i++) {
res.push_back(reader_.record(i).id()); // res.push_back(reader_.record(i).id());
} // }
return res; // return res;
} // }
template <typename T> // template <typename T>
std::vector<int> ScalarReader<T>::timestamps() const { // std::vector<int> ScalarReader<T>::timestamps() const {
std::vector<T> res; // std::vector<T> res;
for (int i = 0; i < reader_.total_records(); i++) { // for (int i = 0; i < reader_.total_records(); i++) {
res.push_back(reader_.record(i).timestamp()); // res.push_back(reader_.record(i).timestamp());
} // }
return res; // return res;
} // }
template <typename T> // template <typename T>
std::vector<std::string> ScalarReader<T>::captions() const { // std::vector<std::string> ScalarReader<T>::captions() const {
return reader_.captions(); // return reader_.captions();
} // }
template <typename T> // template <typename T>
size_t ScalarReader<T>::size() const { // size_t ScalarReader<T>::size() const {
return reader_.total_records(); // return reader_.total_records();
} // }
template class Scalar<int>; // template class Scalar<int>;
template class Scalar<int64_t>; // template class Scalar<int64_t>;
template class Scalar<float>; // template class Scalar<float>;
template class Scalar<double>; // template class Scalar<double>;
} // namespace components } // namespace components
......
...@@ -57,7 +57,12 @@ struct Scalar { ...@@ -57,7 +57,12 @@ struct Scalar {
tablet_.SetCaptions(std::vector<std::string>({cap})); tablet_.SetCaptions(std::vector<std::string>({cap}));
} }
void AddRecord(int id, const std::vector<T>& values); void AddRecord(int id, T value) {
auto record = tablet_.AddRecord();
record.SetId(id);
auto entry = record.AddData<T>();
entry.Set(value);
}
private: private:
Tablet tablet_; Tablet tablet_;
...@@ -68,8 +73,8 @@ struct ScalarReader { ...@@ -68,8 +73,8 @@ struct ScalarReader {
ScalarReader(TabletReader&& reader) : reader_(reader) {} ScalarReader(TabletReader&& reader) : reader_(reader) {}
std::vector<T> records() const; std::vector<T> records() const;
std::vector<int> ids() const; std::vector<T> ids() const;
std::vector<int> timestamps() const; std::vector<T> timestamps() const;
std::vector<std::string> captions() const; std::vector<std::string> captions() const;
size_t size() const; size_t size() const;
...@@ -77,6 +82,43 @@ private: ...@@ -77,6 +82,43 @@ private:
TabletReader reader_; TabletReader reader_;
}; };
template <typename T>
std::vector<T> ScalarReader<T>::records() const {
std::vector<T> res;
for (int i = 0; i < reader_.total_records(); i++) {
res.push_back(reader_.record(i).data<T>(0).Get());
}
return res;
}
template <typename T>
std::vector<T> ScalarReader<T>::ids() const {
std::vector<T> res;
for (int i = 0; i < reader_.total_records(); i++) {
res.push_back(reader_.record(i).id());
}
return res;
}
template <typename T>
std::vector<T> ScalarReader<T>::timestamps() const {
std::vector<T> res;
for (int i = 0; i < reader_.total_records(); i++) {
res.push_back(reader_.record(i).timestamp());
}
return res;
}
template <typename T>
std::vector<std::string> ScalarReader<T>::captions() const {
return reader_.captions();
}
template <typename T>
size_t ScalarReader<T>::size() const {
return reader_.total_records();
}
} // namespace components } // namespace components
} // namespace visualdl } // namespace visualdl
......
...@@ -13,7 +13,7 @@ TEST(Scalar, write) { ...@@ -13,7 +13,7 @@ TEST(Scalar, write) {
auto tablet = storage.AddTablet("scalar0"); auto tablet = storage.AddTablet("scalar0");
components::Scalar<int> scalar(tablet); components::Scalar<int> scalar(tablet);
scalar.SetCaption("train"); scalar.SetCaption("train");
scalar.AddRecord(0, std::vector<int>({12})); scalar.AddRecord(0, 12);
// read from disk // read from disk
StorageReader reader(dir); StorageReader reader(dir);
...@@ -24,9 +24,8 @@ TEST(Scalar, write) { ...@@ -24,9 +24,8 @@ TEST(Scalar, write) {
ASSERT_EQ(scalar_reader.total_records(), 1); ASSERT_EQ(scalar_reader.total_records(), 1);
auto record = scalar_reader.record(0); auto record = scalar_reader.record(0);
// check the first entry of first record // check the first entry of first record
auto vs = record.data<int>(0).GetMulti(); auto vs = record.data<int>(0).Get();
ASSERT_EQ(vs.size(), 1); ASSERT_EQ(vs, 12);
ASSERT_EQ(vs.front(), 12);
} }
} // namespace visualdl } // namespace visualdl
__all__ = [
'StorageReader',
'StorageWriter',
]
import core
dtypes = ("float", "double", "int32", "int64")
class StorageReader(object):
def __init__(self, mode, dir):
self.reader = core.Reader(mode, dir)
def scalar(self, tag, type='float'):
type2scalar = {
'float': self.reader.get_scalar_float,
'double': self.reader.get_scalar_double,
'int': self.reader.get_scalar_int,
}
return type2scalar[type](tag)
class StorageWriter(object):
def __init__(self, mode, dir):
self.writer = core.Writer(mode, dir)
def scalar(self, tag, type='float'):
type2scalar = {
'float': self.writer.new_scalar_float,
'double': self.writer.new_scalar_double,
'int': self.writer.new_scalar_int,
}
return type2scalar[type](tag)
__all__ = [
'set_storage',
'scalar',
]
import core
dtypes = ("float", "double", "int32", "int64")
def IM(dir, mode="read", msecs=500):
im = core.Im()
READ = "read"
WRITE = "write"
if mode == READ:
im.start_read_service(dir, msecs)
else:
im.start_write_service(dir, msecs)
return im
class _Scalar(object):
'''
Python syntax wrapper for the core.ScalarHelper object.
'''
def __init__(self, core_object):
self._core_object = core_object
def add(self, id, vs):
'''
add a scalar record
:param id: int
id in the x-corrdinate
:param vs: list
values
:return: None
'''
self._core_object.add_record(id, vs)
def set_captions(self, cs):
'''
set the captions, one caption for one line.
:param cs: list of str
:return: None
'''
self._core_object.set_captions(cs)
@property
def captions(self):
return self._core_object.get_captions()
@property
def records(self):
'''
get all the records, format like
[
[0.1, 0.2], # first record
[0.2, 0.3], # second record
# ...
]
:return: list of list
'''
return self._core_object.get_records()
@property
def ids(self):
'''
get all the ids for the records
:return: list of int
'''
return self._core_object.get_ids()
@property
def timestamps(self):
'''
get all the timestamps for the records
:return: list of int
'''
return self._core_object.get_timestamps()
@property
def size(self):
return self._core_object.get_record_size()
def scalar(im, tag, dtype='float'):
'''
create a scalar component.
:param tag: str
name of this component.
:param dtype: string
the data type that will be used in underlying storage.
:return: object of core.Tablet
'''
assert dtype in dtypes, "invalid dtype(%s), should be one of %s" % (
dtype, str(dtypes))
tablet = im.add_tablet(tag, -1)
dtype2obj = {
'float': tablet.as_float_scalar,
'double': tablet.as_double_scalar,
'int32': tablet.as_int32_scalar,
'int64': tablet.as_int64_scalar,
}
obj = dtype2obj[dtype](im)
return _Scalar(obj)
import summary import storage
import numpy as np import numpy as np
import unittest import unittest
import random import random
import time import time
once_flag = False class StorageTest(unittest.TestCase):
class ScalarTester(unittest.TestCase):
def setUp(self): def setUp(self):
self.dir = "tmp/summary.test" self.dir = "./tmp/storage_test"
# clean path self.writer = storage.StorageWriter("train", self.dir)
try:
os.rmdir(self.dir)
except:
pass
self.im = summary.IM(self.dir, "write", 200)
self.tablet_name = "scalar0"
self.scalar = summary.scalar(self.im, self.tablet_name)
self.py_captions = ["train cost", "test cost"]
self.scalar.set_captions(self.py_captions)
self.py_records = [] def test_write(self):
self.py_ids = [] scalar = self.writer.scalar("model/scalar/min")
# write scalar.set_caption("model/scalar/min")
for i in range(10): for i in range(10):
record = [0.1 * i, 0.2 * i] scalar.add_record(i, [1.0])
id = i * 10
self.py_records.append(record)
self.py_ids.append(id)
self.scalar.add(id, record)
def test_records(self):
self.assertEqual(self.scalar.size, len(self.py_records))
for i, record in enumerate(self.scalar.records):
self.assertTrue(np.isclose(record, self.py_records[i]).all())
def test_ids(self):
self.assertEqual(len(self.py_ids), self.scalar.size)
for i, id in enumerate(self.scalar.ids):
self.assertEqual(self.py_ids[i], id)
def test_captions(self):
self.assertEqual(self.scalar.captions, self.py_captions)
def test_read_records(self):
time.sleep(1)
im = summary.IM(self.dir, "read", 200)
time.sleep(1)
scalar = summary.scalar(im, self.tablet_name)
records = scalar.records
self.assertEqual(len(self.py_records), scalar.size)
for i, record in enumerate(self.scalar.records):
self.assertTrue(np.isclose(record, records[i]).all())
def test_read_ids(self):
time.sleep(0.6)
im = summary.IM(self.dir, "read", msecs=200)
time.sleep(0.6)
scalar = summary.scalar(im, self.tablet_name)
self.assertEqual(len(self.py_ids), scalar.size)
for i, id in enumerate(scalar.ids):
self.assertEqual(self.py_ids[i], id)
def test_read_captions(self):
time.sleep(0.6)
im = summary.IM(self.dir, "read", msecs=200)
time.sleep(0.6)
scalar = summary.scalar(im, self.tablet_name)
self.assertEqual(scalar.captions, self.py_captions)
def test_mix_read_write(self): def test_read(self):
write_im = summary.IM(self.dir, "write", msecs=200) self.reader = storage.StorageReader("train", self.dir)
time.sleep(0.6) scalar = self.reader.scalar("model/scalar/min")
read_im = summary.IM(self.dir, "read", msecs=200) self.assertEqual(scalar.get_caption(), "model/scalar/min")
scalar_writer = summary.scalar(write_im, self.tablet_name)
scalar_reader = summary.scalar(read_im, self.tablet_name)
scalar_writer.set_captions(["train cost", "test cost"])
for i in range(1000):
scalar_writer.add(i, [random.random(), random.random()])
scalar_reader.records
for i in range(500):
scalar_writer.add(i, [random.random(), random.random()])
scalar_reader.records
for i in range(500):
scalar_writer.add(i, [random.random(), random.random()])
for i in range(10):
scalar_reader.records
scalar_reader.captions
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -9,12 +9,12 @@ namespace visualdl { ...@@ -9,12 +9,12 @@ namespace visualdl {
entry->opr__(v); \ entry->opr__(v); \
} }
IMPL_ENTRY_SET_OR_ADD(Set, int32_t, kInt32, set_i32); IMPL_ENTRY_SET_OR_ADD(Set, int, kInt32, set_i32);
IMPL_ENTRY_SET_OR_ADD(Set, int64_t, kInt64, set_i64); IMPL_ENTRY_SET_OR_ADD(Set, int64_t, kInt64, set_i64);
IMPL_ENTRY_SET_OR_ADD(Set, bool, kBool, set_b); IMPL_ENTRY_SET_OR_ADD(Set, bool, kBool, set_b);
IMPL_ENTRY_SET_OR_ADD(Set, float, kFloat, set_f); IMPL_ENTRY_SET_OR_ADD(Set, float, kFloat, set_f);
IMPL_ENTRY_SET_OR_ADD(Set, double, kDouble, set_d); IMPL_ENTRY_SET_OR_ADD(Set, double, kDouble, set_d);
IMPL_ENTRY_SET_OR_ADD(Add, int32_t, kInt32s, add_i32s); IMPL_ENTRY_SET_OR_ADD(Add, int, kInt32s, add_i32s);
IMPL_ENTRY_SET_OR_ADD(Add, int64_t, kInt64s, add_i64s); IMPL_ENTRY_SET_OR_ADD(Add, int64_t, kInt64s, add_i64s);
IMPL_ENTRY_SET_OR_ADD(Add, float, kFloats, add_fs); IMPL_ENTRY_SET_OR_ADD(Add, float, kFloats, add_fs);
IMPL_ENTRY_SET_OR_ADD(Add, double, kDoubles, add_ds); IMPL_ENTRY_SET_OR_ADD(Add, double, kDoubles, add_ds);
...@@ -27,7 +27,7 @@ IMPL_ENTRY_SET_OR_ADD(Add, bool, kBools, add_bs); ...@@ -27,7 +27,7 @@ IMPL_ENTRY_SET_OR_ADD(Add, bool, kBools, add_bs);
data_.fieldname__(); \ data_.fieldname__(); \
} }
IMPL_ENTRY_GET(int32_t, i32); IMPL_ENTRY_GET(int, i32);
IMPL_ENTRY_GET(int64_t, i64); IMPL_ENTRY_GET(int64_t, i64);
IMPL_ENTRY_GET(float, f); IMPL_ENTRY_GET(float, f);
IMPL_ENTRY_GET(double, d); IMPL_ENTRY_GET(double, d);
...@@ -41,11 +41,21 @@ IMPL_ENTRY_GET(bool, b); ...@@ -41,11 +41,21 @@ IMPL_ENTRY_GET(bool, b);
data_.fieldname__().end()); \ data_.fieldname__().end()); \
} }
IMPL_ENTRY_GET_MULTI(int32_t, i32s); IMPL_ENTRY_GET_MULTI(int, i32s);
IMPL_ENTRY_GET_MULTI(int64_t, i64s);
IMPL_ENTRY_GET_MULTI(float, fs); IMPL_ENTRY_GET_MULTI(float, fs);
IMPL_ENTRY_GET_MULTI(double, ds); IMPL_ENTRY_GET_MULTI(double, ds);
IMPL_ENTRY_GET_MULTI(std::string, ss); IMPL_ENTRY_GET_MULTI(std::string, ss);
IMPL_ENTRY_GET_MULTI(bool, bs); IMPL_ENTRY_GET_MULTI(bool, bs);
template class Entry<int>;
template class Entry<float>;
template class Entry<double>;
template class Entry<bool>;
template class EntryReader<int>;
template class EntryReader<float>;
template class EntryReader<double>;
template class EntryReader<bool>;
} // namespace visualdl } // namespace visualdl
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册