未验证 提交 91bab752 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix dataloader default value and doc (#28728)

* fix dataloader. test=develop
上级 0ed80e09
......@@ -196,7 +196,7 @@ class DataLoader(object):
the key of the dict is the name of each fed variables. If
:attr:`return_list=True`, the return value on each device would
be a list(Tensor). :attr:`return_list` can only be True
in dynamic graph mode. Default False.
in dynamic graph mode. Default True.
batch_sampler(BatchSampler): an instance of `paddle.io.BatchSampler`
to generate batch indices to draw samples from :attr:`dataset`
and combine a batch. Default None.
......@@ -308,7 +308,7 @@ class DataLoader(object):
dataset,
feed_list=None,
places=None,
return_list=False,
return_list=True,
batch_sampler=None,
batch_size=1,
shuffle=False,
......@@ -403,10 +403,10 @@ class DataLoader(object):
if self.dataset_kind == _DatasetKind.ITER:
raise ValueError("length of IterableDataset not supported")
else:
if self.batch_size is None:
return len(self.dataset)
else:
if self.auto_collate_batch:
return len(self.batch_sampler)
else:
return len(self.dataset)
def __iter__(self):
if self.num_workers == 0:
......
......@@ -112,6 +112,7 @@ class TestStaticDataLoader(unittest.TestCase):
places=places,
num_workers=num_workers,
batch_size=BATCH_SIZE,
return_list=False,
drop_last=True)
# assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)
......@@ -199,6 +200,7 @@ class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader):
places=places,
num_workers=num_workers,
batch_size=None,
return_list=False,
drop_last=True)
exe = fluid.Executor(place=places[0])
......
......@@ -113,6 +113,7 @@ class TestStaticDataLoader(unittest.TestCase):
places=places,
num_workers=num_workers,
batch_size=BATCH_SIZE,
return_list=False,
drop_last=True)
assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)
......@@ -226,7 +227,8 @@ class RandomBatchedDataset(Dataset):
labels = []
for _ in range(BATCH_SIZE):
image = np.random.random([IMAGE_SIZE]).astype('float32')
label = np.random.randint(0, self.class_num - 1, (1, )).astype('int64')
label = np.random.randint(0, self.class_num - 1,
(1, )).astype('int64')
images.append(image)
labels.append(label)
return np.stack(images, axis=0), np.stack(labels, axis=0)
......@@ -248,6 +250,7 @@ class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader):
places=places,
num_workers=num_workers,
batch_size=None,
return_list=False,
drop_last=True)
assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册