From 376de50a1d049e577a8628a30b20b83a79ae3ad0 Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Tue, 30 Jan 2018 21:11:54 +0800 Subject: [PATCH] fix protobuf index < current_size_ bug (#225) --- visualdl/logic/sdk.cc | 4 ++-- visualdl/logic/sdk.h | 4 ++-- visualdl/logic/sdk_test.cc | 7 +++++-- visualdl/python/test_storage.py | 5 +++-- visualdl/storage/tablet.h | 5 ++++- 5 files changed, 16 insertions(+), 9 deletions(-) diff --git a/visualdl/logic/sdk.cc b/visualdl/logic/sdk.cc index 33f00d42..94dce5c2 100644 --- a/visualdl/logic/sdk.cc +++ b/visualdl/logic/sdk.cc @@ -109,7 +109,7 @@ namespace components { template std::vector ScalarReader::records() const { std::vector res; - for (int i = 0; i < reader_.total_records(); i++) { + for (int i = 0; i < total_records(); i++) { res.push_back(reader_.record(i).data(0).template Get()); } return res; @@ -302,7 +302,7 @@ void Histogram::AddRecord(int step, const std::vector& data) { template HistogramRecord HistogramReader::record(int i) { - CHECK_LT(i, reader_.total_records()); + CHECK_LT(i, num_records()); auto r = reader_.record(i); auto d = r.data(0); auto boundaries_str = d.GetRaw(); diff --git a/visualdl/logic/sdk.h b/visualdl/logic/sdk.h index 2fcc7f42..810dc5c9 100644 --- a/visualdl/logic/sdk.h +++ b/visualdl/logic/sdk.h @@ -126,7 +126,7 @@ struct ScalarReader { std::vector ids() const; std::vector timestamps() const; std::string caption() const; - size_t total_records() { return reader_.total_records(); } + size_t total_records() const { return reader_.total_records() - 1; } size_t size() const; private: @@ -276,7 +276,7 @@ template struct HistogramReader { HistogramReader(TabletReader tablet) : reader_(tablet) {} - size_t num_records() { return reader_.total_records(); } + size_t num_records() { return reader_.total_records() - 1; } HistogramRecord record(int i); diff --git a/visualdl/logic/sdk_test.cc b/visualdl/logic/sdk_test.cc index dfda35b6..0a42ff23 100644 --- a/visualdl/logic/sdk_test.cc +++ b/visualdl/logic/sdk_test.cc @@ -28,6 +28,7 @@ TEST(Scalar, write) { auto tablet = writer.AddTablet("scalar0"); components::Scalar scalar(tablet); scalar.AddRecord(0, 12); + scalar.AddRecord(1, 13); auto tablet1 = writer.AddTablet("model/layer/min"); components::Scalar scalar1(tablet1); scalar1.SetCaption("customized caption"); @@ -39,9 +40,11 @@ TEST(Scalar, write) { auto scalar_reader = components::ScalarReader(std::move(tablet_reader)); auto captioin = scalar_reader.caption(); ASSERT_EQ(captioin, "train"); - ASSERT_EQ(scalar_reader.total_records(), 1); + // reference PR#225 + ASSERT_EQ(scalar_reader.total_records(), 2 - 1); auto record = scalar_reader.records(); - ASSERT_EQ(record.size(), 1); + // reference PR#225 + ASSERT_EQ(record.size(), 2 - 1); // check the first entry of first record ASSERT_EQ(record.front(), 12); diff --git a/visualdl/python/test_storage.py b/visualdl/python/test_storage.py index e2083fc3..6c84c301 100644 --- a/visualdl/python/test_storage.py +++ b/visualdl/python/test_storage.py @@ -29,7 +29,7 @@ class StorageTest(unittest.TestCase): records = scalar.records() ids = scalar.ids() self.assertTrue( - np.equal(records, [float(i) for i in range(10)]).all()) + np.equal(records, [float(i) for i in range(10 - 1)]).all()) self.assertTrue(np.equal(ids, [float(i) for i in range(10)]).all()) print 'records', records print 'ids', ids @@ -70,7 +70,7 @@ class StorageTest(unittest.TestCase): ''' print 'check image' tag = "layer1/check/image1" - image_writer = self.writer.image(tag, 10, 1) + image_writer = self.writer.image(tag, 10) image = Image.open("./dog.jpg") shape = [image.size[1], image.size[0], 3] @@ -81,6 +81,7 @@ class StorageTest(unittest.TestCase): image_writer.start_sampling() image_writer.add_sample(shape, list(origin_data)) + image_writer.finish_sampling() # read and check whether the original image will be displayed image_reader = reader.image(tag) diff --git a/visualdl/storage/tablet.h b/visualdl/storage/tablet.h index e71cc9fe..03bb1fe6 100644 --- a/visualdl/storage/tablet.h +++ b/visualdl/storage/tablet.h @@ -115,7 +115,10 @@ struct TabletReader { Tablet::Type type() const { return Tablet::Type(data_.component()); } int64_t total_records() const { return data_.records_size(); } int32_t num_samples() const { return data_.num_samples(); } - RecordReader record(int i) const { return RecordReader(data_.records(i)); } + RecordReader record(int i) const { + CHECK_LT(i, total_records()); + return RecordReader(data_.records(i)); + } template EntryReader meta() const { return EntryReader(data_.meta()); -- GitLab