From 8327accc58673d04f5bf727378c6acf10e14f6df Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Tue, 20 Oct 2020 10:52:51 +0800 Subject: [PATCH] Fix dataloader when stack input data with different type (#27950) * fix dataloader --- .../paddle/fluid/dataloader/dataloader_iter.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/dataloader/dataloader_iter.py b/python/paddle/fluid/dataloader/dataloader_iter.py index a0ef750da90..7d203b349a1 100644 --- a/python/paddle/fluid/dataloader/dataloader_iter.py +++ b/python/paddle/fluid/dataloader/dataloader_iter.py @@ -17,6 +17,7 @@ import six import sys import time import signal +import numbers import logging import itertools import threading @@ -81,12 +82,17 @@ def default_collate_fn(batch): else: slots[i].append(item) - if isinstance(slots[0][0], np.ndarray): - return [np.stack(slot, axis=0) for slot in slots] - elif isinstance(slots[0][0], paddle.Tensor): - return [layers.stack(slot, axis=0) for slot in slots] - else: - raise RuntimeError("Unknown data type {}".format(type(slots[0][0]))) + outputs = [] + for slot in slots: + if isinstance(slot[0], (np.ndarray, np.bool, numbers.Number)): + tmp = np.stack(slot, axis=0) + outputs.append(tmp) + elif isinstance(slot[0], paddle.Tensor): + tmp = layers.stack(slot, axis=0) + outputs.append(tmp) + else: + raise RuntimeError("Unknown data type {}".format(type(slot[0]))) + return outputs class _DatasetKind(object): -- GitLab