diff --git a/python/paddle/fluid/dataloader/dataloader_iter.py b/python/paddle/fluid/dataloader/dataloader_iter.py index a0ef750da90a34342b87b775fa2058c9e778a897..7d203b349a130e5d1d2e007bcd64dd47caea886d 100644 --- a/python/paddle/fluid/dataloader/dataloader_iter.py +++ b/python/paddle/fluid/dataloader/dataloader_iter.py @@ -17,6 +17,7 @@ import six import sys import time import signal +import numbers import logging import itertools import threading @@ -81,12 +82,17 @@ def default_collate_fn(batch): else: slots[i].append(item) - if isinstance(slots[0][0], np.ndarray): - return [np.stack(slot, axis=0) for slot in slots] - elif isinstance(slots[0][0], paddle.Tensor): - return [layers.stack(slot, axis=0) for slot in slots] - else: - raise RuntimeError("Unknown data type {}".format(type(slots[0][0]))) + outputs = [] + for slot in slots: + if isinstance(slot[0], (np.ndarray, np.bool, numbers.Number)): + tmp = np.stack(slot, axis=0) + outputs.append(tmp) + elif isinstance(slot[0], paddle.Tensor): + tmp = layers.stack(slot, axis=0) + outputs.append(tmp) + else: + raise RuntimeError("Unknown data type {}".format(type(slot[0]))) + return outputs class _DatasetKind(object):