diff --git a/imperative/python/megengine/data/sampler.py b/imperative/python/megengine/data/sampler.py index 1c8af4951296f82c9cbbf64089335438a61e4ed8..356e1a79a9f4ed40c0320cb8322f46a9b88ffc2a 100644 --- a/imperative/python/megengine/data/sampler.py +++ b/imperative/python/megengine/data/sampler.py @@ -297,10 +297,10 @@ class ReplacementSampler(MapSampler): def sample(self) -> List: n = len(self.dataset) - if self.weights is None: - return self.rng.randint(n, size=self.num_samples).tolist() - else: - return self.rng.multinomial(n, self.weights, self.num_samples).tolist() + indices = self.rng.choice( + n, size=self.num_samples, replace=True, p=self.weights + ) + return indices.tolist() class Infinite(MapSampler): diff --git a/imperative/python/test/unit/data/test_sampler.py b/imperative/python/test/unit/data/test_sampler.py index f4b33a90a422f5c508583a60523ea28abe6c396b..1e1119056d42ed82e75aa7c42e306a4d8fa935ab 100644 --- a/imperative/python/test/unit/data/test_sampler.py +++ b/imperative/python/test/unit/data/test_sampler.py @@ -58,8 +58,34 @@ def test_random_sampler_seed(): def test_ReplacementSampler(): num_samples = 30 - indices = list(range(20)) - weights = list(range(20)) + num_data = 20 + indices = list(range(num_data)) + sampler = ReplacementSampler( + ArrayDataset(indices), num_samples=num_samples, weights=None + ) + assert len(list(each[0] for each in sampler)) == num_samples + + num_data = 8 + weights = list(range(num_data)) + indices = list(range(num_data)) + sampler = ReplacementSampler( + ArrayDataset(indices), num_samples=num_samples, weights=weights + ) + assert len(list(each[0] for each in sampler)) == num_samples + iter = 1000 + hist = [0 for _ in range(num_data)] + for _ in range(iter): + for index in sampler: + index = index[0] + hist[index] += 1 + actual_weights = np.array(hist) / sum(hist) + desired_weights = np.array(weights) / sum(weights) + np.testing.assert_allclose(actual_weights, desired_weights, rtol=8e-2) + + num_data = 50000 + num_samples = 50000 * 30 + weights = list(range(num_data)) + indices = list(range(num_data)) sampler = ReplacementSampler( ArrayDataset(indices), num_samples=num_samples, weights=weights )