From 3b42c360b6e7d157b2157bd6997e8fc8185b4d8d Mon Sep 17 00:00:00 2001 From: Zirui Wu Date: Tue, 30 Jun 2020 13:23:15 -0400 Subject: [PATCH] set dataset_size in generator when source has len --- mindspore/dataset/engine/datasets.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index d8fda008e..f2c1642df 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -3157,6 +3157,9 @@ class GeneratorDataset(MappableDataset): self.column_names.append(col["name"]) self.column_types.append(DataType(col["type"])) + if source is not None and hasattr(source, "__len__"): + self._dataset_size = len(source) + def get_args(self): args = super().get_args() args["source"] = self.source @@ -3177,6 +3180,7 @@ class GeneratorDataset(MappableDataset): return self._dataset_size if self._dataset_size is None: return None + return min(rows_from_sampler, self._dataset_size) # manually set dataset_size as a temporary solution. -- GitLab