提交 1127ace7 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2228 cache get_dataset_size value

Merge pull request !2228 from yanghaitao/yht_get_dataset_size
......@@ -2284,6 +2284,7 @@ class ImageFolderDatasetV2(MappableDataset):
self.decode = decode
self.num_shards = num_shards
self.shard_id = shard_id
self.cur_dataset_size = None
def get_args(self):
args = super().get_args()
......@@ -2305,6 +2306,9 @@ class ImageFolderDatasetV2(MappableDataset):
Return:
Number, number of batches.
"""
if self.cur_dataset_size is not None:
return self.cur_dataset_size
if self.num_samples is None:
num_samples = 0
else:
......@@ -2314,9 +2318,11 @@ class ImageFolderDatasetV2(MappableDataset):
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None:
self.cur_dataset_size = rows_per_shard
return rows_per_shard
return min(rows_from_sampler, rows_per_shard)
self.cur_dataset_size = min(rows_from_sampler, rows_per_shard)
return self.cur_dataset_size
def num_classes(self):
"""
......@@ -2509,6 +2515,7 @@ class MindDataset(SourceDataset):
self.shuffle_option = shuffle
self.distribution = ""
self.sampler = sampler
self.cur_dataset_size = None
if num_shards is None or shard_id is None:
self.partitions = None
......@@ -2578,6 +2585,9 @@ class MindDataset(SourceDataset):
Number, number of batches.
"""
if self._dataset_size is None:
if self.cur_dataset_size is not None:
return self.cur_dataset_size
if self.load_dataset:
dataset_file = [self.dataset_file]
else:
......@@ -2591,6 +2601,7 @@ class MindDataset(SourceDataset):
raise RuntimeError(
"Dataset size plus number of padded samples is not divisible by number of shards.")
num_rows = num_rows // self.partitions[0] + 1
self.cur_dataset_size = num_rows
return num_rows
return self._dataset_size
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册