“6871fe635be3465a06341dd22d5ccc1382d7de8d”上不存在“...train/git@gitcode.net:s920243400/PaddleDetection.git”
未验证 提交 376de50a 编写于 作者: Y Yan Chunwei 提交者: GitHub

fix protobuf index < current_size_ bug (#225)

上级 56b66549
......@@ -109,7 +109,7 @@ namespace components {
template <typename T>
std::vector<T> ScalarReader<T>::records() const {
std::vector<T> res;
for (int i = 0; i < reader_.total_records(); i++) {
for (int i = 0; i < total_records(); i++) {
res.push_back(reader_.record(i).data(0).template Get<T>());
}
return res;
......@@ -302,7 +302,7 @@ void Histogram<T>::AddRecord(int step, const std::vector<T>& data) {
template <typename T>
HistogramRecord<T> HistogramReader<T>::record(int i) {
CHECK_LT(i, reader_.total_records());
CHECK_LT(i, num_records());
auto r = reader_.record(i);
auto d = r.data(0);
auto boundaries_str = d.GetRaw();
......
......@@ -126,7 +126,7 @@ struct ScalarReader {
std::vector<T> ids() const;
std::vector<T> timestamps() const;
std::string caption() const;
size_t total_records() { return reader_.total_records(); }
size_t total_records() const { return reader_.total_records() - 1; }
size_t size() const;
private:
......@@ -276,7 +276,7 @@ template <typename T>
struct HistogramReader {
HistogramReader(TabletReader tablet) : reader_(tablet) {}
size_t num_records() { return reader_.total_records(); }
size_t num_records() { return reader_.total_records() - 1; }
HistogramRecord<T> record(int i);
......
......@@ -28,6 +28,7 @@ TEST(Scalar, write) {
auto tablet = writer.AddTablet("scalar0");
components::Scalar<int> scalar(tablet);
scalar.AddRecord(0, 12);
scalar.AddRecord(1, 13);
auto tablet1 = writer.AddTablet("model/layer/min");
components::Scalar<float> scalar1(tablet1);
scalar1.SetCaption("customized caption");
......@@ -39,9 +40,11 @@ TEST(Scalar, write) {
auto scalar_reader = components::ScalarReader<int>(std::move(tablet_reader));
auto captioin = scalar_reader.caption();
ASSERT_EQ(captioin, "train");
ASSERT_EQ(scalar_reader.total_records(), 1);
// reference PR#225
ASSERT_EQ(scalar_reader.total_records(), 2 - 1);
auto record = scalar_reader.records();
ASSERT_EQ(record.size(), 1);
// reference PR#225
ASSERT_EQ(record.size(), 2 - 1);
// check the first entry of first record
ASSERT_EQ(record.front(), 12);
......
......@@ -29,7 +29,7 @@ class StorageTest(unittest.TestCase):
records = scalar.records()
ids = scalar.ids()
self.assertTrue(
np.equal(records, [float(i) for i in range(10)]).all())
np.equal(records, [float(i) for i in range(10 - 1)]).all())
self.assertTrue(np.equal(ids, [float(i) for i in range(10)]).all())
print 'records', records
print 'ids', ids
......@@ -70,7 +70,7 @@ class StorageTest(unittest.TestCase):
'''
print 'check image'
tag = "layer1/check/image1"
image_writer = self.writer.image(tag, 10, 1)
image_writer = self.writer.image(tag, 10)
image = Image.open("./dog.jpg")
shape = [image.size[1], image.size[0], 3]
......@@ -81,6 +81,7 @@ class StorageTest(unittest.TestCase):
image_writer.start_sampling()
image_writer.add_sample(shape, list(origin_data))
image_writer.finish_sampling()
# read and check whether the original image will be displayed
image_reader = reader.image(tag)
......
......@@ -115,7 +115,10 @@ struct TabletReader {
Tablet::Type type() const { return Tablet::Type(data_.component()); }
int64_t total_records() const { return data_.records_size(); }
int32_t num_samples() const { return data_.num_samples(); }
RecordReader record(int i) const { return RecordReader(data_.records(i)); }
RecordReader record(int i) const {
CHECK_LT(i, total_records());
return RecordReader(data_.records(i));
}
template <typename T>
EntryReader meta() const {
return EntryReader(data_.meta());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册