未验证 提交 ece1e4cd 编写于 作者: K Kaipeng Deng 提交者: GitHub

Add weighted random sampler (#28545)

* add WeightedRandomSampler. test=develop
上级 2cb71c0c
......@@ -16,8 +16,11 @@ from __future__ import print_function
from __future__ import division
import numpy as np
from .. import core
__all__ = ["Sampler", "SequenceSampler", "RandomSampler"]
__all__ = [
"Sampler", "SequenceSampler", "RandomSampler", "WeightedRandomSampler"
]
class Sampler(object):
......@@ -234,3 +237,85 @@ class RandomSampler(Sampler):
def __len__(self):
return self.num_samples
def _weighted_sample(weights, num_samples, replacement=True):
if isinstance(weights, core.LoDTensor):
weights = weights.numpy()
if isinstance(weights, (list, tuple)):
weights = np.array(weights)
assert isinstance(weights, np.ndarray), \
"weights should be paddle.Tensor, numpy.ndarray, list or tuple"
assert len(weights.shape) <= 2, \
"weights should be a 1-D or 2-D array"
weights = weights.reshape((-1, weights.shape[-1]))
assert np.all(weights >= 0.), \
"weights should be positive value"
assert not np.any(weights == np.inf), \
"weights shoule not be INF"
assert not np.any(weights == np.nan), \
"weights shoule not be NaN"
non_zeros = np.sum(weights > 0., axis=1)
assert np.all(non_zeros > 0), \
"weights should have positive values"
if not replacement:
assert np.all(non_zeros >= num_samples), \
"weights positive value number should not " \
"less than num_samples when replacement=False"
weights = weights / weights.sum(axis=1)
rets = []
for i in range(weights.shape[0]):
ret = np.random.choice(weights.shape[1], num_samples, replacement,
weights[i])
rets.append(ret)
return np.array(rets)
class WeightedRandomSampler(Sampler):
"""
Random sample with given weights (probabilities), sampe index will be in range
[0, len(weights) - 1], if :attr:`replacement` is True, index can be sampled
multiple times.
Args:
weights(numpy.ndarray|paddle.Tensor|list|tuple): sequence of weights,
should be numpy array, paddle.Tensor, list or tuple
num_samples(int): set sample number to draw from sampler.
replacement(bool): Whether to draw sample with replacements, default True
Returns:
Sampler: a Sampler yield sample index randomly by given weights
Examples:
.. code-block:: python
from paddle.io import WeightedRandomSampler
sampler = WeightedRandomSampler(weights=[0.1, 0.3, 0.5, 0.7, 0.2],
num_samples=5,
replacement=True)
for index in sampler:
print(index)
"""
def __init__(self, weights, num_samples, replacement=True):
if not isinstance(num_samples, int) or num_samples <= 0:
raise ValueError("num_samples should be a positive integer")
if not isinstance(replacement, bool):
raise ValueError("replacement should be a boolean value")
self.weights = weights
self.num_samples = num_samples
self.replacement = replacement
def __iter__(self):
idxs = _weighted_sample(self.weights, self.num_samples,
self.replacement)
return iter(idxs.reshape((-1)).tolist())
def __len__(self):
mul = np.prod(self.weights.shape) // self.weights.shape[-1]
return self.num_samples * mul
......@@ -16,8 +16,10 @@ from __future__ import division
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.io import BatchSampler, Dataset, Sampler, SequenceSampler, RandomSampler
from paddle.io import BatchSampler, Dataset, Sampler, SequenceSampler, \
RandomSampler, WeightedRandomSampler
from paddle.io import DistributedBatchSampler
......@@ -195,14 +197,86 @@ class TestBatchSamplerWithSamplerShuffle(unittest.TestCase):
pass
class TestDistributedBatchSamplerWithSampler(TestBatchSampler):
def init_batch_sampler(self):
dataset = RandomDataset(1000, 10)
bs = DistributedBatchSampler(
dataset=dataset,
batch_size=self.batch_size,
drop_last=self.drop_last)
return bs
class TestWeightedRandomSampler(unittest.TestCase):
def init_probs(self, total, pos):
pos_probs = np.random.random((pos, )).astype('float32')
probs = np.zeros((total, )).astype('float32')
probs[:pos] = pos_probs
np.random.shuffle(probs)
return probs
def test_replacement(self):
probs = self.init_probs(20, 10)
sampler = WeightedRandomSampler(probs, 30, True)
assert len(sampler) == 30
for idx in iter(sampler):
assert probs[idx] > 0.
def test_no_replacement(self):
probs = self.init_probs(20, 10)
sampler = WeightedRandomSampler(probs, 10, False)
assert len(sampler) == 10
idxs = []
for idx in iter(sampler):
assert probs[idx] > 0.
idxs.append(idx)
assert len(set(idxs)) == len(idxs)
def test_assert(self):
# all zeros
probs = np.zeros((10, )).astype('float32')
sampler = WeightedRandomSampler(probs, 10, True)
try:
for idx in iter(sampler):
pass
self.assertTrue(False)
except AssertionError:
self.assertTrue(True)
# not enough pos
probs = self.init_probs(10, 5)
sampler = WeightedRandomSampler(probs, 10, False)
try:
for idx in iter(sampler):
pass
self.assertTrue(False)
except AssertionError:
self.assertTrue(True)
# neg probs
probs = -1.0 * np.ones((10, )).astype('float32')
sampler = WeightedRandomSampler(probs, 10, True)
try:
for idx in iter(sampler):
pass
self.assertTrue(False)
except AssertionError:
self.assertTrue(True)
def test_raise(self):
# float num_samples
probs = self.init_probs(10, 5)
try:
sampler = WeightedRandomSampler(probs, 2.3, True)
self.assertTrue(False)
except ValueError:
self.assertTrue(True)
# neg num_samples
probs = self.init_probs(10, 5)
try:
sampler = WeightedRandomSampler(probs, -1, True)
self.assertTrue(False)
except ValueError:
self.assertTrue(True)
# no-bool replacement
probs = self.init_probs(10, 5)
try:
sampler = WeightedRandomSampler(probs, 5, 5)
self.assertTrue(False)
except ValueError:
self.assertTrue(True)
if __name__ == '__main__':
......
......@@ -27,9 +27,10 @@ __all__ = [
'Sampler',
'SequenceSampler',
'RandomSampler',
'WeightedRandomSampler',
]
from ..fluid.io import DataLoader
from ..fluid.dataloader import Dataset, IterableDataset, BatchSampler, get_worker_info, \
TensorDataset, Sampler, SequenceSampler, RandomSampler, DistributedBatchSampler, \
ComposeDataset, ChainDataset
ComposeDataset, ChainDataset, WeightedRandomSampler
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册