提交 6df1a43e 编写于 作者: J jonyguo

fix: padded dataset with non div & repeat

上级 71f23bdb
......@@ -75,6 +75,9 @@ Status DistributedSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer
RETURN_STATUS_UNEXPECTED("Distributed Sampler Error");
} else if (cnt_ == samples_per_buffer_ && (non_empty_ || !even_dist_)) {
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
if (!samples_per_buffer_) {
non_empty_ = false;
}
} else if (!samples_per_buffer_ && !non_empty_) {
// If the buffer is empty, we add samples with subscript 0 in the current dataset.
// This step is to make up for the solution that the code default buffer is not empty before.
......
......@@ -454,6 +454,21 @@ def test_clue_padded_and_skip_with_0_samples():
count += 1
assert count == 2
def test_celeba_padded():
data = ds.CelebADataset("../data/dataset/testCelebAData/")
padded_samples = [{'image': np.zeros(1, np.uint8), 'attr': np.zeros(1, np.uint32)}]
padded_ds = ds.PaddedDataset(padded_samples)
data = data + padded_ds
dis_sampler = ds.DistributedSampler(num_shards=2, shard_id=1, shuffle=False, num_samples=None)
data.use_sampler(dis_sampler)
data = data.repeat(2)
count = 0
for _ in data.create_dict_iterator():
count = count + 1
assert count == 2
if __name__ == '__main__':
test_TFRecord_Padded()
test_GeneratorDataSet_Padded()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册