diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index cee38dfa35c9b18c887a0344a3f23fd77b64264f..56ef705f608b0fe2f70a1754c24f3713c999e130 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -506,6 +506,9 @@ class SubsetRandomSampler(BuiltinSampler): def get_num_samples(self): num_samples = super().get_num_samples() + if num_samples is None: + return len(self.indices) + return min(len(self.indices), num_samples)