From 45a9977dd2ea1ba932b2b85d437b561817e77aef Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 1 Dec 2022 11:28:07 +0800 Subject: [PATCH] fix(mge/data): fix weighted random sampler GitOrigin-RevId: d09cbbfffd2434dea3f56b3db2f642b5c236141a --- imperative/python/megengine/data/sampler.py | 8 ++--- .../python/test/unit/data/test_sampler.py | 30 +++++++++++++++++-- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/imperative/python/megengine/data/sampler.py b/imperative/python/megengine/data/sampler.py index 1c8af4951..356e1a79a 100644 --- a/imperative/python/megengine/data/sampler.py +++ b/imperative/python/megengine/data/sampler.py @@ -297,10 +297,10 @@ class ReplacementSampler(MapSampler): def sample(self) -> List: n = len(self.dataset) - if self.weights is None: - return self.rng.randint(n, size=self.num_samples).tolist() - else: - return self.rng.multinomial(n, self.weights, self.num_samples).tolist() + indices = self.rng.choice( + n, size=self.num_samples, replace=True, p=self.weights + ) + return indices.tolist() class Infinite(MapSampler): diff --git a/imperative/python/test/unit/data/test_sampler.py b/imperative/python/test/unit/data/test_sampler.py index f4b33a90a..1e1119056 100644 --- a/imperative/python/test/unit/data/test_sampler.py +++ b/imperative/python/test/unit/data/test_sampler.py @@ -58,8 +58,34 @@ def test_random_sampler_seed(): def test_ReplacementSampler(): num_samples = 30 - indices = list(range(20)) - weights = list(range(20)) + num_data = 20 + indices = list(range(num_data)) + sampler = ReplacementSampler( + ArrayDataset(indices), num_samples=num_samples, weights=None + ) + assert len(list(each[0] for each in sampler)) == num_samples + + num_data = 8 + weights = list(range(num_data)) + indices = list(range(num_data)) + sampler = ReplacementSampler( + ArrayDataset(indices), num_samples=num_samples, weights=weights + ) + assert len(list(each[0] for each in sampler)) == num_samples + iter = 1000 + hist = [0 for _ in range(num_data)] + for _ in range(iter): + for index in sampler: + index = index[0] + hist[index] += 1 + actual_weights = np.array(hist) / sum(hist) + desired_weights = np.array(weights) / sum(weights) + np.testing.assert_allclose(actual_weights, desired_weights, rtol=8e-2) + + num_data = 50000 + num_samples = 50000 * 30 + weights = list(range(num_data)) + indices = list(range(num_data)) sampler = ReplacementSampler( ArrayDataset(indices), num_samples=num_samples, weights=weights ) -- GitLab