From d4160941f32a727aebad84dce4fe4e6d38c670ff Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Tue, 20 Oct 2020 14:34:55 +0800 Subject: [PATCH] fix dataloader (#28105) --- .../paddle/fluid/dataloader/dataloader_iter.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/dataloader/dataloader_iter.py b/python/paddle/fluid/dataloader/dataloader_iter.py index 1ef0d494e07..7a4e3cb6ec8 100644 --- a/python/paddle/fluid/dataloader/dataloader_iter.py +++ b/python/paddle/fluid/dataloader/dataloader_iter.py @@ -17,6 +17,7 @@ import six import sys import time import signal +import numbers import logging import itertools import threading @@ -81,12 +82,17 @@ def default_collate_fn(batch): else: slots[i].append(item) - if isinstance(slots[0][0], np.ndarray): - return [np.stack(slot, axis=0) for slot in slots] - elif isinstance(slots[0][0], paddle.Tensor): - return [layers.stack(slot, axis=0) for slot in slots] - else: - raise RuntimeError("Unknown data type {}".format(type(slots[0][0]))) + outputs = [] + for slot in slots: + if isinstance(slot[0], (np.ndarray, np.bool, numbers.Number)): + tmp = np.stack(slot, axis=0) + outputs.append(tmp) + elif isinstance(slot[0], paddle.Tensor): + tmp = layers.stack(slot, axis=0) + outputs.append(tmp) + else: + raise RuntimeError("Unknown data type {}".format(type(slot[0]))) + return outputs class _DatasetKind(object): -- GitLab