未验证 提交 376de50a 编写于 作者: Y Yan Chunwei 提交者: GitHub

fix protobuf index < current_size_ bug (#225)

上级 56b66549
...@@ -109,7 +109,7 @@ namespace components { ...@@ -109,7 +109,7 @@ namespace components {
template <typename T> template <typename T>
std::vector<T> ScalarReader<T>::records() const { std::vector<T> ScalarReader<T>::records() const {
std::vector<T> res; std::vector<T> res;
for (int i = 0; i < reader_.total_records(); i++) { for (int i = 0; i < total_records(); i++) {
res.push_back(reader_.record(i).data(0).template Get<T>()); res.push_back(reader_.record(i).data(0).template Get<T>());
} }
return res; return res;
...@@ -302,7 +302,7 @@ void Histogram<T>::AddRecord(int step, const std::vector<T>& data) { ...@@ -302,7 +302,7 @@ void Histogram<T>::AddRecord(int step, const std::vector<T>& data) {
template <typename T> template <typename T>
HistogramRecord<T> HistogramReader<T>::record(int i) { HistogramRecord<T> HistogramReader<T>::record(int i) {
CHECK_LT(i, reader_.total_records()); CHECK_LT(i, num_records());
auto r = reader_.record(i); auto r = reader_.record(i);
auto d = r.data(0); auto d = r.data(0);
auto boundaries_str = d.GetRaw(); auto boundaries_str = d.GetRaw();
......
...@@ -126,7 +126,7 @@ struct ScalarReader { ...@@ -126,7 +126,7 @@ struct ScalarReader {
std::vector<T> ids() const; std::vector<T> ids() const;
std::vector<T> timestamps() const; std::vector<T> timestamps() const;
std::string caption() const; std::string caption() const;
size_t total_records() { return reader_.total_records(); } size_t total_records() const { return reader_.total_records() - 1; }
size_t size() const; size_t size() const;
private: private:
...@@ -276,7 +276,7 @@ template <typename T> ...@@ -276,7 +276,7 @@ template <typename T>
struct HistogramReader { struct HistogramReader {
HistogramReader(TabletReader tablet) : reader_(tablet) {} HistogramReader(TabletReader tablet) : reader_(tablet) {}
size_t num_records() { return reader_.total_records(); } size_t num_records() { return reader_.total_records() - 1; }
HistogramRecord<T> record(int i); HistogramRecord<T> record(int i);
......
...@@ -28,6 +28,7 @@ TEST(Scalar, write) { ...@@ -28,6 +28,7 @@ TEST(Scalar, write) {
auto tablet = writer.AddTablet("scalar0"); auto tablet = writer.AddTablet("scalar0");
components::Scalar<int> scalar(tablet); components::Scalar<int> scalar(tablet);
scalar.AddRecord(0, 12); scalar.AddRecord(0, 12);
scalar.AddRecord(1, 13);
auto tablet1 = writer.AddTablet("model/layer/min"); auto tablet1 = writer.AddTablet("model/layer/min");
components::Scalar<float> scalar1(tablet1); components::Scalar<float> scalar1(tablet1);
scalar1.SetCaption("customized caption"); scalar1.SetCaption("customized caption");
...@@ -39,9 +40,11 @@ TEST(Scalar, write) { ...@@ -39,9 +40,11 @@ TEST(Scalar, write) {
auto scalar_reader = components::ScalarReader<int>(std::move(tablet_reader)); auto scalar_reader = components::ScalarReader<int>(std::move(tablet_reader));
auto captioin = scalar_reader.caption(); auto captioin = scalar_reader.caption();
ASSERT_EQ(captioin, "train"); ASSERT_EQ(captioin, "train");
ASSERT_EQ(scalar_reader.total_records(), 1); // reference PR#225
ASSERT_EQ(scalar_reader.total_records(), 2 - 1);
auto record = scalar_reader.records(); auto record = scalar_reader.records();
ASSERT_EQ(record.size(), 1); // reference PR#225
ASSERT_EQ(record.size(), 2 - 1);
// check the first entry of first record // check the first entry of first record
ASSERT_EQ(record.front(), 12); ASSERT_EQ(record.front(), 12);
......
...@@ -29,7 +29,7 @@ class StorageTest(unittest.TestCase): ...@@ -29,7 +29,7 @@ class StorageTest(unittest.TestCase):
records = scalar.records() records = scalar.records()
ids = scalar.ids() ids = scalar.ids()
self.assertTrue( self.assertTrue(
np.equal(records, [float(i) for i in range(10)]).all()) np.equal(records, [float(i) for i in range(10 - 1)]).all())
self.assertTrue(np.equal(ids, [float(i) for i in range(10)]).all()) self.assertTrue(np.equal(ids, [float(i) for i in range(10)]).all())
print 'records', records print 'records', records
print 'ids', ids print 'ids', ids
...@@ -70,7 +70,7 @@ class StorageTest(unittest.TestCase): ...@@ -70,7 +70,7 @@ class StorageTest(unittest.TestCase):
''' '''
print 'check image' print 'check image'
tag = "layer1/check/image1" tag = "layer1/check/image1"
image_writer = self.writer.image(tag, 10, 1) image_writer = self.writer.image(tag, 10)
image = Image.open("./dog.jpg") image = Image.open("./dog.jpg")
shape = [image.size[1], image.size[0], 3] shape = [image.size[1], image.size[0], 3]
...@@ -81,6 +81,7 @@ class StorageTest(unittest.TestCase): ...@@ -81,6 +81,7 @@ class StorageTest(unittest.TestCase):
image_writer.start_sampling() image_writer.start_sampling()
image_writer.add_sample(shape, list(origin_data)) image_writer.add_sample(shape, list(origin_data))
image_writer.finish_sampling()
# read and check whether the original image will be displayed # read and check whether the original image will be displayed
image_reader = reader.image(tag) image_reader = reader.image(tag)
......
...@@ -115,7 +115,10 @@ struct TabletReader { ...@@ -115,7 +115,10 @@ struct TabletReader {
Tablet::Type type() const { return Tablet::Type(data_.component()); } Tablet::Type type() const { return Tablet::Type(data_.component()); }
int64_t total_records() const { return data_.records_size(); } int64_t total_records() const { return data_.records_size(); }
int32_t num_samples() const { return data_.num_samples(); } int32_t num_samples() const { return data_.num_samples(); }
RecordReader record(int i) const { return RecordReader(data_.records(i)); } RecordReader record(int i) const {
CHECK_LT(i, total_records());
return RecordReader(data_.records(i));
}
template <typename T> template <typename T>
EntryReader meta() const { EntryReader meta() const {
return EntryReader(data_.meta()); return EntryReader(data_.meta());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册