未验证 提交 dd3df693 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix RandomSampler & BatchSampler. test=develop (#26559)

* fix RandomSampler & BatchSampler. test=develop
上级 d6e888ca
...@@ -16,7 +16,7 @@ from __future__ import print_function ...@@ -16,7 +16,7 @@ from __future__ import print_function
from __future__ import division from __future__ import division
import numpy as np import numpy as np
from .sampler import Sampler, SequenceSampler from .sampler import Sampler, SequenceSampler, RandomSampler
from .dataset import Dataset, IterableDataset from .dataset import Dataset, IterableDataset
__all__ = ["BatchSampler"] __all__ = ["BatchSampler"]
...@@ -86,7 +86,6 @@ class BatchSampler(Sampler): ...@@ -86,7 +86,6 @@ class BatchSampler(Sampler):
# init with sampler # init with sampler
sampler = RandomSampler(RandomDataset(100)) sampler = RandomSampler(RandomDataset(100))
bs = BatchSampler(sampler=sampler, bs = BatchSampler(sampler=sampler,
shuffle=True,
batch_size=8, batch_size=8,
drop_last=True) drop_last=True)
...@@ -118,14 +117,16 @@ class BatchSampler(Sampler): ...@@ -118,14 +117,16 @@ class BatchSampler(Sampler):
"dataset should not be a paddle.io.IterableDataset" "dataset should not be a paddle.io.IterableDataset"
assert sampler is None, \ assert sampler is None, \
"should not set both dataset and sampler" "should not set both dataset and sampler"
self.sampler = SequenceSampler(dataset) assert isinstance(shuffle, bool), \
"shuffle should be a boolean value, but got {}".format(type(shuffle))
if shuffle:
self.sampler = RandomSampler(dataset)
else:
self.sampler = SequenceSampler(dataset)
assert isinstance(batch_size, int) and batch_size > 0, \ assert isinstance(batch_size, int) and batch_size > 0, \
"batch_size should be a positive integer, but got {}".format(batch_size) "batch_size should be a positive integer, but got {}".format(batch_size)
self.batch_size = batch_size self.batch_size = batch_size
assert isinstance(shuffle, bool), \
"shuffle should be a boolean value, but got {}".format(type(shuffle))
self.shuffle = shuffle
assert isinstance(drop_last, bool), \ assert isinstance(drop_last, bool), \
"drop_last should be a boolean value, but got {}".format(type(drop_last)) "drop_last should be a boolean value, but got {}".format(type(drop_last))
self.drop_last = drop_last self.drop_last = drop_last
......
...@@ -177,7 +177,7 @@ class RandomSampler(Sampler): ...@@ -177,7 +177,7 @@ class RandomSampler(Sampler):
def __len__(self): def __len__(self):
return self.num_samples return self.num_samples
sampler = RandomSampler(data_souce=RandomDataset(100)) sampler = RandomSampler(data_source=RandomDataset(100))
for index in sampler: for index in sampler:
print(index) print(index)
...@@ -216,7 +216,11 @@ class RandomSampler(Sampler): ...@@ -216,7 +216,11 @@ class RandomSampler(Sampler):
def __iter__(self): def __iter__(self):
n = len(self.data_source) n = len(self.data_source)
if self.generator: if self.generator:
for index in self.generator: for i in range(self.num_samples):
try:
index = next(self.generator)
except StopIteration:
return
yield index yield index
else: else:
if self.replacement: if self.replacement:
......
...@@ -88,6 +88,18 @@ class TestRandomSampler(unittest.TestCase): ...@@ -88,6 +88,18 @@ class TestRandomSampler(unittest.TestCase):
rets.append(i) rets.append(i)
assert tuple(sorted(rets)) == tuple(range(0, 60)) assert tuple(sorted(rets)) == tuple(range(0, 60))
def test_with_generator_num_samples(self):
dataset = RandomDataset(100, 10)
generator = iter(range(0, 60))
sampler = RandomSampler(
dataset, generator=generator, num_samples=50, replacement=True)
assert len(sampler) == 50
rets = []
for i in iter(sampler):
rets.append(i)
assert tuple(sorted(rets)) == tuple(range(0, 50))
class TestBatchSampler(unittest.TestCase): class TestBatchSampler(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册