提交 09fd47a2 编写于 作者: M ms_yan

repair get_sampler_size problem

上级 92052713
......@@ -1227,7 +1227,10 @@ class MappableDataset(SourceDataset):
def _get_sampler_dataset_size(self):
if self.sampler is not None:
return self.sampler.get_dataset_size()
if hasattr(self.sampler, 'get_dataset_size'):
return self.sampler.get_dataset_size()
if hasattr(self.sampler, '__len__'):
return len(self.sampler)
return None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册