未验证 提交 4bd7aa25 编写于 作者: K Kaipeng Deng 提交者: GitHub

use paddle.get_default_dtype in vision datasets. test=develop (#27426)

上级 fc61efd7
......@@ -19,6 +19,7 @@ import numpy as np
import six
from six.moves import cPickle as pickle
import paddle
from paddle.io import Dataset
from paddle.dataset.common import _check_exists_and_download
......@@ -113,6 +114,8 @@ class Cifar10(Dataset):
# read dataset into memory
self._load_data()
self.dtype = paddle.get_default_dtype()
def _init_url_md5_flag(self):
self.data_url = CIFAR10_URL
self.data_md5 = CIFAR10_MD5
......@@ -142,7 +145,7 @@ class Cifar10(Dataset):
image = np.reshape(image, [3, 32, 32])
if self.transform is not None:
image = self.transform(image)
return image, label
return image.astype(self.dtype), np.array(label).astype('int64')
def __len__(self):
return len(self.data)
......
......@@ -21,6 +21,7 @@ import numpy as np
import scipy.io as scio
from PIL import Image
import paddle
from paddle.io import Dataset
from paddle.dataset.common import _check_exists_and_download
......@@ -104,6 +105,8 @@ class Flowers(Dataset):
# read dataset into memory
self._load_anno()
self.dtype = paddle.get_default_dtype()
def _load_anno(self):
self.name2mem = {}
self.data_tar = tarfile.open(self.data_file)
......@@ -124,7 +127,7 @@ class Flowers(Dataset):
if self.transform is not None:
image = self.transform(image)
return image, label.astype('int64')
return image.astype(self.dtype), label.astype('int64')
def __len__(self):
return len(self.indexes)
......@@ -15,6 +15,7 @@
import os
import sys
import paddle
from paddle.io import Dataset
from paddle.utils import try_import
......@@ -143,6 +144,8 @@ class DatasetFolder(Dataset):
self.samples = samples
self.targets = [s[1] for s in samples]
self.dtype = paddle.get_default_dtype()
def _find_classes(self, dir):
"""
Finds the class folders in a dataset.
......
......@@ -19,6 +19,7 @@ import gzip
import struct
import numpy as np
import paddle
from paddle.io import Dataset
from paddle.dataset.common import _check_exists_and_download
......@@ -95,6 +96,8 @@ class MNIST(Dataset):
# read dataset into memory
self._parse_dataset()
self.dtype = paddle.get_default_dtype()
def _parse_dataset(self, buffer_size=100):
self.images = []
self.labels = []
......@@ -145,7 +148,7 @@ class MNIST(Dataset):
image = np.reshape(image, [1, 28, 28])
if self.transform is not None:
image = self.transform(image)
return image, label
return image.astype(self.dtype), label.astype('int64')
def __len__(self):
return len(self.labels)
......@@ -19,6 +19,7 @@ import tarfile
import numpy as np
from PIL import Image
import paddle
from paddle.io import Dataset
from paddle.dataset.common import _check_exists_and_download
......@@ -96,6 +97,8 @@ class VOC2012(Dataset):
# read dataset into memory
self._load_anno()
self.dtype = paddle.get_default_dtype()
def _load_anno(self):
self.name2mem = {}
self.data_tar = tarfile.open(self.data_file)
......@@ -127,7 +130,7 @@ class VOC2012(Dataset):
label = np.array(label)
if self.transform is not None:
data = self.transform(data)
return data, label
return data.astype(self.dtype), label.astype(self.dtype)
def __len__(self):
return len(self.data)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册