未验证 提交 ed7903cd 编写于 作者: K Kaipeng Deng 提交者: GitHub

make DataLoader warning less noisy. test=develop (#34001)

上级 8417ad60
...@@ -14,8 +14,9 @@ ...@@ -14,8 +14,9 @@
import logging import logging
from ..log_helper import get_logger from ..log_helper import get_logger
from collections.abc import Sequence, Mapping
from collections.abc import Sequence _WARNING_TO_LOG = True
class _DatasetFetcher(object): class _DatasetFetcher(object):
...@@ -24,13 +25,17 @@ class _DatasetFetcher(object): ...@@ -24,13 +25,17 @@ class _DatasetFetcher(object):
self.auto_collate_batch = auto_collate_batch self.auto_collate_batch = auto_collate_batch
self.collate_fn = collate_fn self.collate_fn = collate_fn
self.drop_last = drop_last self.drop_last = drop_last
self._is_warning_logged = False
def fetch(self, batch_indices): def fetch(self, batch_indices):
raise NotImplementedError("'fetch' not implement for class {}".format( raise NotImplementedError("'fetch' not implement for class {}".format(
self.__class__.__name__)) self.__class__.__name__))
def _log_warning(self): def _log_warning(self):
# only log warning on GPU 0 when distributed launch
from ...distributed import get_world_size, get_rank
if get_world_size() >= 2 and get_rank() != 0:
return
warn_str = "Detect dataset only contains single fileds, return format " \ warn_str = "Detect dataset only contains single fileds, return format " \
"changed since Paddle 2.1. In Paddle <= 2.0, DataLoader add " \ "changed since Paddle 2.1. In Paddle <= 2.0, DataLoader add " \
"a list surround output data(e.g. return [data]), and in " \ "a list surround output data(e.g. return [data]), and in " \
...@@ -77,10 +82,12 @@ class _IterableDatasetFetcher(_DatasetFetcher): ...@@ -77,10 +82,12 @@ class _IterableDatasetFetcher(_DatasetFetcher):
if len(data) == 0 or (self.drop_last and if len(data) == 0 or (self.drop_last and
len(data) < len(batch_indices)): len(data) < len(batch_indices)):
raise StopIteration raise StopIteration
if not isinstance(data[0],
Sequence) and not self._is_warning_logged: global _WARNING_TO_LOG
if not isinstance(data[0], (Sequence, Mapping)) \
and _WARNING_TO_LOG:
self._log_warning() self._log_warning()
self._is_warning_logged = True _WARNING_TO_LOG = False
else: else:
data = next(self.dataset_iter) data = next(self.dataset_iter)
...@@ -98,10 +105,11 @@ class _MapDatasetFetcher(_DatasetFetcher): ...@@ -98,10 +105,11 @@ class _MapDatasetFetcher(_DatasetFetcher):
if self.auto_collate_batch: if self.auto_collate_batch:
data = [self.dataset[idx] for idx in batch_indices] data = [self.dataset[idx] for idx in batch_indices]
if not isinstance(data[0], global _WARNING_TO_LOG
Sequence) and not self._is_warning_logged: if not isinstance(data[0], (Sequence, Mapping)) \
and _WARNING_TO_LOG:
self._log_warning() self._log_warning()
self._is_warning_logged = True _WARNING_TO_LOG = False
else: else:
data = self.dataset[batch_indices] data = self.dataset[batch_indices]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册