提交 5895b7e6 编写于 作者: S superjom

add pybind

上级 befaf66c
......@@ -43,7 +43,7 @@ add_executable(vl_test
${PROJECT_SOURCE_DIR}/visualdl/utils/concurrency.h
${PROJECT_SOURCE_DIR}/visualdl/utils/filesystem.h
)
target_link_libraries(vl_test entry sdk storage im gtest glog protobuf gflags pthread)
target_link_libraries(vl_test sdk storage entry im gtest glog protobuf gflags pthread)
enable_testing ()
......
......@@ -6,6 +6,6 @@ add_dependencies(sdk entry storage storage_proto)
## pybind
add_library(core SHARED ${PROJECT_SOURCE_DIR}/visualdl/logic/pybind.cc)
add_dependencies(core pybind python im storage sdk protobuf glog)
target_link_libraries(core PRIVATE pybind python im storage sdk protobuf glog)
add_dependencies(core pybind python im entry storage sdk protobuf glog)
target_link_libraries(core PRIVATE pybind entry python im storage sdk protobuf glog)
set_target_properties(core PROPERTIES PREFIX "" SUFFIX ".so")
......@@ -45,7 +45,7 @@ PYBIND11_PLUGIN(core) {
// clang-format off
ADD_SCALAR(float)
ADD_SCALAR(double)
ADD_SCALAR(int)
ADD_SCALAR(int);
// clang-format on
#undef ADD_SCALAR
......@@ -64,7 +64,7 @@ PYBIND11_PLUGIN(core) {
// clang-format off
ADD_SCALAR(float)
ADD_SCALAR(double)
ADD_SCALAR(int)
ADD_SCALAR(int);
// clang-format on
#undef ADD_SCALAR
......
......@@ -4,61 +4,61 @@ namespace visualdl {
namespace components {
template <typename T>
void components::Scalar<T>::AddRecord(int id, const std::vector<T> &values) {
// add record data
auto record = tablet_.AddRecord();
auto entry = record.AddData<T>();
for (auto v : values) {
entry.Add(v);
}
// set record id
record.SetId(id);
// set record timestamp
record.SetTimeStamp(time(NULL));
}
// template <typename T>
// void components::Scalar<T>::AddRecord(int id, const std::vector<T> &values) {
// // add record data
// auto record = tablet_.AddRecord();
// auto entry = record.AddData<T>();
// for (auto v : values) {
// entry.Add(v);
// }
// // set record id
// record.SetId(id);
// // set record timestamp
// record.SetTimeStamp(time(NULL));
// }
template <typename T>
std::vector<T> ScalarReader<T>::records() const {
std::vector<T> res;
for (int i = 0; i < reader_.total_records(); i++) {
res.push_back(reader_.record(i).data<T>(0));
}
return res;
}
// template <typename T>
// std::vector<T> ScalarReader<T>::records() const {
// std::vector<T> res;
// for (int i = 0; i < reader_.total_records(); i++) {
// res.push_back(reader_.record(i).data<T>(0));
// }
// return res;
// }
template <typename T>
std::vector<int> ScalarReader<T>::ids() const {
std::vector<int> res;
for (int i = 0; i < reader_.total_records(); i++) {
res.push_back(reader_.record(i).id());
}
return res;
}
// template <typename T>
// std::vector<int> ScalarReader<T>::ids() const {
// std::vector<int> res;
// for (int i = 0; i < reader_.total_records(); i++) {
// res.push_back(reader_.record(i).id());
// }
// return res;
// }
template <typename T>
std::vector<int> ScalarReader<T>::timestamps() const {
std::vector<T> res;
for (int i = 0; i < reader_.total_records(); i++) {
res.push_back(reader_.record(i).timestamp());
}
return res;
}
// template <typename T>
// std::vector<int> ScalarReader<T>::timestamps() const {
// std::vector<T> res;
// for (int i = 0; i < reader_.total_records(); i++) {
// res.push_back(reader_.record(i).timestamp());
// }
// return res;
// }
template <typename T>
std::vector<std::string> ScalarReader<T>::captions() const {
return reader_.captions();
}
// template <typename T>
// std::vector<std::string> ScalarReader<T>::captions() const {
// return reader_.captions();
// }
template <typename T>
size_t ScalarReader<T>::size() const {
return reader_.total_records();
}
// template <typename T>
// size_t ScalarReader<T>::size() const {
// return reader_.total_records();
// }
template class Scalar<int>;
template class Scalar<int64_t>;
template class Scalar<float>;
template class Scalar<double>;
// template class Scalar<int>;
// template class Scalar<int64_t>;
// template class Scalar<float>;
// template class Scalar<double>;
} // namespace components
......
......@@ -57,7 +57,12 @@ struct Scalar {
tablet_.SetCaptions(std::vector<std::string>({cap}));
}
void AddRecord(int id, const std::vector<T>& values);
void AddRecord(int id, T value) {
auto record = tablet_.AddRecord();
record.SetId(id);
auto entry = record.AddData<T>();
entry.Set(value);
}
private:
Tablet tablet_;
......@@ -68,8 +73,8 @@ struct ScalarReader {
ScalarReader(TabletReader&& reader) : reader_(reader) {}
std::vector<T> records() const;
std::vector<int> ids() const;
std::vector<int> timestamps() const;
std::vector<T> ids() const;
std::vector<T> timestamps() const;
std::vector<std::string> captions() const;
size_t size() const;
......@@ -77,6 +82,43 @@ private:
TabletReader reader_;
};
template <typename T>
std::vector<T> ScalarReader<T>::records() const {
std::vector<T> res;
for (int i = 0; i < reader_.total_records(); i++) {
res.push_back(reader_.record(i).data<T>(0).Get());
}
return res;
}
template <typename T>
std::vector<T> ScalarReader<T>::ids() const {
std::vector<T> res;
for (int i = 0; i < reader_.total_records(); i++) {
res.push_back(reader_.record(i).id());
}
return res;
}
template <typename T>
std::vector<T> ScalarReader<T>::timestamps() const {
std::vector<T> res;
for (int i = 0; i < reader_.total_records(); i++) {
res.push_back(reader_.record(i).timestamp());
}
return res;
}
template <typename T>
std::vector<std::string> ScalarReader<T>::captions() const {
return reader_.captions();
}
template <typename T>
size_t ScalarReader<T>::size() const {
return reader_.total_records();
}
} // namespace components
} // namespace visualdl
......
......@@ -13,7 +13,7 @@ TEST(Scalar, write) {
auto tablet = storage.AddTablet("scalar0");
components::Scalar<int> scalar(tablet);
scalar.SetCaption("train");
scalar.AddRecord(0, std::vector<int>({12}));
scalar.AddRecord(0, 12);
// read from disk
StorageReader reader(dir);
......@@ -24,9 +24,8 @@ TEST(Scalar, write) {
ASSERT_EQ(scalar_reader.total_records(), 1);
auto record = scalar_reader.record(0);
// check the first entry of first record
auto vs = record.data<int>(0).GetMulti();
ASSERT_EQ(vs.size(), 1);
ASSERT_EQ(vs.front(), 12);
auto vs = record.data<int>(0).Get();
ASSERT_EQ(vs, 12);
}
} // namespace visualdl
__all__ = [
'StorageReader',
'StorageWriter',
]
import core
dtypes = ("float", "double", "int32", "int64")
class StorageReader(object):
def __init__(self, mode, dir):
self.reader = core.Reader(mode, dir)
def scalar(self, tag, type='float'):
type2scalar = {
'float': self.reader.get_scalar_float,
'double': self.reader.get_scalar_double,
'int': self.reader.get_scalar_int,
}
return type2scalar[type](tag)
class StorageWriter(object):
def __init__(self, mode, dir):
self.writer = core.Writer(mode, dir)
def scalar(self, tag, type='float'):
type2scalar = {
'float': self.writer.new_scalar_float,
'double': self.writer.new_scalar_double,
'int': self.writer.new_scalar_int,
}
return type2scalar[type](tag)
__all__ = [
'set_storage',
'scalar',
]
import core
dtypes = ("float", "double", "int32", "int64")
def IM(dir, mode="read", msecs=500):
im = core.Im()
READ = "read"
WRITE = "write"
if mode == READ:
im.start_read_service(dir, msecs)
else:
im.start_write_service(dir, msecs)
return im
class _Scalar(object):
'''
Python syntax wrapper for the core.ScalarHelper object.
'''
def __init__(self, core_object):
self._core_object = core_object
def add(self, id, vs):
'''
add a scalar record
:param id: int
id in the x-corrdinate
:param vs: list
values
:return: None
'''
self._core_object.add_record(id, vs)
def set_captions(self, cs):
'''
set the captions, one caption for one line.
:param cs: list of str
:return: None
'''
self._core_object.set_captions(cs)
@property
def captions(self):
return self._core_object.get_captions()
@property
def records(self):
'''
get all the records, format like
[
[0.1, 0.2], # first record
[0.2, 0.3], # second record
# ...
]
:return: list of list
'''
return self._core_object.get_records()
@property
def ids(self):
'''
get all the ids for the records
:return: list of int
'''
return self._core_object.get_ids()
@property
def timestamps(self):
'''
get all the timestamps for the records
:return: list of int
'''
return self._core_object.get_timestamps()
@property
def size(self):
return self._core_object.get_record_size()
def scalar(im, tag, dtype='float'):
'''
create a scalar component.
:param tag: str
name of this component.
:param dtype: string
the data type that will be used in underlying storage.
:return: object of core.Tablet
'''
assert dtype in dtypes, "invalid dtype(%s), should be one of %s" % (
dtype, str(dtypes))
tablet = im.add_tablet(tag, -1)
dtype2obj = {
'float': tablet.as_float_scalar,
'double': tablet.as_double_scalar,
'int32': tablet.as_int32_scalar,
'int64': tablet.as_int64_scalar,
}
obj = dtype2obj[dtype](im)
return _Scalar(obj)
import summary
import storage
import numpy as np
import unittest
import random
import time
once_flag = False
class ScalarTester(unittest.TestCase):
class StorageTest(unittest.TestCase):
def setUp(self):
self.dir = "tmp/summary.test"
# clean path
try:
os.rmdir(self.dir)
except:
pass
self.im = summary.IM(self.dir, "write", 200)
self.tablet_name = "scalar0"
self.scalar = summary.scalar(self.im, self.tablet_name)
self.py_captions = ["train cost", "test cost"]
self.scalar.set_captions(self.py_captions)
self.dir = "./tmp/storage_test"
self.writer = storage.StorageWriter("train", self.dir)
self.py_records = []
self.py_ids = []
# write
def test_write(self):
scalar = self.writer.scalar("model/scalar/min")
scalar.set_caption("model/scalar/min")
for i in range(10):
record = [0.1 * i, 0.2 * i]
id = i * 10
self.py_records.append(record)
self.py_ids.append(id)
self.scalar.add(id, record)
def test_records(self):
self.assertEqual(self.scalar.size, len(self.py_records))
for i, record in enumerate(self.scalar.records):
self.assertTrue(np.isclose(record, self.py_records[i]).all())
def test_ids(self):
self.assertEqual(len(self.py_ids), self.scalar.size)
for i, id in enumerate(self.scalar.ids):
self.assertEqual(self.py_ids[i], id)
def test_captions(self):
self.assertEqual(self.scalar.captions, self.py_captions)
def test_read_records(self):
time.sleep(1)
im = summary.IM(self.dir, "read", 200)
time.sleep(1)
scalar = summary.scalar(im, self.tablet_name)
records = scalar.records
self.assertEqual(len(self.py_records), scalar.size)
for i, record in enumerate(self.scalar.records):
self.assertTrue(np.isclose(record, records[i]).all())
def test_read_ids(self):
time.sleep(0.6)
im = summary.IM(self.dir, "read", msecs=200)
time.sleep(0.6)
scalar = summary.scalar(im, self.tablet_name)
self.assertEqual(len(self.py_ids), scalar.size)
for i, id in enumerate(scalar.ids):
self.assertEqual(self.py_ids[i], id)
def test_read_captions(self):
time.sleep(0.6)
im = summary.IM(self.dir, "read", msecs=200)
time.sleep(0.6)
scalar = summary.scalar(im, self.tablet_name)
self.assertEqual(scalar.captions, self.py_captions)
scalar.add_record(i, [1.0])
def test_mix_read_write(self):
write_im = summary.IM(self.dir, "write", msecs=200)
time.sleep(0.6)
read_im = summary.IM(self.dir, "read", msecs=200)
scalar_writer = summary.scalar(write_im, self.tablet_name)
scalar_reader = summary.scalar(read_im, self.tablet_name)
scalar_writer.set_captions(["train cost", "test cost"])
for i in range(1000):
scalar_writer.add(i, [random.random(), random.random()])
scalar_reader.records
for i in range(500):
scalar_writer.add(i, [random.random(), random.random()])
scalar_reader.records
for i in range(500):
scalar_writer.add(i, [random.random(), random.random()])
for i in range(10):
scalar_reader.records
scalar_reader.captions
def test_read(self):
self.reader = storage.StorageReader("train", self.dir)
scalar = self.reader.scalar("model/scalar/min")
self.assertEqual(scalar.get_caption(), "model/scalar/min")
if __name__ == '__main__':
......
......@@ -9,12 +9,12 @@ namespace visualdl {
entry->opr__(v); \
}
IMPL_ENTRY_SET_OR_ADD(Set, int32_t, kInt32, set_i32);
IMPL_ENTRY_SET_OR_ADD(Set, int, kInt32, set_i32);
IMPL_ENTRY_SET_OR_ADD(Set, int64_t, kInt64, set_i64);
IMPL_ENTRY_SET_OR_ADD(Set, bool, kBool, set_b);
IMPL_ENTRY_SET_OR_ADD(Set, float, kFloat, set_f);
IMPL_ENTRY_SET_OR_ADD(Set, double, kDouble, set_d);
IMPL_ENTRY_SET_OR_ADD(Add, int32_t, kInt32s, add_i32s);
IMPL_ENTRY_SET_OR_ADD(Add, int, kInt32s, add_i32s);
IMPL_ENTRY_SET_OR_ADD(Add, int64_t, kInt64s, add_i64s);
IMPL_ENTRY_SET_OR_ADD(Add, float, kFloats, add_fs);
IMPL_ENTRY_SET_OR_ADD(Add, double, kDoubles, add_ds);
......@@ -27,7 +27,7 @@ IMPL_ENTRY_SET_OR_ADD(Add, bool, kBools, add_bs);
data_.fieldname__(); \
}
IMPL_ENTRY_GET(int32_t, i32);
IMPL_ENTRY_GET(int, i32);
IMPL_ENTRY_GET(int64_t, i64);
IMPL_ENTRY_GET(float, f);
IMPL_ENTRY_GET(double, d);
......@@ -41,11 +41,21 @@ IMPL_ENTRY_GET(bool, b);
data_.fieldname__().end()); \
}
IMPL_ENTRY_GET_MULTI(int32_t, i32s);
IMPL_ENTRY_GET_MULTI(int64_t, i64s);
IMPL_ENTRY_GET_MULTI(int, i32s);
IMPL_ENTRY_GET_MULTI(float, fs);
IMPL_ENTRY_GET_MULTI(double, ds);
IMPL_ENTRY_GET_MULTI(std::string, ss);
IMPL_ENTRY_GET_MULTI(bool, bs);
template class Entry<int>;
template class Entry<float>;
template class Entry<double>;
template class Entry<bool>;
template class EntryReader<int>;
template class EntryReader<float>;
template class EntryReader<double>;
template class EntryReader<bool>;
} // namespace visualdl
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册