未验证 提交 fb1e0c93 编写于 作者: L LielinJiang 提交者: GitHub

Make vision datasets return PIL.Image as default (#28264)

* return pil image as default according backend
上级 26ede6e0
......@@ -296,12 +296,17 @@ class ProgBarLogger(Callback):
.. code-block:: python
import paddle
import paddle.vision.transforms as T
from paddle.static import InputSpec
inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')]
labels = [InputSpec([None, 1], 'int64', 'label')]
train_dataset = paddle.vision.datasets.MNIST(mode='train')
transform = T.Compose([
T.Transpose(),
T.Normalize([127.5], [127.5])
])
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
lenet = paddle.vision.LeNet()
model = paddle.Model(lenet,
......@@ -432,12 +437,17 @@ class ModelCheckpoint(Callback):
.. code-block:: python
import paddle
import paddle.vision.transforms as T
from paddle.static import InputSpec
inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')]
labels = [InputSpec([None, 1], 'int64', 'label')]
train_dataset = paddle.vision.datasets.MNIST(mode='train')
transform = T.Compose([
T.Transpose(),
T.Normalize([127.5], [127.5])
])
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
lenet = paddle.vision.LeNet()
model = paddle.Model(lenet,
......@@ -484,13 +494,18 @@ class VisualDL(Callback):
.. code-block:: python
import paddle
import paddle.vision.transforms as T
from paddle.static import InputSpec
inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')]
labels = [InputSpec([None, 1], 'int64', 'label')]
train_dataset = paddle.vision.datasets.MNIST(mode='train')
eval_dataset = paddle.vision.datasets.MNIST(mode='test')
transform = T.Compose([
T.Transpose(),
T.Normalize([127.5], [127.5])
])
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
eval_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
net = paddle.vision.LeNet()
model = paddle.Model(net, inputs, labels)
......
......@@ -837,6 +837,7 @@ class Model(object):
import paddle
import paddle.nn as nn
import paddle.vision.transforms as T
from paddle.static import InputSpec
device = paddle.set_device('cpu') # or 'gpu'
......@@ -858,7 +859,11 @@ class Model(object):
paddle.nn.CrossEntropyLoss(),
paddle.metric.Accuracy())
data = paddle.vision.datasets.MNIST(mode='train')
transform = T.Compose([
T.Transpose(),
T.Normalize([127.5], [127.5])
])
data = paddle.vision.datasets.MNIST(mode='train', transform=transform)
model.fit(data, epochs=2, batch_size=32, verbose=1)
"""
......@@ -1067,6 +1072,7 @@ class Model(object):
import paddle
import paddle.nn as nn
import paddle.vision.transforms as T
from paddle.static import InputSpec
class Mnist(nn.Layer):
......@@ -1093,7 +1099,13 @@ class Model(object):
optim = paddle.optimizer.SGD(learning_rate=1e-3,
parameters=model.parameters())
model.prepare(optim, paddle.nn.CrossEntropyLoss())
data = paddle.vision.datasets.MNIST(mode='train')
transform = T.Compose([
T.Transpose(),
T.Normalize([127.5], [127.5])
])
data = paddle.vision.datasets.MNIST(mode='train', transform=transform)
model.fit(data, epochs=1, batch_size=32, verbose=0)
model.save('checkpoint/test') # save for training
model.save('inference_model', False) # save for inference
......@@ -1353,14 +1365,19 @@ class Model(object):
.. code-block:: python
import paddle
import paddle.vision.transforms as T
from paddle.static import InputSpec
dynamic = True
device = paddle.set_device('cpu') # or 'gpu'
paddle.disable_static(device) if dynamic else None
train_dataset = paddle.vision.datasets.MNIST(mode='train')
val_dataset = paddle.vision.datasets.MNIST(mode='test')
transform = T.Compose([
T.Transpose(),
T.Normalize([127.5], [127.5])
])
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
val_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
input = InputSpec([None, 1, 28, 28], 'float32', 'image')
label = InputSpec([None, 1], 'int64', 'label')
......@@ -1386,16 +1403,21 @@ class Model(object):
.. code-block:: python
import paddle
import paddle.vision.transforms as T
from paddle.static import InputSpec
dynamic = True
device = paddle.set_device('cpu') # or 'gpu'
paddle.disable_static(device) if dynamic else None
train_dataset = paddle.vision.datasets.MNIST(mode='train')
transform = T.Compose([
T.Transpose(),
T.Normalize([127.5], [127.5])
])
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
train_loader = paddle.io.DataLoader(train_dataset,
places=device, batch_size=64)
val_dataset = paddle.vision.datasets.MNIST(mode='test')
val_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
val_loader = paddle.io.DataLoader(val_dataset,
places=device, batch_size=64)
......@@ -1522,10 +1544,15 @@ class Model(object):
.. code-block:: python
import paddle
import paddle.vision.transforms as T
from paddle.static import InputSpec
# declarative mode
val_dataset = paddle.vision.datasets.MNIST(mode='test')
transform = T.Compose([
T.Transpose(),
T.Normalize([127.5], [127.5])
])
val_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
input = InputSpec([-1, 1, 28, 28], 'float32', 'image')
label = InputSpec([None, 1], 'int64', 'label')
......
......@@ -24,6 +24,7 @@ from paddle import Model
from paddle.static import InputSpec
from paddle.vision.models import LeNet
from paddle.hapi.callbacks import config_callbacks
import paddle.vision.transforms as T
class TestCallbacks(unittest.TestCase):
......@@ -112,8 +113,11 @@ class TestCallbacks(unittest.TestCase):
inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')]
labels = [InputSpec([None, 1], 'int64', 'label')]
train_dataset = paddle.vision.datasets.MNIST(mode='train')
eval_dataset = paddle.vision.datasets.MNIST(mode='test')
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = paddle.vision.datasets.MNIST(
mode='train', transform=transform)
eval_dataset = paddle.vision.datasets.MNIST(
mode='test', transform=transform)
net = paddle.vision.LeNet()
model = paddle.Model(net, inputs, labels)
......
......@@ -27,10 +27,11 @@ class TestCifar10Train(unittest.TestCase):
# long time, randomly check 1 sample
idx = np.random.randint(0, 50000)
data, label = cifar[idx]
data = np.array(data)
self.assertTrue(len(data.shape) == 3)
self.assertTrue(data.shape[0] == 3)
self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[2] == 32)
self.assertTrue(data.shape[0] == 32)
self.assertTrue(0 <= int(label) <= 9)
......@@ -43,12 +44,30 @@ class TestCifar10Test(unittest.TestCase):
# long time, randomly check 1 sample
idx = np.random.randint(0, 10000)
data, label = cifar[idx]
data = np.array(data)
self.assertTrue(len(data.shape) == 3)
self.assertTrue(data.shape[0] == 3)
self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[2] == 32)
self.assertTrue(data.shape[0] == 32)
self.assertTrue(0 <= int(label) <= 9)
# test cv2 backend
cifar = Cifar10(mode='test', backend='cv2')
self.assertTrue(len(cifar) == 10000)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 10000)
data, label = cifar[idx]
self.assertTrue(len(data.shape) == 3)
self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32)
self.assertTrue(0 <= int(label) <= 99)
with self.assertRaises(ValueError):
cifar = Cifar10(mode='test', backend=1)
class TestCifar100Train(unittest.TestCase):
def test_main(self):
......@@ -59,10 +78,11 @@ class TestCifar100Train(unittest.TestCase):
# long time, randomly check 1 sample
idx = np.random.randint(0, 50000)
data, label = cifar[idx]
data = np.array(data)
self.assertTrue(len(data.shape) == 3)
self.assertTrue(data.shape[0] == 3)
self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[2] == 32)
self.assertTrue(data.shape[0] == 32)
self.assertTrue(0 <= int(label) <= 99)
......@@ -75,12 +95,30 @@ class TestCifar100Test(unittest.TestCase):
# long time, randomly check 1 sample
idx = np.random.randint(0, 10000)
data, label = cifar[idx]
data = np.array(data)
self.assertTrue(len(data.shape) == 3)
self.assertTrue(data.shape[0] == 3)
self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[2] == 32)
self.assertTrue(data.shape[0] == 32)
self.assertTrue(0 <= int(label) <= 99)
# test cv2 backend
cifar = Cifar100(mode='test', backend='cv2')
self.assertTrue(len(cifar) == 10000)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 10000)
data, label = cifar[idx]
self.assertTrue(len(data.shape) == 3)
self.assertTrue(data.shape[2] == 3)
self.assertTrue(data.shape[1] == 32)
self.assertTrue(data.shape[0] == 32)
self.assertTrue(0 <= int(label) <= 99)
with self.assertRaises(ValueError):
cifar = Cifar100(mode='test', backend=1)
if __name__ == '__main__':
unittest.main()
......@@ -32,6 +32,9 @@ class TestVOC2012Train(unittest.TestCase):
# long time, randomly check 1 sample
idx = np.random.randint(0, 3)
image, label = voc2012[idx]
image = np.array(image)
label = np.array(label)
self.assertTrue(len(image.shape) == 3)
self.assertTrue(len(label.shape) == 2)
......@@ -45,6 +48,9 @@ class TestVOC2012Valid(unittest.TestCase):
# long time, randomly check 1 sample
idx = np.random.randint(0, 1)
image, label = voc2012[idx]
image = np.array(image)
label = np.array(label)
self.assertTrue(len(image.shape) == 3)
self.assertTrue(len(label.shape) == 2)
......@@ -58,9 +64,27 @@ class TestVOC2012Test(unittest.TestCase):
# long time, randomly check 1 sample
idx = np.random.randint(0, 1)
image, label = voc2012[idx]
image = np.array(image)
label = np.array(label)
self.assertTrue(len(image.shape) == 3)
self.assertTrue(len(label.shape) == 2)
# test cv2 backend
voc2012 = VOC2012(mode='test', backend='cv2')
self.assertTrue(len(voc2012) == 2)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 1)
image, label = voc2012[idx]
self.assertTrue(len(image.shape) == 3)
self.assertTrue(len(label.shape) == 2)
with self.assertRaises(ValueError):
voc2012 = VOC2012(mode='test', backend=1)
if __name__ == '__main__':
unittest.main()
......@@ -19,6 +19,7 @@ import tempfile
import shutil
import cv2
import paddle.vision.transforms as T
from paddle.vision.datasets import *
from paddle.dataset.common import _check_exists_and_download
......@@ -89,7 +90,8 @@ class TestFolderDatasets(unittest.TestCase):
class TestMNISTTest(unittest.TestCase):
def test_main(self):
mnist = MNIST(mode='test')
transform = T.Transpose()
mnist = MNIST(mode='test', transform=transform)
self.assertTrue(len(mnist) == 10000)
for i in range(len(mnist)):
......@@ -103,7 +105,8 @@ class TestMNISTTest(unittest.TestCase):
class TestMNISTTrain(unittest.TestCase):
def test_main(self):
mnist = MNIST(mode='train')
transform = T.Transpose()
mnist = MNIST(mode='train', transform=transform)
self.assertTrue(len(mnist) == 60000)
for i in range(len(mnist)):
......@@ -114,6 +117,22 @@ class TestMNISTTrain(unittest.TestCase):
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 9)
# test cv2 backend
mnist = MNIST(mode='train', transform=transform, backend='cv2')
self.assertTrue(len(mnist) == 60000)
for i in range(len(mnist)):
image, label = mnist[i]
self.assertTrue(image.shape[0] == 1)
self.assertTrue(image.shape[1] == 28)
self.assertTrue(image.shape[2] == 28)
self.assertTrue(label.shape[0] == 1)
self.assertTrue(0 <= int(label) <= 9)
break
with self.assertRaises(ValueError):
mnist = MNIST(mode='train', transform=transform, backend=1)
class TestFlowersTrain(unittest.TestCase):
def test_main(self):
......@@ -124,6 +143,7 @@ class TestFlowersTrain(unittest.TestCase):
# long time, randomly check 1 sample
idx = np.random.randint(0, 6149)
image, label = flowers[idx]
image = np.array(image)
self.assertTrue(len(image.shape) == 3)
self.assertTrue(image.shape[2] == 3)
self.assertTrue(label.shape[0] == 1)
......@@ -138,6 +158,7 @@ class TestFlowersValid(unittest.TestCase):
# long time, randomly check 1 sample
idx = np.random.randint(0, 1020)
image, label = flowers[idx]
image = np.array(image)
self.assertTrue(len(image.shape) == 3)
self.assertTrue(image.shape[2] == 3)
self.assertTrue(label.shape[0] == 1)
......@@ -152,10 +173,27 @@ class TestFlowersTest(unittest.TestCase):
# long time, randomly check 1 sample
idx = np.random.randint(0, 1020)
image, label = flowers[idx]
image = np.array(image)
self.assertTrue(len(image.shape) == 3)
self.assertTrue(image.shape[2] == 3)
self.assertTrue(label.shape[0] == 1)
# test cv2 backend
flowers = Flowers(mode='test', backend='cv2')
self.assertTrue(len(flowers) == 1020)
# traversal whole dataset may cost a
# long time, randomly check 1 sample
idx = np.random.randint(0, 1020)
image, label = flowers[idx]
self.assertTrue(len(image.shape) == 3)
self.assertTrue(image.shape[2] == 3)
self.assertTrue(label.shape[0] == 1)
with self.assertRaises(ValueError):
flowers = Flowers(mode='test', backend=1)
if __name__ == '__main__':
unittest.main()
......@@ -17,6 +17,7 @@ from __future__ import print_function
import tarfile
import numpy as np
import six
from PIL import Image
from six.moves import cPickle as pickle
import paddle
......@@ -51,6 +52,10 @@ class Cifar10(Dataset):
transform(callable): transform to perform on image, None for on transform.
download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True
backend(str, optional): Specifies which type of image to be returned:
PIL.Image or numpy.ndarray. Should be one of {'pil', 'cv2'}.
If this option is not set, will get backend from ``paddle.vsion.get_image_backend`` ,
default backend is 'pil'. Default: None.
Returns:
Dataset: instance of cifar-10 dataset
......@@ -72,13 +77,14 @@ class Cifar10(Dataset):
nn.Softmax())
def forward(self, image, label):
image = paddle.reshape(image, (3, -1))
image = paddle.reshape(image, (1, -1))
return self.fc(image), label
paddle.disable_static()
normalize = Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
std=[0.5, 0.5, 0.5],
data_format='HWC')
cifar10 = Cifar10(mode='train', transform=normalize)
for i in range(10):
......@@ -96,11 +102,20 @@ class Cifar10(Dataset):
data_file=None,
mode='train',
transform=None,
download=True):
download=True,
backend=None):
assert mode.lower() in ['train', 'test', 'train', 'test'], \
"mode should be 'train10', 'test10', 'train100' or 'test100', but got {}".format(mode)
self.mode = mode.lower()
if backend is None:
backend = paddle.vision.get_image_backend()
if backend not in ['pil', 'cv2']:
raise ValueError(
"Expected backend are one of ['pil', 'cv2'], but got {}"
.format(backend))
self.backend = backend
self._init_url_md5_flag()
self.data_file = data_file
......@@ -143,8 +158,16 @@ class Cifar10(Dataset):
def __getitem__(self, idx):
image, label = self.data[idx]
image = np.reshape(image, [3, 32, 32])
image = image.transpose([1, 2, 0])
if self.backend == 'pil':
image = Image.fromarray(image)
if self.transform is not None:
image = self.transform(image)
if self.backend == 'pil':
return image, np.array(label).astype('int64')
return image.astype(self.dtype), np.array(label).astype('int64')
def __len__(self):
......@@ -163,6 +186,10 @@ class Cifar100(Cifar10):
transform(callable): transform to perform on image, None for on transform.
download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True
backend(str, optional): Specifies which type of image to be returned:
PIL.Image or numpy.ndarray. Should be one of {'pil', 'cv2'}.
If this option is not set, will get backend from ``paddle.vsion.get_image_backend`` ,
default backend is 'pil'. Default: None.
Returns:
Dataset: instance of cifar-100 dataset
......@@ -184,13 +211,14 @@ class Cifar100(Cifar10):
nn.Softmax())
def forward(self, image, label):
image = paddle.reshape(image, (3, -1))
image = paddle.reshape(image, (1, -1))
return self.fc(image), label
paddle.disable_static()
normalize = Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
std=[0.5, 0.5, 0.5],
data_format='HWC')
cifar100 = Cifar100(mode='train', transform=normalize)
for i in range(10):
......@@ -208,8 +236,10 @@ class Cifar100(Cifar10):
data_file=None,
mode='train',
transform=None,
download=True):
super(Cifar100, self).__init__(data_file, mode, transform, download)
download=True,
backend=None):
super(Cifar100, self).__init__(data_file, mode, transform, download,
backend)
def _init_url_md5_flag(self):
self.data_url = CIFAR100_URL
......
......@@ -56,6 +56,10 @@ class Flowers(Dataset):
transform(callable): transform to perform on image, None for on transform.
download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True
backend(str, optional): Specifies which type of image to be returned:
PIL.Image or numpy.ndarray. Should be one of {'pil', 'cv2'}.
If this option is not set, will get backend from ``paddle.vsion.get_image_backend`` ,
default backend is 'pil'. Default: None.
Examples:
......@@ -67,7 +71,7 @@ class Flowers(Dataset):
for i in range(len(flowers)):
sample = flowers[i]
print(sample[0].shape, sample[1])
print(sample[0].size, sample[1])
"""
......@@ -77,9 +81,19 @@ class Flowers(Dataset):
setid_file=None,
mode='train',
transform=None,
download=True):
download=True,
backend=None):
assert mode.lower() in ['train', 'valid', 'test'], \
"mode should be 'train', 'valid' or 'test', but got {}".format(mode)
if backend is None:
backend = paddle.vision.get_image_backend()
if backend not in ['pil', 'cv2']:
raise ValueError(
"Expected backend are one of ['pil', 'cv2'], but got {}"
.format(backend))
self.backend = backend
self.flag = MODE_FLAG_MAP[mode.lower()]
self.data_file = data_file
......@@ -122,11 +136,18 @@ class Flowers(Dataset):
img_name = "jpg/image_%05d.jpg" % index
img_ele = self.name2mem[img_name]
image = self.data_tar.extractfile(img_ele).read()
image = np.array(Image.open(io.BytesIO(image)))
if self.backend == 'pil':
image = Image.open(io.BytesIO(image))
elif self.backend == 'cv2':
image = np.array(Image.open(io.BytesIO(image)))
if self.transform is not None:
image = self.transform(image)
if self.backend == 'pil':
return image, label.astype('int64')
return image.astype(self.dtype), label.astype('int64')
def __len__(self):
......
......@@ -18,6 +18,7 @@ import os
import gzip
import struct
import numpy as np
from PIL import Image
import paddle
from paddle.io import Dataset
......@@ -48,7 +49,11 @@ class MNIST(Dataset):
mode(str): 'train' or 'test' mode. Default 'train'.
download(bool): whether to download dataset automatically if
:attr:`image_path` :attr:`label_path` is not set. Default True
backend(str, optional): Specifies which type of image to be returned:
PIL.Image or numpy.ndarray. Should be one of {'pil', 'cv2'}.
If this option is not set, will get backend from ``paddle.vsion.get_image_backend`` ,
default backend is 'pil'. Default: None.
Returns:
Dataset: MNIST Dataset.
......@@ -62,7 +67,7 @@ class MNIST(Dataset):
for i in range(len(mnist)):
sample = mnist[i]
print(sample[0].shape, sample[1])
print(sample[0].size, sample[1])
"""
......@@ -71,9 +76,19 @@ class MNIST(Dataset):
label_path=None,
mode='train',
transform=None,
download=True):
download=True,
backend=None):
assert mode.lower() in ['train', 'test'], \
"mode should be 'train' or 'test', but got {}".format(mode)
if backend is None:
backend = paddle.vision.get_image_backend()
if backend not in ['pil', 'cv2']:
raise ValueError(
"Expected backend are one of ['pil', 'cv2'], but got {}"
.format(backend))
self.backend = backend
self.mode = mode.lower()
self.image_path = image_path
if self.image_path is None:
......@@ -145,9 +160,17 @@ class MNIST(Dataset):
def __getitem__(self, idx):
image, label = self.images[idx], self.labels[idx]
image = np.reshape(image, [1, 28, 28])
image = np.reshape(image, [28, 28])
if self.backend == 'pil':
image = Image.fromarray(image, mode='L')
if self.transform is not None:
image = self.transform(image)
if self.backend == 'pil':
return image, label.astype('int64')
return image.astype(self.dtype), label.astype('int64')
def __len__(self):
......
......@@ -48,6 +48,10 @@ class VOC2012(Dataset):
mode(str): 'train', 'valid' or 'test' mode. Default 'train'.
download(bool): whether to download dataset automatically if
:attr:`data_file` is not set. Default True
backend(str, optional): Specifies which type of image to be returned:
PIL.Image or numpy.ndarray. Should be one of {'pil', 'cv2'}.
If this option is not set, will get backend from ``paddle.vsion.get_image_backend`` ,
default backend is 'pil'. Default: None.
Examples:
......@@ -55,6 +59,7 @@ class VOC2012(Dataset):
import paddle
from paddle.vision.datasets import VOC2012
from paddle.vision.transforms import Normalize
class SimpleNet(paddle.nn.Layer):
def __init__(self):
......@@ -65,7 +70,10 @@ class VOC2012(Dataset):
paddle.disable_static()
voc2012 = VOC2012(mode='train')
normalize = Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5],
data_format='HWC')
voc2012 = VOC2012(mode='train', transform=normalize, backend='cv2')
for i in range(10):
image, label= voc2012[i]
......@@ -82,9 +90,19 @@ class VOC2012(Dataset):
data_file=None,
mode='train',
transform=None,
download=True):
download=True,
backend=None):
assert mode.lower() in ['train', 'valid', 'test'], \
"mode should be 'train', 'valid' or 'test', but got {}".format(mode)
if backend is None:
backend = paddle.vision.get_image_backend()
if backend not in ['pil', 'cv2']:
raise ValueError(
"Expected backend are one of ['pil', 'cv2'], but got {}"
.format(backend))
self.backend = backend
self.flag = MODE_FLAG_MAP[mode.lower()]
self.data_file = data_file
......@@ -126,11 +144,18 @@ class VOC2012(Dataset):
label = self.data_tar.extractfile(self.name2mem[label_file]).read()
data = Image.open(io.BytesIO(data))
label = Image.open(io.BytesIO(label))
data = np.array(data)
label = np.array(label)
if self.backend == 'cv2':
data = np.array(data)
label = np.array(label)
if self.transform is not None:
data = self.transform(data)
return data.astype(self.dtype), label.astype(self.dtype)
if self.backend == 'cv2':
return data.astype(self.dtype), label.astype(self.dtype)
return data, label
def __len__(self):
return len(self.data)
......
......@@ -686,6 +686,8 @@ class Transpose(BaseTransform):
if F._is_pil_image(img):
img = np.asarray(img)
if len(img.shape) == 2:
img = img[..., np.newaxis]
return img.transpose(self.order)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册