未验证 提交 c4ddc3ab 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix dataloader collate return list mix tensor and numpy array (#30904)

* fix dataloader collate return list mix tensor and numpy array. test=develop
上级 5b267474
......@@ -440,11 +440,16 @@ def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
if use_shared_memory:
# FIXME(dkp): _convert_to_tensor_list only support np.array
# list now, should support paddle.Tensor list
if isinstance(batch[0][0], paddle.Tensor):
np_batch = []
for sample in batch:
np_batch.append([s.numpy() for s in sample])
batch = np_batch
new_batch = []
for sample in batch:
new_sample = []
for s in sample:
if isinstance(s, paddle.Tensor):
new_sample.append(s.numpy())
else:
new_sample.append(s)
new_batch.append(new_sample)
batch = new_batch
tensor_list = core._convert_to_tensor_list(batch)
out_queue.put((idx, tensor_list))
......
......@@ -235,5 +235,43 @@ class TestChainDataset(unittest.TestCase):
self.run_main(num_workers=0, places=p)
class NumpyMixTensorDataset(Dataset):
def __init__(self, sample_num):
self.sample_num = sample_num
def __len__(self):
return self.sample_num
def __getitem__(self, idx):
np.random.seed(idx)
image = np.random.random([IMAGE_SIZE]).astype('float32')
label = np.random.randint(0, 9, (1, )).astype('int64')
return paddle.to_tensor(image, place=paddle.CPUPlace()), label
class TestNumpyMixTensorDataset(TestTensorDataset):
def run_main(self, num_workers, places):
paddle.static.default_startup_program().random_seed = 1
paddle.static.default_main_program().random_seed = 1
place = paddle.CPUPlace()
with fluid.dygraph.guard(place):
dataset = NumpyMixTensorDataset(16)
assert len(dataset) == 16
dataloader = DataLoader(
dataset,
places=place,
num_workers=num_workers,
batch_size=1,
drop_last=True)
for i, (input, label) in enumerate(dataloader()):
assert len(input) == 1
assert len(label) == 1
assert input.shape == [1, IMAGE_SIZE]
assert label.shape == [1, 1]
assert isinstance(input, paddle.Tensor)
assert isinstance(label, paddle.Tensor)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册