diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 63cb7d2d481149bd3017086073295d4469680e76..0058eebc5835c244ba1722d1092935c17950737c 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -2284,6 +2284,7 @@ class ImageFolderDatasetV2(MappableDataset): self.decode = decode self.num_shards = num_shards self.shard_id = shard_id + self.cur_dataset_size = None def get_args(self): args = super().get_args() @@ -2305,6 +2306,9 @@ class ImageFolderDatasetV2(MappableDataset): Return: Number, number of batches. """ + if self.cur_dataset_size is not None: + return self.cur_dataset_size + if self.num_samples is None: num_samples = 0 else: @@ -2314,9 +2318,11 @@ class ImageFolderDatasetV2(MappableDataset): rows_from_sampler = self._get_sampler_dataset_size() if rows_from_sampler is None: + self.cur_dataset_size = rows_per_shard return rows_per_shard - return min(rows_from_sampler, rows_per_shard) + self.cur_dataset_size = min(rows_from_sampler, rows_per_shard) + return self.cur_dataset_size def num_classes(self): """ @@ -2509,6 +2515,7 @@ class MindDataset(SourceDataset): self.shuffle_option = shuffle self.distribution = "" self.sampler = sampler + self.cur_dataset_size = None if num_shards is None or shard_id is None: self.partitions = None @@ -2578,6 +2585,9 @@ class MindDataset(SourceDataset): Number, number of batches. """ if self._dataset_size is None: + if self.cur_dataset_size is not None: + return self.cur_dataset_size + if self.load_dataset: dataset_file = [self.dataset_file] else: @@ -2591,6 +2601,7 @@ class MindDataset(SourceDataset): raise RuntimeError( "Dataset size plus number of padded samples is not divisible by number of shards.") num_rows = num_rows // self.partitions[0] + 1 + self.cur_dataset_size = num_rows return num_rows return self._dataset_size