diff --git a/python/paddle/vision/datasets/cifar.py b/python/paddle/vision/datasets/cifar.py index 631892ee4dcbf0382bc79ae4279d895872c68ef0..c531f3d0e4e3d276d9831b2ac868af9b0761107d 100644 --- a/python/paddle/vision/datasets/cifar.py +++ b/python/paddle/vision/datasets/cifar.py @@ -19,6 +19,7 @@ import numpy as np import six from six.moves import cPickle as pickle +import paddle from paddle.io import Dataset from paddle.dataset.common import _check_exists_and_download @@ -113,6 +114,8 @@ class Cifar10(Dataset): # read dataset into memory self._load_data() + self.dtype = paddle.get_default_dtype() + def _init_url_md5_flag(self): self.data_url = CIFAR10_URL self.data_md5 = CIFAR10_MD5 @@ -142,7 +145,7 @@ class Cifar10(Dataset): image = np.reshape(image, [3, 32, 32]) if self.transform is not None: image = self.transform(image) - return image, label + return image.astype(self.dtype), np.array(label).astype('int64') def __len__(self): return len(self.data) diff --git a/python/paddle/vision/datasets/flowers.py b/python/paddle/vision/datasets/flowers.py index 1c0f41123e2313d9db6f5e846d133ecdebc7f1af..2251333fd8d281bd07402fbbf3a05fea47a69cce 100644 --- a/python/paddle/vision/datasets/flowers.py +++ b/python/paddle/vision/datasets/flowers.py @@ -21,6 +21,7 @@ import numpy as np import scipy.io as scio from PIL import Image +import paddle from paddle.io import Dataset from paddle.dataset.common import _check_exists_and_download @@ -104,6 +105,8 @@ class Flowers(Dataset): # read dataset into memory self._load_anno() + self.dtype = paddle.get_default_dtype() + def _load_anno(self): self.name2mem = {} self.data_tar = tarfile.open(self.data_file) @@ -124,7 +127,7 @@ class Flowers(Dataset): if self.transform is not None: image = self.transform(image) - return image, label.astype('int64') + return image.astype(self.dtype), label.astype('int64') def __len__(self): return len(self.indexes) diff --git a/python/paddle/vision/datasets/folder.py b/python/paddle/vision/datasets/folder.py index 8a3053abefc1b28ba36150a1ff68a4dd4c3469c9..19d913504bdf7b09de9d888c0caa5cc1c049ac57 100644 --- a/python/paddle/vision/datasets/folder.py +++ b/python/paddle/vision/datasets/folder.py @@ -15,6 +15,7 @@ import os import sys +import paddle from paddle.io import Dataset from paddle.utils import try_import @@ -143,6 +144,8 @@ class DatasetFolder(Dataset): self.samples = samples self.targets = [s[1] for s in samples] + self.dtype = paddle.get_default_dtype() + def _find_classes(self, dir): """ Finds the class folders in a dataset. diff --git a/python/paddle/vision/datasets/mnist.py b/python/paddle/vision/datasets/mnist.py index 597d4046441ddb5b04ad7ceafd83e28e409c674c..16c39e56ef0d65ba89bb611c62e0e957b840a826 100644 --- a/python/paddle/vision/datasets/mnist.py +++ b/python/paddle/vision/datasets/mnist.py @@ -19,6 +19,7 @@ import gzip import struct import numpy as np +import paddle from paddle.io import Dataset from paddle.dataset.common import _check_exists_and_download @@ -95,6 +96,8 @@ class MNIST(Dataset): # read dataset into memory self._parse_dataset() + self.dtype = paddle.get_default_dtype() + def _parse_dataset(self, buffer_size=100): self.images = [] self.labels = [] @@ -145,7 +148,7 @@ class MNIST(Dataset): image = np.reshape(image, [1, 28, 28]) if self.transform is not None: image = self.transform(image) - return image, label + return image.astype(self.dtype), label.astype('int64') def __len__(self): return len(self.labels) diff --git a/python/paddle/vision/datasets/voc2012.py b/python/paddle/vision/datasets/voc2012.py index ae14ea3016363c828d17ba34aca8e1a6663ecf76..5fc9d7c38153e5d8c10da5275f3bb11164b12e54 100644 --- a/python/paddle/vision/datasets/voc2012.py +++ b/python/paddle/vision/datasets/voc2012.py @@ -19,6 +19,7 @@ import tarfile import numpy as np from PIL import Image +import paddle from paddle.io import Dataset from paddle.dataset.common import _check_exists_and_download @@ -96,6 +97,8 @@ class VOC2012(Dataset): # read dataset into memory self._load_anno() + self.dtype = paddle.get_default_dtype() + def _load_anno(self): self.name2mem = {} self.data_tar = tarfile.open(self.data_file) @@ -127,7 +130,7 @@ class VOC2012(Dataset): label = np.array(label) if self.transform is not None: data = self.transform(data) - return data, label + return data.astype(self.dtype), label.astype(self.dtype) def __len__(self): return len(self.data)