diff --git a/python/paddle/io/dataloader/batch_sampler.py b/python/paddle/io/dataloader/batch_sampler.py index 190e9240900f8b44e169aa8198aaf55f342aa36b..2f235f408d625d3fa6e4535471d394adf6eeedab 100644 --- a/python/paddle/io/dataloader/batch_sampler.py +++ b/python/paddle/io/dataloader/batch_sampler.py @@ -133,6 +133,11 @@ class BatchSampler(Sampler): ), "batch_size should be a positive integer, but got {}".format( batch_size ) + assert batch_size <= len( + self.sampler + ), "batch_size should not bigger than num of samples, but got {}".format( + batch_size + ) self.batch_size = batch_size assert isinstance( drop_last, bool