diff --git a/visualdl/logic/sdk.cc b/visualdl/logic/sdk.cc index 33f00d42330c23bebab39ae6c159e9f24c5cd57c..94dce5c2bad3e200d8d08bb7294d6614a9b11f20 100644 --- a/visualdl/logic/sdk.cc +++ b/visualdl/logic/sdk.cc @@ -109,7 +109,7 @@ namespace components { template std::vector ScalarReader::records() const { std::vector res; - for (int i = 0; i < reader_.total_records(); i++) { + for (int i = 0; i < total_records(); i++) { res.push_back(reader_.record(i).data(0).template Get()); } return res; @@ -302,7 +302,7 @@ void Histogram::AddRecord(int step, const std::vector& data) { template HistogramRecord HistogramReader::record(int i) { - CHECK_LT(i, reader_.total_records()); + CHECK_LT(i, num_records()); auto r = reader_.record(i); auto d = r.data(0); auto boundaries_str = d.GetRaw(); diff --git a/visualdl/logic/sdk.h b/visualdl/logic/sdk.h index 2fcc7f4240835cb2b8464e44e9694bc842db2580..810dc5c9f6314da9c05bfd6fb557e0d24b12c82c 100644 --- a/visualdl/logic/sdk.h +++ b/visualdl/logic/sdk.h @@ -126,7 +126,7 @@ struct ScalarReader { std::vector ids() const; std::vector timestamps() const; std::string caption() const; - size_t total_records() { return reader_.total_records(); } + size_t total_records() const { return reader_.total_records() - 1; } size_t size() const; private: @@ -276,7 +276,7 @@ template struct HistogramReader { HistogramReader(TabletReader tablet) : reader_(tablet) {} - size_t num_records() { return reader_.total_records(); } + size_t num_records() { return reader_.total_records() - 1; } HistogramRecord record(int i); diff --git a/visualdl/logic/sdk_test.cc b/visualdl/logic/sdk_test.cc index dfda35b655c7e868144b33bdb27ea2a4008925f0..0a42ff23cf5ebd3f4d04e74a629f3311b112f64b 100644 --- a/visualdl/logic/sdk_test.cc +++ b/visualdl/logic/sdk_test.cc @@ -28,6 +28,7 @@ TEST(Scalar, write) { auto tablet = writer.AddTablet("scalar0"); components::Scalar scalar(tablet); scalar.AddRecord(0, 12); + scalar.AddRecord(1, 13); auto tablet1 = writer.AddTablet("model/layer/min"); components::Scalar scalar1(tablet1); scalar1.SetCaption("customized caption"); @@ -39,9 +40,11 @@ TEST(Scalar, write) { auto scalar_reader = components::ScalarReader(std::move(tablet_reader)); auto captioin = scalar_reader.caption(); ASSERT_EQ(captioin, "train"); - ASSERT_EQ(scalar_reader.total_records(), 1); + // reference PR#225 + ASSERT_EQ(scalar_reader.total_records(), 2 - 1); auto record = scalar_reader.records(); - ASSERT_EQ(record.size(), 1); + // reference PR#225 + ASSERT_EQ(record.size(), 2 - 1); // check the first entry of first record ASSERT_EQ(record.front(), 12); diff --git a/visualdl/python/test_storage.py b/visualdl/python/test_storage.py index e2083fc3289bf33ea673ccd6336f12cf68740576..6c84c301fbdb62a59376253b97005762d57ba58e 100644 --- a/visualdl/python/test_storage.py +++ b/visualdl/python/test_storage.py @@ -29,7 +29,7 @@ class StorageTest(unittest.TestCase): records = scalar.records() ids = scalar.ids() self.assertTrue( - np.equal(records, [float(i) for i in range(10)]).all()) + np.equal(records, [float(i) for i in range(10 - 1)]).all()) self.assertTrue(np.equal(ids, [float(i) for i in range(10)]).all()) print 'records', records print 'ids', ids @@ -70,7 +70,7 @@ class StorageTest(unittest.TestCase): ''' print 'check image' tag = "layer1/check/image1" - image_writer = self.writer.image(tag, 10, 1) + image_writer = self.writer.image(tag, 10) image = Image.open("./dog.jpg") shape = [image.size[1], image.size[0], 3] @@ -81,6 +81,7 @@ class StorageTest(unittest.TestCase): image_writer.start_sampling() image_writer.add_sample(shape, list(origin_data)) + image_writer.finish_sampling() # read and check whether the original image will be displayed image_reader = reader.image(tag) diff --git a/visualdl/storage/tablet.h b/visualdl/storage/tablet.h index e71cc9fef859d04dc4a4e62ebe941aa3a4dbb15d..03bb1fe678769752c2511ffa479ee00ed832fa12 100644 --- a/visualdl/storage/tablet.h +++ b/visualdl/storage/tablet.h @@ -115,7 +115,10 @@ struct TabletReader { Tablet::Type type() const { return Tablet::Type(data_.component()); } int64_t total_records() const { return data_.records_size(); } int32_t num_samples() const { return data_.num_samples(); } - RecordReader record(int i) const { return RecordReader(data_.records(i)); } + RecordReader record(int i) const { + CHECK_LT(i, total_records()); + return RecordReader(data_.records(i)); + } template EntryReader meta() const { return EntryReader(data_.meta());