diff --git a/visualdl/backend/logic/pybind.cc b/visualdl/backend/logic/pybind.cc index 2256542d880e2cc9199a4b9d6c27497491e8fb95..0be6d5cb56e66e1012228ab8c30748b7e73862fa 100644 --- a/visualdl/backend/logic/pybind.cc +++ b/visualdl/backend/logic/pybind.cc @@ -64,6 +64,7 @@ PYBIND11_MODULE(core, m) { .def("get_records", &vs::components::ScalarHelper::GetRecords) \ .def("get_captions", &vs::components::ScalarHelper::GetCaptions) \ .def("get_ids", &vs::components::ScalarHelper::GetIds) \ + .def("get_record_size", &vs::components::ScalarHelper::GetSize) \ .def("get_timestamps", &vs::components::ScalarHelper::GetTimestamps); ADD_SCALAR_TYPED_INTERFACE(int32_t, ScalarInt32); ADD_SCALAR_TYPED_INTERFACE(int64_t, ScalarInt64); diff --git a/visualdl/backend/logic/sdk.cc b/visualdl/backend/logic/sdk.cc index a2c7d638f847efc9af33487a725204a5b2a00df7..c676c5505ab215ca6e601b8307af443520238f2f 100644 --- a/visualdl/backend/logic/sdk.cc +++ b/visualdl/backend/logic/sdk.cc @@ -69,6 +69,7 @@ namespace components { template void ScalarHelper::SetCaptions(const std::vector &captions) { + CHECK_EQ(data_->captions_size(), 0UL) << "the captions can set only once"; for (int i = 0; i < captions.size(); i++) { data_->add_captions(captions[i]); } diff --git a/visualdl/backend/logic/sdk.h b/visualdl/backend/logic/sdk.h index 54a504325a1905f3a58c4e9a96817c64775275ec..d463e3cd5903afdfcaf3d2ccf413d5b8a54cc49b 100644 --- a/visualdl/backend/logic/sdk.h +++ b/visualdl/backend/logic/sdk.h @@ -133,6 +133,8 @@ public: std::vector GetCaptions() const; + size_t GetSize() const { return data_->records_size(); } + private: storage::Tablet *data_; }; diff --git a/visualdl/backend/python/__init__.py b/visualdl/backend/python/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/visualdl/backend/python/summary.py b/visualdl/backend/python/summary.py new file mode 100644 index 0000000000000000000000000000000000000000..59877aca69df649599e55aa112a4a8cf9d26c107 --- /dev/null +++ b/visualdl/backend/python/summary.py @@ -0,0 +1,106 @@ +__all__ = [ + 'set_storage', + 'scalar', +] +import core + +im = core.im() + +dtypes = ("float", "double", "int32", "int64") + + +def set_storage(dir): + ''' + :param dir: str + directory of summary to write log. + :return: None + ''' + im.storage().set_dir(dir) + + +class _Scalar(object): + ''' + Python syntax wrapper for the core.ScalarHelper object. + ''' + + def __init__(self, core_object): + self._core_object = core_object + + def add(self, id, vs): + ''' + add a scalar record + :param id: int + id in the x-corrdinate + :param vs: list + values + :return: None + ''' + self._core_object.add_record(id, vs) + + def set_captions(self, cs): + ''' + set the captions, one caption for one line. + :param cs: list of str + :return: None + ''' + self._core_object.set_captions(cs) + + @property + def captions(self): + return self._core_object.get_captions() + + @property + def records(self): + ''' + get all the records, format like + [ + [0.1, 0.2], # first record + [0.2, 0.3], # second record + # ... + ] + :return: list of list + ''' + return self._core_object.get_records() + + @property + def ids(self): + ''' + get all the ids for the records + :return: list of int + ''' + return self._core_object.get_ids() + + @property + def timestamps(self): + ''' + get all the timestamps for the records + :return: list of int + ''' + return self._core_object.get_timestamps() + + @property + def size(self): + return self._core_object.get_record_size() + + +def scalar(tag, dtype='float'): + ''' + create a scalar component. + + :param tag: str + name of this component. + :param dtype: string + the data type that will be used in underlying storage. + :return: object of core.Tablet + ''' + assert dtype in dtypes, "invalid dtype(%s), should be one of %s" % ( + dtype, str(dtypes)) + tablet = im.add_tablet(tag, -1) + dtype2obj = { + 'float': tablet.as_float_scalar, + 'double': tablet.as_double_scalar, + 'int32': tablet.as_int32_scalar, + 'int64': tablet.as_int64_scalar, + } + obj = dtype2obj[dtype]() + return _Scalar(obj) diff --git a/visualdl/backend/python/test_summary.py b/visualdl/backend/python/test_summary.py new file mode 100644 index 0000000000000000000000000000000000000000..24fec43f5690da88c4d0d8fd2bad697794e4b032 --- /dev/null +++ b/visualdl/backend/python/test_summary.py @@ -0,0 +1,44 @@ +import summary +import numpy as np +import unittest + +summary.set_storage("tmp_dir") + +once_flag = False + + +class ScalarTester(unittest.TestCase): + def setUp(self): + global once_flag + self.scalar = summary.scalar("scalar0") + if not once_flag: + self.py_captions = ["train cost", "test cost"] + self.scalar.set_captions(self.py_captions) + + self.py_records = [] + self.py_ids = [] + for i in range(10): + record = [0.1 * i, 0.2 * i] + id = i * 10 + self.py_records.append(record) + self.py_ids.append(id) + if not once_flag: + self.scalar.add(id, record) + once_flag = True + + def test_records(self): + self.assertEqual(self.scalar.size, len(self.py_records)) + for i, record in enumerate(self.scalar.records): + self.assertTrue(np.isclose(record, self.py_records[i]).all()) + + def test_ids(self): + self.assertEqual(len(self.py_ids), self.scalar.size) + for i, id in enumerate(self.scalar.ids): + self.assertEqual(self.py_ids[i], id) + + def test_captions(self): + self.assertEqual(self.scalar.captions, self.py_captions) + + +if __name__ == '__main__': + unittest.main()