diff --git a/python/paddle/fluid/dataloader/batch_sampler.py b/python/paddle/fluid/dataloader/batch_sampler.py index 8043237c0d97d4be09e675464c01a6a9cc70392c..1d180329b72510de5e7e9362e4c002f4508ba1be 100644 --- a/python/paddle/fluid/dataloader/batch_sampler.py +++ b/python/paddle/fluid/dataloader/batch_sampler.py @@ -16,7 +16,7 @@ from __future__ import print_function from __future__ import division import numpy as np -from .sampler import Sampler, SequenceSampler +from .sampler import Sampler, SequenceSampler, RandomSampler from .dataset import Dataset, IterableDataset __all__ = ["BatchSampler"] @@ -86,7 +86,6 @@ class BatchSampler(Sampler): # init with sampler sampler = RandomSampler(RandomDataset(100)) bs = BatchSampler(sampler=sampler, - shuffle=True, batch_size=8, drop_last=True) @@ -118,14 +117,16 @@ class BatchSampler(Sampler): "dataset should not be a paddle.io.IterableDataset" assert sampler is None, \ "should not set both dataset and sampler" - self.sampler = SequenceSampler(dataset) + assert isinstance(shuffle, bool), \ + "shuffle should be a boolean value, but got {}".format(type(shuffle)) + if shuffle: + self.sampler = RandomSampler(dataset) + else: + self.sampler = SequenceSampler(dataset) assert isinstance(batch_size, int) and batch_size > 0, \ "batch_size should be a positive integer, but got {}".format(batch_size) self.batch_size = batch_size - assert isinstance(shuffle, bool), \ - "shuffle should be a boolean value, but got {}".format(type(shuffle)) - self.shuffle = shuffle assert isinstance(drop_last, bool), \ "drop_last should be a boolean value, but got {}".format(type(drop_last)) self.drop_last = drop_last diff --git a/python/paddle/fluid/dataloader/sampler.py b/python/paddle/fluid/dataloader/sampler.py index d2f3231cc6b12f1190ad9453ea48213edc9122b8..5c75fafe8b22380090ba6fb580777cdbe6570ad6 100644 --- a/python/paddle/fluid/dataloader/sampler.py +++ b/python/paddle/fluid/dataloader/sampler.py @@ -177,7 +177,7 @@ class RandomSampler(Sampler): def __len__(self): return self.num_samples - sampler = RandomSampler(data_souce=RandomDataset(100)) + sampler = RandomSampler(data_source=RandomDataset(100)) for index in sampler: print(index) @@ -216,7 +216,11 @@ class RandomSampler(Sampler): def __iter__(self): n = len(self.data_source) if self.generator: - for index in self.generator: + for i in range(self.num_samples): + try: + index = next(self.generator) + except StopIteration: + return yield index else: if self.replacement: diff --git a/python/paddle/fluid/tests/unittests/test_batch_sampler.py b/python/paddle/fluid/tests/unittests/test_batch_sampler.py index 2e2a6144fd0119ef920031a623f7eedd8ae5dbe7..6ec6fdb59f200ce1dc9b6418b7f11329f85ba5dd 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_sampler.py +++ b/python/paddle/fluid/tests/unittests/test_batch_sampler.py @@ -88,6 +88,18 @@ class TestRandomSampler(unittest.TestCase): rets.append(i) assert tuple(sorted(rets)) == tuple(range(0, 60)) + def test_with_generator_num_samples(self): + dataset = RandomDataset(100, 10) + generator = iter(range(0, 60)) + sampler = RandomSampler( + dataset, generator=generator, num_samples=50, replacement=True) + assert len(sampler) == 50 + + rets = [] + for i in iter(sampler): + rets.append(i) + assert tuple(sorted(rets)) == tuple(range(0, 50)) + class TestBatchSampler(unittest.TestCase): def setUp(self):