diff --git a/imperative/python/megengine/data/sampler.py b/imperative/python/megengine/data/sampler.py index 356e1a79a9f4ed40c0320cb8322f46a9b88ffc2a..5ee708a0d9b6993310dd030462f6d413e7de1377 100644 --- a/imperative/python/megengine/data/sampler.py +++ b/imperative/python/megengine/data/sampler.py @@ -58,10 +58,9 @@ class MapSampler(Sampler): "drop_last should be a boolean value, but got " "drop_last={}".format(drop_last) ) + if num_samples is not None and ( - not isinstance(num_samples, int) - or isinstance(num_samples, bool) - or num_samples <= 0 + not isinstance(num_samples, int) or num_samples <= 0 ): raise ValueError( "num_samples should be a positive integer " @@ -83,6 +82,14 @@ class MapSampler(Sampler): num_samples = len(self.dataset) self.num_samples = int(math.ceil(num_samples / self.world_size)) + if self.num_samples < self.batch_size: + raise ValueError( + "num_samples should be greater than batch_size " + ", but got num_samples={} and batch_size={}".format( + self.num_samples, self.batch_size + ) + ) + # Make sure seeds are the same at each rank if seed is None and self.world_size > 1: seed = 0 @@ -297,13 +304,12 @@ class ReplacementSampler(MapSampler): def sample(self) -> List: n = len(self.dataset) - indices = self.rng.choice( - n, size=self.num_samples, replace=True, p=self.weights - ) + sample_size = self.num_samples * self.world_size + indices = self.rng.choice(n, size=sample_size, replace=True, p=self.weights) return indices.tolist() -class Infinite(MapSampler): +class Infinite(Sampler): r"""Infinite Sampler warper for basic sampler.""" def sample(self):