未验证 提交 4bd7aa25 编写于 作者: K Kaipeng Deng 提交者: GitHub

use paddle.get_default_dtype in vision datasets. test=develop (#27426)

上级 fc61efd7
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
import six import six
from six.moves import cPickle as pickle from six.moves import cPickle as pickle
import paddle
from paddle.io import Dataset from paddle.io import Dataset
from paddle.dataset.common import _check_exists_and_download from paddle.dataset.common import _check_exists_and_download
...@@ -113,6 +114,8 @@ class Cifar10(Dataset): ...@@ -113,6 +114,8 @@ class Cifar10(Dataset):
# read dataset into memory # read dataset into memory
self._load_data() self._load_data()
self.dtype = paddle.get_default_dtype()
def _init_url_md5_flag(self): def _init_url_md5_flag(self):
self.data_url = CIFAR10_URL self.data_url = CIFAR10_URL
self.data_md5 = CIFAR10_MD5 self.data_md5 = CIFAR10_MD5
...@@ -142,7 +145,7 @@ class Cifar10(Dataset): ...@@ -142,7 +145,7 @@ class Cifar10(Dataset):
image = np.reshape(image, [3, 32, 32]) image = np.reshape(image, [3, 32, 32])
if self.transform is not None: if self.transform is not None:
image = self.transform(image) image = self.transform(image)
return image, label return image.astype(self.dtype), np.array(label).astype('int64')
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
......
...@@ -21,6 +21,7 @@ import numpy as np ...@@ -21,6 +21,7 @@ import numpy as np
import scipy.io as scio import scipy.io as scio
from PIL import Image from PIL import Image
import paddle
from paddle.io import Dataset from paddle.io import Dataset
from paddle.dataset.common import _check_exists_and_download from paddle.dataset.common import _check_exists_and_download
...@@ -104,6 +105,8 @@ class Flowers(Dataset): ...@@ -104,6 +105,8 @@ class Flowers(Dataset):
# read dataset into memory # read dataset into memory
self._load_anno() self._load_anno()
self.dtype = paddle.get_default_dtype()
def _load_anno(self): def _load_anno(self):
self.name2mem = {} self.name2mem = {}
self.data_tar = tarfile.open(self.data_file) self.data_tar = tarfile.open(self.data_file)
...@@ -124,7 +127,7 @@ class Flowers(Dataset): ...@@ -124,7 +127,7 @@ class Flowers(Dataset):
if self.transform is not None: if self.transform is not None:
image = self.transform(image) image = self.transform(image)
return image, label.astype('int64') return image.astype(self.dtype), label.astype('int64')
def __len__(self): def __len__(self):
return len(self.indexes) return len(self.indexes)
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import os import os
import sys import sys
import paddle
from paddle.io import Dataset from paddle.io import Dataset
from paddle.utils import try_import from paddle.utils import try_import
...@@ -143,6 +144,8 @@ class DatasetFolder(Dataset): ...@@ -143,6 +144,8 @@ class DatasetFolder(Dataset):
self.samples = samples self.samples = samples
self.targets = [s[1] for s in samples] self.targets = [s[1] for s in samples]
self.dtype = paddle.get_default_dtype()
def _find_classes(self, dir): def _find_classes(self, dir):
""" """
Finds the class folders in a dataset. Finds the class folders in a dataset.
......
...@@ -19,6 +19,7 @@ import gzip ...@@ -19,6 +19,7 @@ import gzip
import struct import struct
import numpy as np import numpy as np
import paddle
from paddle.io import Dataset from paddle.io import Dataset
from paddle.dataset.common import _check_exists_and_download from paddle.dataset.common import _check_exists_and_download
...@@ -95,6 +96,8 @@ class MNIST(Dataset): ...@@ -95,6 +96,8 @@ class MNIST(Dataset):
# read dataset into memory # read dataset into memory
self._parse_dataset() self._parse_dataset()
self.dtype = paddle.get_default_dtype()
def _parse_dataset(self, buffer_size=100): def _parse_dataset(self, buffer_size=100):
self.images = [] self.images = []
self.labels = [] self.labels = []
...@@ -145,7 +148,7 @@ class MNIST(Dataset): ...@@ -145,7 +148,7 @@ class MNIST(Dataset):
image = np.reshape(image, [1, 28, 28]) image = np.reshape(image, [1, 28, 28])
if self.transform is not None: if self.transform is not None:
image = self.transform(image) image = self.transform(image)
return image, label return image.astype(self.dtype), label.astype('int64')
def __len__(self): def __len__(self):
return len(self.labels) return len(self.labels)
...@@ -19,6 +19,7 @@ import tarfile ...@@ -19,6 +19,7 @@ import tarfile
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import paddle
from paddle.io import Dataset from paddle.io import Dataset
from paddle.dataset.common import _check_exists_and_download from paddle.dataset.common import _check_exists_and_download
...@@ -96,6 +97,8 @@ class VOC2012(Dataset): ...@@ -96,6 +97,8 @@ class VOC2012(Dataset):
# read dataset into memory # read dataset into memory
self._load_anno() self._load_anno()
self.dtype = paddle.get_default_dtype()
def _load_anno(self): def _load_anno(self):
self.name2mem = {} self.name2mem = {}
self.data_tar = tarfile.open(self.data_file) self.data_tar = tarfile.open(self.data_file)
...@@ -127,7 +130,7 @@ class VOC2012(Dataset): ...@@ -127,7 +130,7 @@ class VOC2012(Dataset):
label = np.array(label) label = np.array(label)
if self.transform is not None: if self.transform is not None:
data = self.transform(data) data = self.transform(data)
return data, label return data.astype(self.dtype), label.astype(self.dtype)
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册