提交 1332a049 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2775 Set dataset size in generator when input has a length

Merge pull request !2775 from ZiruiWu/better_get_dataset_size_in_generator
......@@ -3157,6 +3157,9 @@ class GeneratorDataset(MappableDataset):
self.column_names.append(col["name"])
self.column_types.append(DataType(col["type"]))
if source is not None and hasattr(source, "__len__"):
self._dataset_size = len(source)
def get_args(self):
args = super().get_args()
args["source"] = self.source
......@@ -3177,6 +3180,7 @@ class GeneratorDataset(MappableDataset):
return self._dataset_size
if self._dataset_size is None:
return None
return min(rows_from_sampler, self._dataset_size)
# manually set dataset_size as a temporary solution.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册