未验证 提交 ece1e4cd 编写于 作者: K Kaipeng Deng 提交者: GitHub

Add weighted random sampler (#28545)

* add WeightedRandomSampler. test=develop
上级 2cb71c0c
...@@ -16,8 +16,11 @@ from __future__ import print_function ...@@ -16,8 +16,11 @@ from __future__ import print_function
from __future__ import division from __future__ import division
import numpy as np import numpy as np
from .. import core
__all__ = ["Sampler", "SequenceSampler", "RandomSampler"] __all__ = [
"Sampler", "SequenceSampler", "RandomSampler", "WeightedRandomSampler"
]
class Sampler(object): class Sampler(object):
...@@ -234,3 +237,85 @@ class RandomSampler(Sampler): ...@@ -234,3 +237,85 @@ class RandomSampler(Sampler):
def __len__(self): def __len__(self):
return self.num_samples return self.num_samples
def _weighted_sample(weights, num_samples, replacement=True):
if isinstance(weights, core.LoDTensor):
weights = weights.numpy()
if isinstance(weights, (list, tuple)):
weights = np.array(weights)
assert isinstance(weights, np.ndarray), \
"weights should be paddle.Tensor, numpy.ndarray, list or tuple"
assert len(weights.shape) <= 2, \
"weights should be a 1-D or 2-D array"
weights = weights.reshape((-1, weights.shape[-1]))
assert np.all(weights >= 0.), \
"weights should be positive value"
assert not np.any(weights == np.inf), \
"weights shoule not be INF"
assert not np.any(weights == np.nan), \
"weights shoule not be NaN"
non_zeros = np.sum(weights > 0., axis=1)
assert np.all(non_zeros > 0), \
"weights should have positive values"
if not replacement:
assert np.all(non_zeros >= num_samples), \
"weights positive value number should not " \
"less than num_samples when replacement=False"
weights = weights / weights.sum(axis=1)
rets = []
for i in range(weights.shape[0]):
ret = np.random.choice(weights.shape[1], num_samples, replacement,
weights[i])
rets.append(ret)
return np.array(rets)
class WeightedRandomSampler(Sampler):
"""
Random sample with given weights (probabilities), sampe index will be in range
[0, len(weights) - 1], if :attr:`replacement` is True, index can be sampled
multiple times.
Args:
weights(numpy.ndarray|paddle.Tensor|list|tuple): sequence of weights,
should be numpy array, paddle.Tensor, list or tuple
num_samples(int): set sample number to draw from sampler.
replacement(bool): Whether to draw sample with replacements, default True
Returns:
Sampler: a Sampler yield sample index randomly by given weights
Examples:
.. code-block:: python
from paddle.io import WeightedRandomSampler
sampler = WeightedRandomSampler(weights=[0.1, 0.3, 0.5, 0.7, 0.2],
num_samples=5,
replacement=True)
for index in sampler:
print(index)
"""
def __init__(self, weights, num_samples, replacement=True):
if not isinstance(num_samples, int) or num_samples <= 0:
raise ValueError("num_samples should be a positive integer")
if not isinstance(replacement, bool):
raise ValueError("replacement should be a boolean value")
self.weights = weights
self.num_samples = num_samples
self.replacement = replacement
def __iter__(self):
idxs = _weighted_sample(self.weights, self.num_samples,
self.replacement)
return iter(idxs.reshape((-1)).tolist())
def __len__(self):
mul = np.prod(self.weights.shape) // self.weights.shape[-1]
return self.num_samples * mul
...@@ -16,8 +16,10 @@ from __future__ import division ...@@ -16,8 +16,10 @@ from __future__ import division
import unittest import unittest
import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.io import BatchSampler, Dataset, Sampler, SequenceSampler, RandomSampler from paddle.io import BatchSampler, Dataset, Sampler, SequenceSampler, \
RandomSampler, WeightedRandomSampler
from paddle.io import DistributedBatchSampler from paddle.io import DistributedBatchSampler
...@@ -195,14 +197,86 @@ class TestBatchSamplerWithSamplerShuffle(unittest.TestCase): ...@@ -195,14 +197,86 @@ class TestBatchSamplerWithSamplerShuffle(unittest.TestCase):
pass pass
class TestDistributedBatchSamplerWithSampler(TestBatchSampler): class TestWeightedRandomSampler(unittest.TestCase):
def init_batch_sampler(self): def init_probs(self, total, pos):
dataset = RandomDataset(1000, 10) pos_probs = np.random.random((pos, )).astype('float32')
bs = DistributedBatchSampler( probs = np.zeros((total, )).astype('float32')
dataset=dataset, probs[:pos] = pos_probs
batch_size=self.batch_size, np.random.shuffle(probs)
drop_last=self.drop_last) return probs
return bs
def test_replacement(self):
probs = self.init_probs(20, 10)
sampler = WeightedRandomSampler(probs, 30, True)
assert len(sampler) == 30
for idx in iter(sampler):
assert probs[idx] > 0.
def test_no_replacement(self):
probs = self.init_probs(20, 10)
sampler = WeightedRandomSampler(probs, 10, False)
assert len(sampler) == 10
idxs = []
for idx in iter(sampler):
assert probs[idx] > 0.
idxs.append(idx)
assert len(set(idxs)) == len(idxs)
def test_assert(self):
# all zeros
probs = np.zeros((10, )).astype('float32')
sampler = WeightedRandomSampler(probs, 10, True)
try:
for idx in iter(sampler):
pass
self.assertTrue(False)
except AssertionError:
self.assertTrue(True)
# not enough pos
probs = self.init_probs(10, 5)
sampler = WeightedRandomSampler(probs, 10, False)
try:
for idx in iter(sampler):
pass
self.assertTrue(False)
except AssertionError:
self.assertTrue(True)
# neg probs
probs = -1.0 * np.ones((10, )).astype('float32')
sampler = WeightedRandomSampler(probs, 10, True)
try:
for idx in iter(sampler):
pass
self.assertTrue(False)
except AssertionError:
self.assertTrue(True)
def test_raise(self):
# float num_samples
probs = self.init_probs(10, 5)
try:
sampler = WeightedRandomSampler(probs, 2.3, True)
self.assertTrue(False)
except ValueError:
self.assertTrue(True)
# neg num_samples
probs = self.init_probs(10, 5)
try:
sampler = WeightedRandomSampler(probs, -1, True)
self.assertTrue(False)
except ValueError:
self.assertTrue(True)
# no-bool replacement
probs = self.init_probs(10, 5)
try:
sampler = WeightedRandomSampler(probs, 5, 5)
self.assertTrue(False)
except ValueError:
self.assertTrue(True)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -27,9 +27,10 @@ __all__ = [ ...@@ -27,9 +27,10 @@ __all__ = [
'Sampler', 'Sampler',
'SequenceSampler', 'SequenceSampler',
'RandomSampler', 'RandomSampler',
'WeightedRandomSampler',
] ]
from ..fluid.io import DataLoader from ..fluid.io import DataLoader
from ..fluid.dataloader import Dataset, IterableDataset, BatchSampler, get_worker_info, \ from ..fluid.dataloader import Dataset, IterableDataset, BatchSampler, get_worker_info, \
TensorDataset, Sampler, SequenceSampler, RandomSampler, DistributedBatchSampler, \ TensorDataset, Sampler, SequenceSampler, RandomSampler, DistributedBatchSampler, \
ComposeDataset, ChainDataset ComposeDataset, ChainDataset, WeightedRandomSampler
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册