未验证 提交 d4160941 编写于 作者: L LielinJiang 提交者: GitHub

fix dataloader (#28105)

上级 0b294906
...@@ -17,6 +17,7 @@ import six ...@@ -17,6 +17,7 @@ import six
import sys import sys
import time import time
import signal import signal
import numbers
import logging import logging
import itertools import itertools
import threading import threading
...@@ -81,12 +82,17 @@ def default_collate_fn(batch): ...@@ -81,12 +82,17 @@ def default_collate_fn(batch):
else: else:
slots[i].append(item) slots[i].append(item)
if isinstance(slots[0][0], np.ndarray): outputs = []
return [np.stack(slot, axis=0) for slot in slots] for slot in slots:
elif isinstance(slots[0][0], paddle.Tensor): if isinstance(slot[0], (np.ndarray, np.bool, numbers.Number)):
return [layers.stack(slot, axis=0) for slot in slots] tmp = np.stack(slot, axis=0)
else: outputs.append(tmp)
raise RuntimeError("Unknown data type {}".format(type(slots[0][0]))) elif isinstance(slot[0], paddle.Tensor):
tmp = layers.stack(slot, axis=0)
outputs.append(tmp)
else:
raise RuntimeError("Unknown data type {}".format(type(slot[0])))
return outputs
class _DatasetKind(object): class _DatasetKind(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册