未验证 提交 fb08734e 编写于 作者: Y Yan Chunwei 提交者: GitHub

Merge pull request #34 from Superjom/feature/init_python_sdk

......@@ -64,6 +64,7 @@ PYBIND11_MODULE(core, m) {
.def("get_records", &vs::components::ScalarHelper<T>::GetRecords) \
.def("get_captions", &vs::components::ScalarHelper<T>::GetCaptions) \
.def("get_ids", &vs::components::ScalarHelper<T>::GetIds) \
.def("get_record_size", &vs::components::ScalarHelper<T>::GetSize) \
.def("get_timestamps", &vs::components::ScalarHelper<T>::GetTimestamps);
ADD_SCALAR_TYPED_INTERFACE(int32_t, ScalarInt32);
ADD_SCALAR_TYPED_INTERFACE(int64_t, ScalarInt64);
......
......@@ -69,6 +69,7 @@ namespace components {
template <typename T>
void ScalarHelper<T>::SetCaptions(const std::vector<std::string> &captions) {
CHECK_EQ(data_->captions_size(), 0UL) << "the captions can set only once";
for (int i = 0; i < captions.size(); i++) {
data_->add_captions(captions[i]);
}
......
......@@ -133,6 +133,8 @@ public:
std::vector<std::string> GetCaptions() const;
size_t GetSize() const { return data_->records_size(); }
private:
storage::Tablet *data_;
};
......
__all__ = [
'set_storage',
'scalar',
]
import core
im = core.im()
dtypes = ("float", "double", "int32", "int64")
def set_storage(dir):
'''
:param dir: str
directory of summary to write log.
:return: None
'''
im.storage().set_dir(dir)
class _Scalar(object):
'''
Python syntax wrapper for the core.ScalarHelper object.
'''
def __init__(self, core_object):
self._core_object = core_object
def add(self, id, vs):
'''
add a scalar record
:param id: int
id in the x-corrdinate
:param vs: list
values
:return: None
'''
self._core_object.add_record(id, vs)
def set_captions(self, cs):
'''
set the captions, one caption for one line.
:param cs: list of str
:return: None
'''
self._core_object.set_captions(cs)
@property
def captions(self):
return self._core_object.get_captions()
@property
def records(self):
'''
get all the records, format like
[
[0.1, 0.2], # first record
[0.2, 0.3], # second record
# ...
]
:return: list of list
'''
return self._core_object.get_records()
@property
def ids(self):
'''
get all the ids for the records
:return: list of int
'''
return self._core_object.get_ids()
@property
def timestamps(self):
'''
get all the timestamps for the records
:return: list of int
'''
return self._core_object.get_timestamps()
@property
def size(self):
return self._core_object.get_record_size()
def scalar(tag, dtype='float'):
'''
create a scalar component.
:param tag: str
name of this component.
:param dtype: string
the data type that will be used in underlying storage.
:return: object of core.Tablet
'''
assert dtype in dtypes, "invalid dtype(%s), should be one of %s" % (
dtype, str(dtypes))
tablet = im.add_tablet(tag, -1)
dtype2obj = {
'float': tablet.as_float_scalar,
'double': tablet.as_double_scalar,
'int32': tablet.as_int32_scalar,
'int64': tablet.as_int64_scalar,
}
obj = dtype2obj[dtype]()
return _Scalar(obj)
import summary
import numpy as np
import unittest
summary.set_storage("tmp_dir")
once_flag = False
class ScalarTester(unittest.TestCase):
def setUp(self):
global once_flag
self.scalar = summary.scalar("scalar0")
if not once_flag:
self.py_captions = ["train cost", "test cost"]
self.scalar.set_captions(self.py_captions)
self.py_records = []
self.py_ids = []
for i in range(10):
record = [0.1 * i, 0.2 * i]
id = i * 10
self.py_records.append(record)
self.py_ids.append(id)
if not once_flag:
self.scalar.add(id, record)
once_flag = True
def test_records(self):
self.assertEqual(self.scalar.size, len(self.py_records))
for i, record in enumerate(self.scalar.records):
self.assertTrue(np.isclose(record, self.py_records[i]).all())
def test_ids(self):
self.assertEqual(len(self.py_ids), self.scalar.size)
for i, id in enumerate(self.scalar.ids):
self.assertEqual(self.py_ids[i], id)
def test_captions(self):
self.assertEqual(self.scalar.captions, self.py_captions)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册