提交 03804075 编写于 作者: Y yanghaitao1

store get dataset size

上级 1e90e7be
...@@ -2284,6 +2284,7 @@ class ImageFolderDatasetV2(MappableDataset): ...@@ -2284,6 +2284,7 @@ class ImageFolderDatasetV2(MappableDataset):
self.decode = decode self.decode = decode
self.num_shards = num_shards self.num_shards = num_shards
self.shard_id = shard_id self.shard_id = shard_id
self.cur_dataset_size = None
def get_args(self): def get_args(self):
args = super().get_args() args = super().get_args()
...@@ -2305,6 +2306,9 @@ class ImageFolderDatasetV2(MappableDataset): ...@@ -2305,6 +2306,9 @@ class ImageFolderDatasetV2(MappableDataset):
Return: Return:
Number, number of batches. Number, number of batches.
""" """
if self.cur_dataset_size is not None:
return self.cur_dataset_size
if self.num_samples is None: if self.num_samples is None:
num_samples = 0 num_samples = 0
else: else:
...@@ -2314,9 +2318,11 @@ class ImageFolderDatasetV2(MappableDataset): ...@@ -2314,9 +2318,11 @@ class ImageFolderDatasetV2(MappableDataset):
rows_from_sampler = self._get_sampler_dataset_size() rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None: if rows_from_sampler is None:
self.cur_dataset_size = rows_per_shard
return rows_per_shard return rows_per_shard
return min(rows_from_sampler, rows_per_shard) self.cur_dataset_size = min(rows_from_sampler, rows_per_shard)
return self.cur_dataset_size
def num_classes(self): def num_classes(self):
""" """
...@@ -2509,6 +2515,7 @@ class MindDataset(SourceDataset): ...@@ -2509,6 +2515,7 @@ class MindDataset(SourceDataset):
self.shuffle_option = shuffle self.shuffle_option = shuffle
self.distribution = "" self.distribution = ""
self.sampler = sampler self.sampler = sampler
self.cur_dataset_size = None
if num_shards is None or shard_id is None: if num_shards is None or shard_id is None:
self.partitions = None self.partitions = None
...@@ -2578,6 +2585,9 @@ class MindDataset(SourceDataset): ...@@ -2578,6 +2585,9 @@ class MindDataset(SourceDataset):
Number, number of batches. Number, number of batches.
""" """
if self._dataset_size is None: if self._dataset_size is None:
if self.cur_dataset_size is not None:
return self.cur_dataset_size
if self.load_dataset: if self.load_dataset:
dataset_file = [self.dataset_file] dataset_file = [self.dataset_file]
else: else:
...@@ -2591,6 +2601,7 @@ class MindDataset(SourceDataset): ...@@ -2591,6 +2601,7 @@ class MindDataset(SourceDataset):
raise RuntimeError( raise RuntimeError(
"Dataset size plus number of padded samples is not divisible by number of shards.") "Dataset size plus number of padded samples is not divisible by number of shards.")
num_rows = num_rows // self.partitions[0] + 1 num_rows = num_rows // self.partitions[0] + 1
self.cur_dataset_size = num_rows
return num_rows return num_rows
return self._dataset_size return self._dataset_size
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册