未验证 提交 fb1e0c93 编写于 作者: L LielinJiang 提交者: GitHub

Make vision datasets return PIL.Image as default (#28264)

* return pil image as default according backend
上级 26ede6e0
...@@ -296,12 +296,17 @@ class ProgBarLogger(Callback): ...@@ -296,12 +296,17 @@ class ProgBarLogger(Callback):
.. code-block:: python .. code-block:: python
import paddle import paddle
import paddle.vision.transforms as T
from paddle.static import InputSpec from paddle.static import InputSpec
inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')] inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')]
labels = [InputSpec([None, 1], 'int64', 'label')] labels = [InputSpec([None, 1], 'int64', 'label')]
train_dataset = paddle.vision.datasets.MNIST(mode='train') transform = T.Compose([
T.Transpose(),
T.Normalize([127.5], [127.5])
])
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
lenet = paddle.vision.LeNet() lenet = paddle.vision.LeNet()
model = paddle.Model(lenet, model = paddle.Model(lenet,
...@@ -432,12 +437,17 @@ class ModelCheckpoint(Callback): ...@@ -432,12 +437,17 @@ class ModelCheckpoint(Callback):
.. code-block:: python .. code-block:: python
import paddle import paddle
import paddle.vision.transforms as T
from paddle.static import InputSpec from paddle.static import InputSpec
inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')] inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')]
labels = [InputSpec([None, 1], 'int64', 'label')] labels = [InputSpec([None, 1], 'int64', 'label')]
train_dataset = paddle.vision.datasets.MNIST(mode='train') transform = T.Compose([
T.Transpose(),
T.Normalize([127.5], [127.5])
])
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
lenet = paddle.vision.LeNet() lenet = paddle.vision.LeNet()
model = paddle.Model(lenet, model = paddle.Model(lenet,
...@@ -484,13 +494,18 @@ class VisualDL(Callback): ...@@ -484,13 +494,18 @@ class VisualDL(Callback):
.. code-block:: python .. code-block:: python
import paddle import paddle
import paddle.vision.transforms as T
from paddle.static import InputSpec from paddle.static import InputSpec
inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')] inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')]
labels = [InputSpec([None, 1], 'int64', 'label')] labels = [InputSpec([None, 1], 'int64', 'label')]
train_dataset = paddle.vision.datasets.MNIST(mode='train') transform = T.Compose([
eval_dataset = paddle.vision.datasets.MNIST(mode='test') T.Transpose(),
T.Normalize([127.5], [127.5])
])
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
eval_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
net = paddle.vision.LeNet() net = paddle.vision.LeNet()
model = paddle.Model(net, inputs, labels) model = paddle.Model(net, inputs, labels)
......
...@@ -837,6 +837,7 @@ class Model(object): ...@@ -837,6 +837,7 @@ class Model(object):
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.vision.transforms as T
from paddle.static import InputSpec from paddle.static import InputSpec
device = paddle.set_device('cpu') # or 'gpu' device = paddle.set_device('cpu') # or 'gpu'
...@@ -858,7 +859,11 @@ class Model(object): ...@@ -858,7 +859,11 @@ class Model(object):
paddle.nn.CrossEntropyLoss(), paddle.nn.CrossEntropyLoss(),
paddle.metric.Accuracy()) paddle.metric.Accuracy())
data = paddle.vision.datasets.MNIST(mode='train') transform = T.Compose([
T.Transpose(),
T.Normalize([127.5], [127.5])
])
data = paddle.vision.datasets.MNIST(mode='train', transform=transform)
model.fit(data, epochs=2, batch_size=32, verbose=1) model.fit(data, epochs=2, batch_size=32, verbose=1)
""" """
...@@ -1067,6 +1072,7 @@ class Model(object): ...@@ -1067,6 +1072,7 @@ class Model(object):
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.vision.transforms as T
from paddle.static import InputSpec from paddle.static import InputSpec
class Mnist(nn.Layer): class Mnist(nn.Layer):
...@@ -1093,7 +1099,13 @@ class Model(object): ...@@ -1093,7 +1099,13 @@ class Model(object):
optim = paddle.optimizer.SGD(learning_rate=1e-3, optim = paddle.optimizer.SGD(learning_rate=1e-3,
parameters=model.parameters()) parameters=model.parameters())
model.prepare(optim, paddle.nn.CrossEntropyLoss()) model.prepare(optim, paddle.nn.CrossEntropyLoss())
data = paddle.vision.datasets.MNIST(mode='train')
transform = T.Compose([
T.Transpose(),
T.Normalize([127.5], [127.5])
])
data = paddle.vision.datasets.MNIST(mode='train', transform=transform)
model.fit(data, epochs=1, batch_size=32, verbose=0) model.fit(data, epochs=1, batch_size=32, verbose=0)
model.save('checkpoint/test') # save for training model.save('checkpoint/test') # save for training
model.save('inference_model', False) # save for inference model.save('inference_model', False) # save for inference
...@@ -1353,14 +1365,19 @@ class Model(object): ...@@ -1353,14 +1365,19 @@ class Model(object):
.. code-block:: python .. code-block:: python
import paddle import paddle
import paddle.vision.transforms as T
from paddle.static import InputSpec from paddle.static import InputSpec
dynamic = True dynamic = True
device = paddle.set_device('cpu') # or 'gpu' device = paddle.set_device('cpu') # or 'gpu'
paddle.disable_static(device) if dynamic else None paddle.disable_static(device) if dynamic else None
train_dataset = paddle.vision.datasets.MNIST(mode='train') transform = T.Compose([
val_dataset = paddle.vision.datasets.MNIST(mode='test') T.Transpose(),
T.Normalize([127.5], [127.5])
])
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
val_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
input = InputSpec([None, 1, 28, 28], 'float32', 'image') input = InputSpec([None, 1, 28, 28], 'float32', 'image')
label = InputSpec([None, 1], 'int64', 'label') label = InputSpec([None, 1], 'int64', 'label')
...@@ -1386,16 +1403,21 @@ class Model(object): ...@@ -1386,16 +1403,21 @@ class Model(object):
.. code-block:: python .. code-block:: python
import paddle import paddle
import paddle.vision.transforms as T
from paddle.static import InputSpec from paddle.static import InputSpec
dynamic = True dynamic = True
device = paddle.set_device('cpu') # or 'gpu' device = paddle.set_device('cpu') # or 'gpu'
paddle.disable_static(device) if dynamic else None paddle.disable_static(device) if dynamic else None
train_dataset = paddle.vision.datasets.MNIST(mode='train') transform = T.Compose([
T.Transpose(),
T.Normalize([127.5], [127.5])
])
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
train_loader = paddle.io.DataLoader(train_dataset, train_loader = paddle.io.DataLoader(train_dataset,
places=device, batch_size=64) places=device, batch_size=64)
val_dataset = paddle.vision.datasets.MNIST(mode='test') val_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
val_loader = paddle.io.DataLoader(val_dataset, val_loader = paddle.io.DataLoader(val_dataset,
places=device, batch_size=64) places=device, batch_size=64)
...@@ -1522,10 +1544,15 @@ class Model(object): ...@@ -1522,10 +1544,15 @@ class Model(object):
.. code-block:: python .. code-block:: python
import paddle import paddle
import paddle.vision.transforms as T
from paddle.static import InputSpec from paddle.static import InputSpec
# declarative mode # declarative mode
val_dataset = paddle.vision.datasets.MNIST(mode='test') transform = T.Compose([
T.Transpose(),
T.Normalize([127.5], [127.5])
])
val_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
input = InputSpec([-1, 1, 28, 28], 'float32', 'image') input = InputSpec([-1, 1, 28, 28], 'float32', 'image')
label = InputSpec([None, 1], 'int64', 'label') label = InputSpec([None, 1], 'int64', 'label')
......
...@@ -24,6 +24,7 @@ from paddle import Model ...@@ -24,6 +24,7 @@ from paddle import Model
from paddle.static import InputSpec from paddle.static import InputSpec
from paddle.vision.models import LeNet from paddle.vision.models import LeNet
from paddle.hapi.callbacks import config_callbacks from paddle.hapi.callbacks import config_callbacks
import paddle.vision.transforms as T
class TestCallbacks(unittest.TestCase): class TestCallbacks(unittest.TestCase):
...@@ -112,8 +113,11 @@ class TestCallbacks(unittest.TestCase): ...@@ -112,8 +113,11 @@ class TestCallbacks(unittest.TestCase):
inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')] inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')]
labels = [InputSpec([None, 1], 'int64', 'label')] labels = [InputSpec([None, 1], 'int64', 'label')]
train_dataset = paddle.vision.datasets.MNIST(mode='train') transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
eval_dataset = paddle.vision.datasets.MNIST(mode='test') train_dataset = paddle.vision.datasets.MNIST(
mode='train', transform=transform)
eval_dataset = paddle.vision.datasets.MNIST(
mode='test', transform=transform)
net = paddle.vision.LeNet() net = paddle.vision.LeNet()
model = paddle.Model(net, inputs, labels) model = paddle.Model(net, inputs, labels)
......
...@@ -27,10 +27,11 @@ class TestCifar10Train(unittest.TestCase): ...@@ -27,10 +27,11 @@ class TestCifar10Train(unittest.TestCase):
# long time, randomly check 1 sample # long time, randomly check 1 sample
idx = np.random.randint(0, 50000) idx = np.random.randint(0, 50000)
data, label = cifar[idx] data, label = cifar[idx]
data = np.array(data)
self.assertTrue(len(data.shape) == 3) self.assertTrue(len(data.shape) == 3)
self.assertTrue(data.shape[0] == 3) self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[2] == 32) self.assertTrue(data.shape[0] == 32)
self.assertTrue(0 <= int(label) <= 9) self.assertTrue(0 <= int(label) <= 9)
...@@ -43,12 +44,30 @@ class TestCifar10Test(unittest.TestCase): ...@@ -43,12 +44,30 @@ class TestCifar10Test(unittest.TestCase):
# long time, randomly check 1 sample # long time, randomly check 1 sample
idx = np.random.randint(0, 10000) idx = np.random.randint(0, 10000)
data, label = cifar[idx] data, label = cifar[idx]
data = np.array(data)
self.assertTrue(len(data.shape) == 3) self.assertTrue(len(data.shape) == 3)
self.assertTrue(data.shape[0] == 3) self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[2] == 32) self.assertTrue(data.shape[0] == 32)
self.assertTrue(0 <= int(label) <= 9) self.assertTrue(0 <= int(label) <= 9)
# test cv2 backend
cifar = Cifar10(mode='test', backend='cv2')
self.assertTrue(len(cifar) == 10000)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 10000)
data, label = cifar[idx]
self.assertTrue(len(data.shape) == 3)
self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32)
self.assertTrue(0 <= int(label) <= 99)
with self.assertRaises(ValueError):
cifar = Cifar10(mode='test', backend=1)
class TestCifar100Train(unittest.TestCase): class TestCifar100Train(unittest.TestCase):
def test_main(self): def test_main(self):
...@@ -59,10 +78,11 @@ class TestCifar100Train(unittest.TestCase): ...@@ -59,10 +78,11 @@ class TestCifar100Train(unittest.TestCase):
# long time, randomly check 1 sample # long time, randomly check 1 sample
idx = np.random.randint(0, 50000) idx = np.random.randint(0, 50000)
data, label = cifar[idx] data, label = cifar[idx]
data = np.array(data)
self.assertTrue(len(data.shape) == 3) self.assertTrue(len(data.shape) == 3)
self.assertTrue(data.shape[0] == 3) self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[2] == 32) self.assertTrue(data.shape[0] == 32)
self.assertTrue(0 <= int(label) <= 99) self.assertTrue(0 <= int(label) <= 99)
...@@ -75,12 +95,30 @@ class TestCifar100Test(unittest.TestCase): ...@@ -75,12 +95,30 @@ class TestCifar100Test(unittest.TestCase):
# long time, randomly check 1 sample # long time, randomly check 1 sample
idx = np.random.randint(0, 10000) idx = np.random.randint(0, 10000)
data, label = cifar[idx] data, label = cifar[idx]
data = np.array(data)
self.assertTrue(len(data.shape) == 3) self.assertTrue(len(data.shape) == 3)
self.assertTrue(data.shape[0] == 3) self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32) self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[2] == 32) self.assertTrue(data.shape[0] == 32)
self.assertTrue(0 <= int(label) <= 99) self.assertTrue(0 <= int(label) <= 99)
# test cv2 backend
cifar = Cifar100(mode='test', backend='cv2')
self.assertTrue(len(cifar) == 10000)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 10000)
data, label = cifar[idx]
self.assertTrue(len(data.shape) == 3)
self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32)
self.assertTrue(0 <= int(label) <= 99)
with self.assertRaises(ValueError):
cifar = Cifar100(mode='test', backend=1)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -32,6 +32,9 @@ class TestVOC2012Train(unittest.TestCase): ...@@ -32,6 +32,9 @@ class TestVOC2012Train(unittest.TestCase):
# long time, randomly check 1 sample # long time, randomly check 1 sample
idx = np.random.randint(0, 3) idx = np.random.randint(0, 3)
image, label = voc2012[idx] image, label = voc2012[idx]
image = np.array(image)
label = np.array(label)
self.assertTrue(len(image.shape) == 3) self.assertTrue(len(image.shape) == 3)
self.assertTrue(len(label.shape) == 2) self.assertTrue(len(label.shape) == 2)
...@@ -45,6 +48,9 @@ class TestVOC2012Valid(unittest.TestCase): ...@@ -45,6 +48,9 @@ class TestVOC2012Valid(unittest.TestCase):
# long time, randomly check 1 sample # long time, randomly check 1 sample
idx = np.random.randint(0, 1) idx = np.random.randint(0, 1)
image, label = voc2012[idx] image, label = voc2012[idx]
image = np.array(image)
label = np.array(label)
self.assertTrue(len(image.shape) == 3) self.assertTrue(len(image.shape) == 3)
self.assertTrue(len(label.shape) == 2) self.assertTrue(len(label.shape) == 2)
...@@ -58,9 +64,27 @@ class TestVOC2012Test(unittest.TestCase): ...@@ -58,9 +64,27 @@ class TestVOC2012Test(unittest.TestCase):
# long time, randomly check 1 sample # long time, randomly check 1 sample
idx = np.random.randint(0, 1) idx = np.random.randint(0, 1)
image, label = voc2012[idx] image, label = voc2012[idx]
image = np.array(image)
label = np.array(label)
self.assertTrue(len(image.shape) == 3)
self.assertTrue(len(label.shape) == 2)
# test cv2 backend
voc2012 = VOC2012(mode='test', backend='cv2')
self.assertTrue(len(voc2012) == 2)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 1)
image, label = voc2012[idx]
self.assertTrue(len(image.shape) == 3) self.assertTrue(len(image.shape) == 3)
self.assertTrue(len(label.shape) == 2) self.assertTrue(len(label.shape) == 2)
with self.assertRaises(ValueError):
voc2012 = VOC2012(mode='test', backend=1)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -19,6 +19,7 @@ import tempfile ...@@ -19,6 +19,7 @@ import tempfile
import shutil import shutil
import cv2 import cv2
import paddle.vision.transforms as T
from paddle.vision.datasets import * from paddle.vision.datasets import *
from paddle.dataset.common import _check_exists_and_download from paddle.dataset.common import _check_exists_and_download
...@@ -89,7 +90,8 @@ class TestFolderDatasets(unittest.TestCase): ...@@ -89,7 +90,8 @@ class TestFolderDatasets(unittest.TestCase):
class TestMNISTTest(unittest.TestCase): class TestMNISTTest(unittest.TestCase):
def test_main(self): def test_main(self):
mnist = MNIST(mode='test') transform = T.Transpose()
mnist = MNIST(mode='test', transform=transform)
self.assertTrue(len(mnist) == 10000) self.assertTrue(len(mnist) == 10000)
for i in range(len(mnist)): for i in range(len(mnist)):
...@@ -103,7 +105,8 @@ class TestMNISTTest(unittest.TestCase): ...@@ -103,7 +105,8 @@ class TestMNISTTest(unittest.TestCase):
class TestMNISTTrain(unittest.TestCase): class TestMNISTTrain(unittest.TestCase):
def test_main(self): def test_main(self):
mnist = MNIST(mode='train') transform = T.Transpose()
mnist = MNIST(mode='train', transform=transform)
self.assertTrue(len(mnist) == 60000) self.assertTrue(len(mnist) == 60000)
for i in range(len(mnist)): for i in range(len(mnist)):
...@@ -114,6 +117,22 @@ class TestMNISTTrain(unittest.TestCase): ...@@ -114,6 +117,22 @@ class TestMNISTTrain(unittest.TestCase):
self.assertTrue(label.shape[0] == 1) self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 9) self.assertTrue(0 <= int(label) <= 9)
# test cv2 backend
mnist = MNIST(mode='train', transform=transform, backend='cv2')
self.assertTrue(len(mnist) == 60000)
for i in range(len(mnist)):
image, label = mnist[i]
self.assertTrue(image.shape[0] == 1)
self.assertTrue(image.shape[1] == 28)
self.assertTrue(image.shape[2] == 28)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 9)
break
with self.assertRaises(ValueError):
mnist = MNIST(mode='train', transform=transform, backend=1)
class TestFlowersTrain(unittest.TestCase): class TestFlowersTrain(unittest.TestCase):
def test_main(self): def test_main(self):
...@@ -124,6 +143,7 @@ class TestFlowersTrain(unittest.TestCase): ...@@ -124,6 +143,7 @@ class TestFlowersTrain(unittest.TestCase):
# long time, randomly check 1 sample # long time, randomly check 1 sample
idx = np.random.randint(0, 6149) idx = np.random.randint(0, 6149)
image, label = flowers[idx] image, label = flowers[idx]
image = np.array(image)
self.assertTrue(len(image.shape) == 3) self.assertTrue(len(image.shape) == 3)
self.assertTrue(image.shape[2] == 3) self.assertTrue(image.shape[2] == 3)
self.assertTrue(label.shape[0] == 1) self.assertTrue(label.shape[0] == 1)
...@@ -138,6 +158,7 @@ class TestFlowersValid(unittest.TestCase): ...@@ -138,6 +158,7 @@ class TestFlowersValid(unittest.TestCase):
# long time, randomly check 1 sample # long time, randomly check 1 sample
idx = np.random.randint(0, 1020) idx = np.random.randint(0, 1020)
image, label = flowers[idx] image, label = flowers[idx]
image = np.array(image)
self.assertTrue(len(image.shape) == 3) self.assertTrue(len(image.shape) == 3)
self.assertTrue(image.shape[2] == 3) self.assertTrue(image.shape[2] == 3)
self.assertTrue(label.shape[0] == 1) self.assertTrue(label.shape[0] == 1)
...@@ -152,10 +173,27 @@ class TestFlowersTest(unittest.TestCase): ...@@ -152,10 +173,27 @@ class TestFlowersTest(unittest.TestCase):
# long time, randomly check 1 sample # long time, randomly check 1 sample
idx = np.random.randint(0, 1020) idx = np.random.randint(0, 1020)
image, label = flowers[idx] image, label = flowers[idx]
image = np.array(image)
self.assertTrue(len(image.shape) == 3) self.assertTrue(len(image.shape) == 3)
self.assertTrue(image.shape[2] == 3) self.assertTrue(image.shape[2] == 3)
self.assertTrue(label.shape[0] == 1) self.assertTrue(label.shape[0] == 1)
# test cv2 backend
flowers = Flowers(mode='test', backend='cv2')
self.assertTrue(len(flowers) == 1020)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 1020)
image, label = flowers[idx]
self.assertTrue(len(image.shape) == 3)
self.assertTrue(image.shape[2] == 3)
self.assertTrue(label.shape[0] == 1)
with self.assertRaises(ValueError):
flowers = Flowers(mode='test', backend=1)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
import tarfile import tarfile
import numpy as np import numpy as np
import six import six
from PIL import Image
from six.moves import cPickle as pickle from six.moves import cPickle as pickle
import paddle import paddle
...@@ -51,6 +52,10 @@ class Cifar10(Dataset): ...@@ -51,6 +52,10 @@ class Cifar10(Dataset):
transform(callable): transform to perform on image, None for on transform. transform(callable): transform to perform on image, None for on transform.
download(bool): whether to download dataset automatically if download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True :attr:`data_file` is not set. Default True
backend(str, optional): Specifies which type of image to be returned:
PIL.Image or numpy.ndarray. Should be one of {'pil', 'cv2'}.
If this option is not set, will get backend from ``paddle.vsion.get_image_backend`` ,
default backend is 'pil'. Default: None.
Returns: Returns:
Dataset: instance of cifar-10 dataset Dataset: instance of cifar-10 dataset
...@@ -72,13 +77,14 @@ class Cifar10(Dataset): ...@@ -72,13 +77,14 @@ class Cifar10(Dataset):
nn.Softmax()) nn.Softmax())
def forward(self, image, label): def forward(self, image, label):
image = paddle.reshape(image, (3, -1)) image = paddle.reshape(image, (1, -1))
return self.fc(image), label return self.fc(image), label
paddle.disable_static() paddle.disable_static()
normalize = Normalize(mean=[0.5, 0.5, 0.5], normalize = Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5]) std=[0.5, 0.5, 0.5],
data_format='HWC')
cifar10 = Cifar10(mode='train', transform=normalize) cifar10 = Cifar10(mode='train', transform=normalize)
for i in range(10): for i in range(10):
...@@ -96,11 +102,20 @@ class Cifar10(Dataset): ...@@ -96,11 +102,20 @@ class Cifar10(Dataset):
data_file=None, data_file=None,
mode='train', mode='train',
transform=None, transform=None,
download=True): download=True,
backend=None):
assert mode.lower() in ['train', 'test', 'train', 'test'], \ assert mode.lower() in ['train', 'test', 'train', 'test'], \
"mode should be 'train10', 'test10', 'train100' or 'test100', but got {}".format(mode) "mode should be 'train10', 'test10', 'train100' or 'test100', but got {}".format(mode)
self.mode = mode.lower() self.mode = mode.lower()
if backend is None:
backend = paddle.vision.get_image_backend()
if backend not in ['pil', 'cv2']:
raise ValueError(
"Expected backend are one of ['pil', 'cv2'], but got {}"
.format(backend))
self.backend = backend
self._init_url_md5_flag() self._init_url_md5_flag()
self.data_file = data_file self.data_file = data_file
...@@ -143,8 +158,16 @@ class Cifar10(Dataset): ...@@ -143,8 +158,16 @@ class Cifar10(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
image, label = self.data[idx] image, label = self.data[idx]
image = np.reshape(image, [3, 32, 32]) image = np.reshape(image, [3, 32, 32])
image = image.transpose([1, 2, 0])
if self.backend == 'pil':
image = Image.fromarray(image)
if self.transform is not None: if self.transform is not None:
image = self.transform(image) image = self.transform(image)
if self.backend == 'pil':
return image, np.array(label).astype('int64')
return image.astype(self.dtype), np.array(label).astype('int64') return image.astype(self.dtype), np.array(label).astype('int64')
def __len__(self): def __len__(self):
...@@ -163,6 +186,10 @@ class Cifar100(Cifar10): ...@@ -163,6 +186,10 @@ class Cifar100(Cifar10):
transform(callable): transform to perform on image, None for on transform. transform(callable): transform to perform on image, None for on transform.
download(bool): whether to download dataset automatically if download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True :attr:`data_file` is not set. Default True
backend(str, optional): Specifies which type of image to be returned:
PIL.Image or numpy.ndarray. Should be one of {'pil', 'cv2'}.
If this option is not set, will get backend from ``paddle.vsion.get_image_backend`` ,
default backend is 'pil'. Default: None.
Returns: Returns:
Dataset: instance of cifar-100 dataset Dataset: instance of cifar-100 dataset
...@@ -184,13 +211,14 @@ class Cifar100(Cifar10): ...@@ -184,13 +211,14 @@ class Cifar100(Cifar10):
nn.Softmax()) nn.Softmax())
def forward(self, image, label): def forward(self, image, label):
image = paddle.reshape(image, (3, -1)) image = paddle.reshape(image, (1, -1))
return self.fc(image), label return self.fc(image), label
paddle.disable_static() paddle.disable_static()
normalize = Normalize(mean=[0.5, 0.5, 0.5], normalize = Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5]) std=[0.5, 0.5, 0.5],
data_format='HWC')
cifar100 = Cifar100(mode='train', transform=normalize) cifar100 = Cifar100(mode='train', transform=normalize)
for i in range(10): for i in range(10):
...@@ -208,8 +236,10 @@ class Cifar100(Cifar10): ...@@ -208,8 +236,10 @@ class Cifar100(Cifar10):
data_file=None, data_file=None,
mode='train', mode='train',
transform=None, transform=None,
download=True): download=True,
super(Cifar100, self).__init__(data_file, mode, transform, download) backend=None):
super(Cifar100, self).__init__(data_file, mode, transform, download,
backend)
def _init_url_md5_flag(self): def _init_url_md5_flag(self):
self.data_url = CIFAR100_URL self.data_url = CIFAR100_URL
......
...@@ -56,6 +56,10 @@ class Flowers(Dataset): ...@@ -56,6 +56,10 @@ class Flowers(Dataset):
transform(callable): transform to perform on image, None for on transform. transform(callable): transform to perform on image, None for on transform.
download(bool): whether to download dataset automatically if download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True :attr:`data_file` is not set. Default True
backend(str, optional): Specifies which type of image to be returned:
PIL.Image or numpy.ndarray. Should be one of {'pil', 'cv2'}.
If this option is not set, will get backend from ``paddle.vsion.get_image_backend`` ,
default backend is 'pil'. Default: None.
Examples: Examples:
...@@ -67,7 +71,7 @@ class Flowers(Dataset): ...@@ -67,7 +71,7 @@ class Flowers(Dataset):
for i in range(len(flowers)): for i in range(len(flowers)):
sample = flowers[i] sample = flowers[i]
print(sample[0].shape, sample[1]) print(sample[0].size, sample[1])
""" """
...@@ -77,9 +81,19 @@ class Flowers(Dataset): ...@@ -77,9 +81,19 @@ class Flowers(Dataset):
setid_file=None, setid_file=None,
mode='train', mode='train',
transform=None, transform=None,
download=True): download=True,
backend=None):
assert mode.lower() in ['train', 'valid', 'test'], \ assert mode.lower() in ['train', 'valid', 'test'], \
"mode should be 'train', 'valid' or 'test', but got {}".format(mode) "mode should be 'train', 'valid' or 'test', but got {}".format(mode)
if backend is None:
backend = paddle.vision.get_image_backend()
if backend not in ['pil', 'cv2']:
raise ValueError(
"Expected backend are one of ['pil', 'cv2'], but got {}"
.format(backend))
self.backend = backend
self.flag = MODE_FLAG_MAP[mode.lower()] self.flag = MODE_FLAG_MAP[mode.lower()]
self.data_file = data_file self.data_file = data_file
...@@ -122,11 +136,18 @@ class Flowers(Dataset): ...@@ -122,11 +136,18 @@ class Flowers(Dataset):
img_name = "jpg/image_%05d.jpg" % index img_name = "jpg/image_%05d.jpg" % index
img_ele = self.name2mem[img_name] img_ele = self.name2mem[img_name]
image = self.data_tar.extractfile(img_ele).read() image = self.data_tar.extractfile(img_ele).read()
if self.backend == 'pil':
image = Image.open(io.BytesIO(image))
elif self.backend == 'cv2':
image = np.array(Image.open(io.BytesIO(image))) image = np.array(Image.open(io.BytesIO(image)))
if self.transform is not None: if self.transform is not None:
image = self.transform(image) image = self.transform(image)
if self.backend == 'pil':
return image, label.astype('int64')
return image.astype(self.dtype), label.astype('int64') return image.astype(self.dtype), label.astype('int64')
def __len__(self): def __len__(self):
......
...@@ -18,6 +18,7 @@ import os ...@@ -18,6 +18,7 @@ import os
import gzip import gzip
import struct import struct
import numpy as np import numpy as np
from PIL import Image
import paddle import paddle
from paddle.io import Dataset from paddle.io import Dataset
...@@ -48,6 +49,10 @@ class MNIST(Dataset): ...@@ -48,6 +49,10 @@ class MNIST(Dataset):
mode(str): 'train' or 'test' mode. Default 'train'. mode(str): 'train' or 'test' mode. Default 'train'.
download(bool): whether to download dataset automatically if download(bool): whether to download dataset automatically if
:attr:`image_path` :attr:`label_path` is not set. Default True :attr:`image_path` :attr:`label_path` is not set. Default True
backend(str, optional): Specifies which type of image to be returned:
PIL.Image or numpy.ndarray. Should be one of {'pil', 'cv2'}.
If this option is not set, will get backend from ``paddle.vsion.get_image_backend`` ,
default backend is 'pil'. Default: None.
Returns: Returns:
Dataset: MNIST Dataset. Dataset: MNIST Dataset.
...@@ -62,7 +67,7 @@ class MNIST(Dataset): ...@@ -62,7 +67,7 @@ class MNIST(Dataset):
for i in range(len(mnist)): for i in range(len(mnist)):
sample = mnist[i] sample = mnist[i]
print(sample[0].shape, sample[1]) print(sample[0].size, sample[1])
""" """
...@@ -71,9 +76,19 @@ class MNIST(Dataset): ...@@ -71,9 +76,19 @@ class MNIST(Dataset):
label_path=None, label_path=None,
mode='train', mode='train',
transform=None, transform=None,
download=True): download=True,
backend=None):
assert mode.lower() in ['train', 'test'], \ assert mode.lower() in ['train', 'test'], \
"mode should be 'train' or 'test', but got {}".format(mode) "mode should be 'train' or 'test', but got {}".format(mode)
if backend is None:
backend = paddle.vision.get_image_backend()
if backend not in ['pil', 'cv2']:
raise ValueError(
"Expected backend are one of ['pil', 'cv2'], but got {}"
.format(backend))
self.backend = backend
self.mode = mode.lower() self.mode = mode.lower()
self.image_path = image_path self.image_path = image_path
if self.image_path is None: if self.image_path is None:
...@@ -145,9 +160,17 @@ class MNIST(Dataset): ...@@ -145,9 +160,17 @@ class MNIST(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
image, label = self.images[idx], self.labels[idx] image, label = self.images[idx], self.labels[idx]
image = np.reshape(image, [1, 28, 28]) image = np.reshape(image, [28, 28])
if self.backend == 'pil':
image = Image.fromarray(image, mode='L')
if self.transform is not None: if self.transform is not None:
image = self.transform(image) image = self.transform(image)
if self.backend == 'pil':
return image, label.astype('int64')
return image.astype(self.dtype), label.astype('int64') return image.astype(self.dtype), label.astype('int64')
def __len__(self): def __len__(self):
......
...@@ -48,6 +48,10 @@ class VOC2012(Dataset): ...@@ -48,6 +48,10 @@ class VOC2012(Dataset):
mode(str): 'train', 'valid' or 'test' mode. Default 'train'. mode(str): 'train', 'valid' or 'test' mode. Default 'train'.
download(bool): whether to download dataset automatically if download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True :attr:`data_file` is not set. Default True
backend(str, optional): Specifies which type of image to be returned:
PIL.Image or numpy.ndarray. Should be one of {'pil', 'cv2'}.
If this option is not set, will get backend from ``paddle.vsion.get_image_backend`` ,
default backend is 'pil'. Default: None.
Examples: Examples:
...@@ -55,6 +59,7 @@ class VOC2012(Dataset): ...@@ -55,6 +59,7 @@ class VOC2012(Dataset):
import paddle import paddle
from paddle.vision.datasets import VOC2012 from paddle.vision.datasets import VOC2012
from paddle.vision.transforms import Normalize
class SimpleNet(paddle.nn.Layer): class SimpleNet(paddle.nn.Layer):
def __init__(self): def __init__(self):
...@@ -65,7 +70,10 @@ class VOC2012(Dataset): ...@@ -65,7 +70,10 @@ class VOC2012(Dataset):
paddle.disable_static() paddle.disable_static()
voc2012 = VOC2012(mode='train') normalize = Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5],
data_format='HWC')
voc2012 = VOC2012(mode='train', transform=normalize, backend='cv2')
for i in range(10): for i in range(10):
image, label= voc2012[i] image, label= voc2012[i]
...@@ -82,9 +90,19 @@ class VOC2012(Dataset): ...@@ -82,9 +90,19 @@ class VOC2012(Dataset):
data_file=None, data_file=None,
mode='train', mode='train',
transform=None, transform=None,
download=True): download=True,
backend=None):
assert mode.lower() in ['train', 'valid', 'test'], \ assert mode.lower() in ['train', 'valid', 'test'], \
"mode should be 'train', 'valid' or 'test', but got {}".format(mode) "mode should be 'train', 'valid' or 'test', but got {}".format(mode)
if backend is None:
backend = paddle.vision.get_image_backend()
if backend not in ['pil', 'cv2']:
raise ValueError(
"Expected backend are one of ['pil', 'cv2'], but got {}"
.format(backend))
self.backend = backend
self.flag = MODE_FLAG_MAP[mode.lower()] self.flag = MODE_FLAG_MAP[mode.lower()]
self.data_file = data_file self.data_file = data_file
...@@ -126,12 +144,19 @@ class VOC2012(Dataset): ...@@ -126,12 +144,19 @@ class VOC2012(Dataset):
label = self.data_tar.extractfile(self.name2mem[label_file]).read() label = self.data_tar.extractfile(self.name2mem[label_file]).read()
data = Image.open(io.BytesIO(data)) data = Image.open(io.BytesIO(data))
label = Image.open(io.BytesIO(label)) label = Image.open(io.BytesIO(label))
if self.backend == 'cv2':
data = np.array(data) data = np.array(data)
label = np.array(label) label = np.array(label)
if self.transform is not None: if self.transform is not None:
data = self.transform(data) data = self.transform(data)
if self.backend == 'cv2':
return data.astype(self.dtype), label.astype(self.dtype) return data.astype(self.dtype), label.astype(self.dtype)
return data, label
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
......
...@@ -686,6 +686,8 @@ class Transpose(BaseTransform): ...@@ -686,6 +686,8 @@ class Transpose(BaseTransform):
if F._is_pil_image(img): if F._is_pil_image(img):
img = np.asarray(img) img = np.asarray(img)
if len(img.shape) == 2:
img = img[..., np.newaxis]
return img.transpose(self.order) return img.transpose(self.order)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册