From 6f7a7cf1b8d45f47b3024e288c724ac93a214fe8 Mon Sep 17 00:00:00 2001 From: superjom Date: Thu, 23 Nov 2017 17:10:03 +0800 Subject: [PATCH] init python sdk --- visualdl/backend/logic/pybind.cc | 1 + visualdl/backend/logic/sdk.cc | 1 + visualdl/backend/logic/sdk.h | 2 + visualdl/backend/python/__init__.py | 0 visualdl/backend/python/summary.py | 106 ++++++++++++++++++++++++ visualdl/backend/python/test_summary.py | 44 ++++++++++ 6 files changed, 154 insertions(+) create mode 100644 visualdl/backend/python/__init__.py create mode 100644 visualdl/backend/python/summary.py create mode 100644 visualdl/backend/python/test_summary.py diff --git a/visualdl/backend/logic/pybind.cc b/visualdl/backend/logic/pybind.cc index 2256542d..0be6d5cb 100644 --- a/visualdl/backend/logic/pybind.cc +++ b/visualdl/backend/logic/pybind.cc @@ -64,6 +64,7 @@ PYBIND11_MODULE(core, m) { .def("get_records", &vs::components::ScalarHelper::GetRecords) \ .def("get_captions", &vs::components::ScalarHelper::GetCaptions) \ .def("get_ids", &vs::components::ScalarHelper::GetIds) \ + .def("get_record_size", &vs::components::ScalarHelper::GetSize) \ .def("get_timestamps", &vs::components::ScalarHelper::GetTimestamps); ADD_SCALAR_TYPED_INTERFACE(int32_t, ScalarInt32); ADD_SCALAR_TYPED_INTERFACE(int64_t, ScalarInt64); diff --git a/visualdl/backend/logic/sdk.cc b/visualdl/backend/logic/sdk.cc index a2c7d638..c676c550 100644 --- a/visualdl/backend/logic/sdk.cc +++ b/visualdl/backend/logic/sdk.cc @@ -69,6 +69,7 @@ namespace components { template void ScalarHelper::SetCaptions(const std::vector &captions) { + CHECK_EQ(data_->captions_size(), 0UL) << "the captions can set only once"; for (int i = 0; i < captions.size(); i++) { data_->add_captions(captions[i]); } diff --git a/visualdl/backend/logic/sdk.h b/visualdl/backend/logic/sdk.h index 54a50432..d463e3cd 100644 --- a/visualdl/backend/logic/sdk.h +++ b/visualdl/backend/logic/sdk.h @@ -133,6 +133,8 @@ public: std::vector GetCaptions() const; + size_t GetSize() const { return data_->records_size(); } + private: storage::Tablet *data_; }; diff --git a/visualdl/backend/python/__init__.py b/visualdl/backend/python/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/visualdl/backend/python/summary.py b/visualdl/backend/python/summary.py new file mode 100644 index 00000000..59877aca --- /dev/null +++ b/visualdl/backend/python/summary.py @@ -0,0 +1,106 @@ +__all__ = [ + 'set_storage', + 'scalar', +] +import core + +im = core.im() + +dtypes = ("float", "double", "int32", "int64") + + +def set_storage(dir): + ''' + :param dir: str + directory of summary to write log. + :return: None + ''' + im.storage().set_dir(dir) + + +class _Scalar(object): + ''' + Python syntax wrapper for the core.ScalarHelper object. + ''' + + def __init__(self, core_object): + self._core_object = core_object + + def add(self, id, vs): + ''' + add a scalar record + :param id: int + id in the x-corrdinate + :param vs: list + values + :return: None + ''' + self._core_object.add_record(id, vs) + + def set_captions(self, cs): + ''' + set the captions, one caption for one line. + :param cs: list of str + :return: None + ''' + self._core_object.set_captions(cs) + + @property + def captions(self): + return self._core_object.get_captions() + + @property + def records(self): + ''' + get all the records, format like + [ + [0.1, 0.2], # first record + [0.2, 0.3], # second record + # ... + ] + :return: list of list + ''' + return self._core_object.get_records() + + @property + def ids(self): + ''' + get all the ids for the records + :return: list of int + ''' + return self._core_object.get_ids() + + @property + def timestamps(self): + ''' + get all the timestamps for the records + :return: list of int + ''' + return self._core_object.get_timestamps() + + @property + def size(self): + return self._core_object.get_record_size() + + +def scalar(tag, dtype='float'): + ''' + create a scalar component. + + :param tag: str + name of this component. + :param dtype: string + the data type that will be used in underlying storage. + :return: object of core.Tablet + ''' + assert dtype in dtypes, "invalid dtype(%s), should be one of %s" % ( + dtype, str(dtypes)) + tablet = im.add_tablet(tag, -1) + dtype2obj = { + 'float': tablet.as_float_scalar, + 'double': tablet.as_double_scalar, + 'int32': tablet.as_int32_scalar, + 'int64': tablet.as_int64_scalar, + } + obj = dtype2obj[dtype]() + return _Scalar(obj) diff --git a/visualdl/backend/python/test_summary.py b/visualdl/backend/python/test_summary.py new file mode 100644 index 00000000..24fec43f --- /dev/null +++ b/visualdl/backend/python/test_summary.py @@ -0,0 +1,44 @@ +import summary +import numpy as np +import unittest + +summary.set_storage("tmp_dir") + +once_flag = False + + +class ScalarTester(unittest.TestCase): + def setUp(self): + global once_flag + self.scalar = summary.scalar("scalar0") + if not once_flag: + self.py_captions = ["train cost", "test cost"] + self.scalar.set_captions(self.py_captions) + + self.py_records = [] + self.py_ids = [] + for i in range(10): + record = [0.1 * i, 0.2 * i] + id = i * 10 + self.py_records.append(record) + self.py_ids.append(id) + if not once_flag: + self.scalar.add(id, record) + once_flag = True + + def test_records(self): + self.assertEqual(self.scalar.size, len(self.py_records)) + for i, record in enumerate(self.scalar.records): + self.assertTrue(np.isclose(record, self.py_records[i]).all()) + + def test_ids(self): + self.assertEqual(len(self.py_ids), self.scalar.size) + for i, id in enumerate(self.scalar.ids): + self.assertEqual(self.py_ids[i], id) + + def test_captions(self): + self.assertEqual(self.scalar.captions, self.py_captions) + + +if __name__ == '__main__': + unittest.main() -- GitLab