diff --git a/python/paddle/fluid/dataloader/fetcher.py b/python/paddle/fluid/dataloader/fetcher.py index 05382b04dc457b4d131f898008e4e08eae1e0a0d..8ccec81810a0a60d75b2546bd7cad4ede226855b 100644 --- a/python/paddle/fluid/dataloader/fetcher.py +++ b/python/paddle/fluid/dataloader/fetcher.py @@ -14,8 +14,9 @@ import logging from ..log_helper import get_logger +from collections.abc import Sequence, Mapping -from collections.abc import Sequence +_WARNING_TO_LOG = True class _DatasetFetcher(object): @@ -24,13 +25,17 @@ class _DatasetFetcher(object): self.auto_collate_batch = auto_collate_batch self.collate_fn = collate_fn self.drop_last = drop_last - self._is_warning_logged = False def fetch(self, batch_indices): raise NotImplementedError("'fetch' not implement for class {}".format( self.__class__.__name__)) def _log_warning(self): + # only log warning on GPU 0 when distributed launch + from ...distributed import get_world_size, get_rank + if get_world_size() >= 2 and get_rank() != 0: + return + warn_str = "Detect dataset only contains single fileds, return format " \ "changed since Paddle 2.1. In Paddle <= 2.0, DataLoader add " \ "a list surround output data(e.g. return [data]), and in " \ @@ -77,10 +82,12 @@ class _IterableDatasetFetcher(_DatasetFetcher): if len(data) == 0 or (self.drop_last and len(data) < len(batch_indices)): raise StopIteration - if not isinstance(data[0], - Sequence) and not self._is_warning_logged: + + global _WARNING_TO_LOG + if not isinstance(data[0], (Sequence, Mapping)) \ + and _WARNING_TO_LOG: self._log_warning() - self._is_warning_logged = True + _WARNING_TO_LOG = False else: data = next(self.dataset_iter) @@ -98,10 +105,11 @@ class _MapDatasetFetcher(_DatasetFetcher): if self.auto_collate_batch: data = [self.dataset[idx] for idx in batch_indices] - if not isinstance(data[0], - Sequence) and not self._is_warning_logged: + global _WARNING_TO_LOG + if not isinstance(data[0], (Sequence, Mapping)) \ + and _WARNING_TO_LOG: self._log_warning() - self._is_warning_logged = True + _WARNING_TO_LOG = False else: data = self.dataset[batch_indices]