diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index d8fda008e9047a636c2d98214d440661491b2a7d..f2c1642df58529f0a986924d97801cdada42fb03 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -3157,6 +3157,9 @@ class GeneratorDataset(MappableDataset): self.column_names.append(col["name"]) self.column_types.append(DataType(col["type"])) + if source is not None and hasattr(source, "__len__"): + self._dataset_size = len(source) + def get_args(self): args = super().get_args() args["source"] = self.source @@ -3177,6 +3180,7 @@ class GeneratorDataset(MappableDataset): return self._dataset_size if self._dataset_size is None: return None + return min(rows_from_sampler, self._dataset_size) # manually set dataset_size as a temporary solution.