未验证 提交 308467c3 编写于 作者: K Kaipeng Deng 提交者: GitHub

Add warning for dataloader incompatable upgrade (#32967)

* add warning log for DataLoader output format imcompatible upgrade. test=develop
上级 fe94db6c
...@@ -12,6 +12,11 @@ ...@@ -12,6 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from ..log_helper import get_logger
from collections.abc import Sequence
class _DatasetFetcher(object): class _DatasetFetcher(object):
def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last): def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last):
...@@ -19,11 +24,39 @@ class _DatasetFetcher(object): ...@@ -19,11 +24,39 @@ class _DatasetFetcher(object):
self.auto_collate_batch = auto_collate_batch self.auto_collate_batch = auto_collate_batch
self.collate_fn = collate_fn self.collate_fn = collate_fn
self.drop_last = drop_last self.drop_last = drop_last
self._is_warning_logged = False
def fetch(self, batch_indices): def fetch(self, batch_indices):
raise NotImplementedError("'fetch' not implement for class {}".format( raise NotImplementedError("'fetch' not implement for class {}".format(
self.__class__.__name__)) self.__class__.__name__))
def _log_warning(self):
warn_str = "Detect dataset only contains single fileds, return format " \
"changed since Paddle 2.1. In Paddle <= 2.0, DataLoader add " \
"a list surround output data(e.g. return [data]), and in " \
"Paddle >= 2.1, DataLoader return the single filed directly " \
"(e.g. return data). For example, in following code: \n\n"
warn_str += \
"import numpy as np\n" \
"from paddle.io import DataLoader, Dataset\n\n" \
"class RandomDataset(Dataset):\n" \
" def __getitem__(self, idx):\n" \
" data = np.random.random((2, 3)).astype('float32')\n\n" \
" return data\n\n" \
" def __len__(self):\n" \
" return 10\n\n" \
"dataset = RandomDataset()\n" \
"loader = DataLoader(dataset, batch_size=1)\n" \
"data = next(loader())\n\n"
warn_str += "In Paddle <= 2.0, data is in format '[Tensor(shape=(1, 2, 3), " \
"dtype=float32)]', and in Paddle >= 2.1, data is in format" \
" 'Tensor(shape=(1, 2, 3), dtype=float32)'\n"
logger = get_logger(
"DataLoader", logging.INFO, fmt='%(levelname)s: %(message)s')
logger.warning(warn_str)
class _IterableDatasetFetcher(_DatasetFetcher): class _IterableDatasetFetcher(_DatasetFetcher):
def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last): def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last):
...@@ -40,9 +73,14 @@ class _IterableDatasetFetcher(_DatasetFetcher): ...@@ -40,9 +73,14 @@ class _IterableDatasetFetcher(_DatasetFetcher):
data.append(next(self.dataset_iter)) data.append(next(self.dataset_iter))
except StopIteration: except StopIteration:
break break
if len(data) == 0 or (self.drop_last and if len(data) == 0 or (self.drop_last and
len(data) < len(batch_indices)): len(data) < len(batch_indices)):
raise StopIteration raise StopIteration
if not isinstance(data[0],
Sequence) and not self._is_warning_logged:
self._log_warning()
self._is_warning_logged = True
else: else:
data = next(self.dataset_iter) data = next(self.dataset_iter)
...@@ -59,6 +97,11 @@ class _MapDatasetFetcher(_DatasetFetcher): ...@@ -59,6 +97,11 @@ class _MapDatasetFetcher(_DatasetFetcher):
def fetch(self, batch_indices): def fetch(self, batch_indices):
if self.auto_collate_batch: if self.auto_collate_batch:
data = [self.dataset[idx] for idx in batch_indices] data = [self.dataset[idx] for idx in batch_indices]
if not isinstance(data[0],
Sequence) and not self._is_warning_logged:
self._log_warning()
self._is_warning_logged = True
else: else:
data = self.dataset[batch_indices] data = self.dataset[batch_indices]
......
...@@ -330,6 +330,59 @@ class TestComplextDataset(unittest.TestCase): ...@@ -330,6 +330,59 @@ class TestComplextDataset(unittest.TestCase):
self.run_main(num_workers) self.run_main(num_workers)
class SingleFieldDataset(Dataset):
def __init__(self, sample_num):
self.sample_num = sample_num
def __len__(self):
return self.sample_num
def __getitem__(self, idx):
return np.random.random((2, 3)).astype('float32')
class TestSingleFieldDataset(unittest.TestCase):
def init_dataset(self):
self.sample_num = 16
self.dataset = SingleFieldDataset(self.sample_num)
def run_main(self, num_workers):
paddle.static.default_startup_program().random_seed = 1
paddle.static.default_main_program().random_seed = 1
place = paddle.CPUPlace()
with fluid.dygraph.guard(place):
self.init_dataset()
dataloader = DataLoader(
self.dataset,
places=place,
num_workers=num_workers,
batch_size=2,
drop_last=True)
for i, data in enumerate(dataloader()):
assert isinstance(data, paddle.Tensor)
assert data.shape == [2, 2, 3]
def test_main(self):
for num_workers in [0, 2]:
self.run_main(num_workers)
class SingleFieldIterableDataset(IterableDataset):
def __init__(self, sample_num):
self.sample_num = sample_num
def __iter__(self):
for _ in range(self.sample_num):
yield np.random.random((2, 3)).astype('float32')
class TestSingleFieldIterableDataset(TestSingleFieldDataset):
def init_dataset(self):
self.sample_num = 16
self.dataset = SingleFieldIterableDataset(self.sample_num)
class TestDataLoaderGenerateStates(unittest.TestCase): class TestDataLoaderGenerateStates(unittest.TestCase):
def setUp(self): def setUp(self):
self.inputs = [(0, 1), (0, 2), (1, 3)] self.inputs = [(0, 1), (0, 2), (1, 3)]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册