未验证 提交 d4160941 编写于 作者: L LielinJiang 提交者: GitHub

fix dataloader (#28105)

上级 0b294906
......@@ -17,6 +17,7 @@ import six
import sys
import time
import signal
import numbers
import logging
import itertools
import threading
......@@ -81,12 +82,17 @@ def default_collate_fn(batch):
else:
slots[i].append(item)
if isinstance(slots[0][0], np.ndarray):
return [np.stack(slot, axis=0) for slot in slots]
elif isinstance(slots[0][0], paddle.Tensor):
return [layers.stack(slot, axis=0) for slot in slots]
outputs = []
for slot in slots:
if isinstance(slot[0], (np.ndarray, np.bool, numbers.Number)):
tmp = np.stack(slot, axis=0)
outputs.append(tmp)
elif isinstance(slot[0], paddle.Tensor):
tmp = layers.stack(slot, axis=0)
outputs.append(tmp)
else:
raise RuntimeError("Unknown data type {}".format(type(slots[0][0])))
raise RuntimeError("Unknown data type {}".format(type(slot[0])))
return outputs
class _DatasetKind(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册