未验证 提交 91bab752 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix dataloader default value and doc (#28728)

* fix dataloader. test=develop
上级 0ed80e09
...@@ -196,7 +196,7 @@ class DataLoader(object): ...@@ -196,7 +196,7 @@ class DataLoader(object):
the key of the dict is the name of each fed variables. If the key of the dict is the name of each fed variables. If
:attr:`return_list=True`, the return value on each device would :attr:`return_list=True`, the return value on each device would
be a list(Tensor). :attr:`return_list` can only be True be a list(Tensor). :attr:`return_list` can only be True
in dynamic graph mode. Default False. in dynamic graph mode. Default True.
batch_sampler(BatchSampler): an instance of `paddle.io.BatchSampler` batch_sampler(BatchSampler): an instance of `paddle.io.BatchSampler`
to generate batch indices to draw samples from :attr:`dataset` to generate batch indices to draw samples from :attr:`dataset`
and combine a batch. Default None. and combine a batch. Default None.
...@@ -308,7 +308,7 @@ class DataLoader(object): ...@@ -308,7 +308,7 @@ class DataLoader(object):
dataset, dataset,
feed_list=None, feed_list=None,
places=None, places=None,
return_list=False, return_list=True,
batch_sampler=None, batch_sampler=None,
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
...@@ -403,10 +403,10 @@ class DataLoader(object): ...@@ -403,10 +403,10 @@ class DataLoader(object):
if self.dataset_kind == _DatasetKind.ITER: if self.dataset_kind == _DatasetKind.ITER:
raise ValueError("length of IterableDataset not supported") raise ValueError("length of IterableDataset not supported")
else: else:
if self.batch_size is None: if self.auto_collate_batch:
return len(self.dataset)
else:
return len(self.batch_sampler) return len(self.batch_sampler)
else:
return len(self.dataset)
def __iter__(self): def __iter__(self):
if self.num_workers == 0: if self.num_workers == 0:
......
...@@ -112,6 +112,7 @@ class TestStaticDataLoader(unittest.TestCase): ...@@ -112,6 +112,7 @@ class TestStaticDataLoader(unittest.TestCase):
places=places, places=places,
num_workers=num_workers, num_workers=num_workers,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
return_list=False,
drop_last=True) drop_last=True)
# assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE) # assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)
...@@ -199,6 +200,7 @@ class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader): ...@@ -199,6 +200,7 @@ class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader):
places=places, places=places,
num_workers=num_workers, num_workers=num_workers,
batch_size=None, batch_size=None,
return_list=False,
drop_last=True) drop_last=True)
exe = fluid.Executor(place=places[0]) exe = fluid.Executor(place=places[0])
......
...@@ -113,6 +113,7 @@ class TestStaticDataLoader(unittest.TestCase): ...@@ -113,6 +113,7 @@ class TestStaticDataLoader(unittest.TestCase):
places=places, places=places,
num_workers=num_workers, num_workers=num_workers,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
return_list=False,
drop_last=True) drop_last=True)
assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE) assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)
...@@ -226,7 +227,8 @@ class RandomBatchedDataset(Dataset): ...@@ -226,7 +227,8 @@ class RandomBatchedDataset(Dataset):
labels = [] labels = []
for _ in range(BATCH_SIZE): for _ in range(BATCH_SIZE):
image = np.random.random([IMAGE_SIZE]).astype('float32') image = np.random.random([IMAGE_SIZE]).astype('float32')
label = np.random.randint(0, self.class_num - 1, (1, )).astype('int64') label = np.random.randint(0, self.class_num - 1,
(1, )).astype('int64')
images.append(image) images.append(image)
labels.append(label) labels.append(label)
return np.stack(images, axis=0), np.stack(labels, axis=0) return np.stack(images, axis=0), np.stack(labels, axis=0)
...@@ -248,6 +250,7 @@ class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader): ...@@ -248,6 +250,7 @@ class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader):
places=places, places=places,
num_workers=num_workers, num_workers=num_workers,
batch_size=None, batch_size=None,
return_list=False,
drop_last=True) drop_last=True)
assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE) assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册