未验证 提交 2ad63212 编写于 作者: L Luyang 提交者: GitHub

Dev flow.utils.data part3 (#5644)

* add more datasets

* add more transform funcs

* export interface

* export datasets interface

* auto format by CI

* fix docs

* skip test

* support DistributedSampler

* refine

* add more transform function

* fix err import

* fix comment

* refine

* add more transform test

* refactor dataloader test

* refine

* add ddp test

* refine

* refine

* add ddp test case

* skil test

* add ddp test case

* add test case

* refine

* rm ddp test

* remove ddp test

* auto format by CI

* format

* update api docs

* add utils.rst

* auto format by CI

* fix ddp grad size
Signed-off-by: Ndaquexian <daquexian566@gmail.com>

* remove print
Signed-off-by: Ndaquexian <daquexian566@gmail.com>

* refine as comments

* refine

* auto format by CI

* auto format by CI

* refine

* add ddp test

* auto format by CI

* rm test case

* fix reshape
Co-authored-by: Noneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: Ndaquexian <daquexian566@gmail.com>
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 3afdd833
...@@ -20,6 +20,8 @@ OneFlow API Reference ...@@ -20,6 +20,8 @@ OneFlow API Reference
linalg linalg
image image
optim optim
utils
Indices and tables Indices and tables
......
oneflow.utils
===================================
Utils
----------------------------------
.. currentmodule:: oneflow.utils
.. automodule:: oneflow.utils.data
:members: DataLoader,
Dataset,
IterableDataset,
TensorDataset,
ConcatDataset,
Subset,
random_split,
Sampler,
SequentialSampler,
RandomSampler,
SubsetRandomSampler,
BatchSampler
.. currentmodule:: oneflow.utils
.. automodule:: oneflow.utils.data.distributed
:members: DistributedSampler
.. currentmodule:: oneflow.utils
.. automodule:: oneflow.utils.vision.datasets
:members: MNIST,
FashionMNIST,
CIFAR10,
CIFAR100,
ImageNet,
CocoCaptions,
CocoDetection,
VOCDetection,
VOCSegmentation,
DatasetFolder,
ImageFolder
.. currentmodule:: oneflow.utils
.. automodule:: oneflow.utils.vision.transforms
:members: Compose,
ToTensor,
PILToTensor,
ConvertImageDtype,
ToPILImage,
Normalize,
Resize,
Scale,
CenterCrop,
Pad,
Lambda,
RandomTransforms,
RandomApply,
RandomOrder,
RandomChoice,
RandomCrop,
RandomHorizontalFlip,
RandomVerticalFlip,
RandomResizedCrop,
RandomSizedCrop,
FiveCrop,
TenCrop,
InterpolationMode
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import oneflow as flow
import oneflow.utils.vision.transforms as transforms
def load_data_cifar10(
batch_size,
data_dir="./data-test/cifar10",
download=True,
transform=None,
source_url=None,
num_workers=0,
):
cifar10_train = flow.utils.vision.datasets.CIFAR10(
root=data_dir,
train=True,
download=download,
transform=transform,
source_url=source_url,
)
cifar10_test = flow.utils.vision.datasets.CIFAR10(
root=data_dir,
train=False,
download=download,
transform=transform,
source_url=source_url,
)
train_iter = flow.utils.data.DataLoader(
cifar10_train, batch_size=batch_size, shuffle=True, num_workers=num_workers
)
test_iter = flow.utils.data.DataLoader(
cifar10_test, batch_size=batch_size, shuffle=False, num_workers=num_workers
)
return train_iter, test_iter
def load_data_mnist(
batch_size, resize=None, root="./data/mnist", download=True, source_url=None
):
"""Download the MNIST dataset and then load into memory."""
root = os.path.expanduser(root)
transformer = []
if resize:
transformer += [transforms.Resize(resize)]
transformer += [transforms.ToTensor()]
transformer = transforms.Compose(transformer)
mnist_train = flow.utils.vision.datasets.MNIST(
root=root,
train=True,
transform=transformer,
download=download,
source_url=source_url,
)
mnist_test = flow.utils.vision.datasets.MNIST(
root=root,
train=False,
transform=transformer,
download=download,
source_url=source_url,
)
train_iter = flow.utils.data.DataLoader(mnist_train, batch_size, shuffle=True)
test_iter = flow.utils.data.DataLoader(mnist_test, batch_size, shuffle=False)
return train_iter, test_iter
def get_fashion_mnist_dataset(
resize=None, root="./data-test/fashion-mnist", download=True, source_url=None,
):
root = os.path.expanduser(root)
trans = []
if resize:
trans.append(transforms.Resize(resize))
trans.append(transforms.ToTensor())
transform = transforms.Compose(trans)
mnist_train = flow.utils.vision.datasets.FashionMNIST(
root=root,
train=True,
transform=transform,
download=download,
source_url=source_url,
)
mnist_test = flow.utils.vision.datasets.FashionMNIST(
root=root,
train=False,
transform=transform,
download=download,
source_url=source_url,
)
return mnist_train, mnist_test
# reference: http://tangshusen.me/Dive-into-DL-PyTorch/#/chapter03_DL-basics/3.10_mlp-pytorch
def load_data_fashion_mnist(
batch_size,
resize=None,
root="./data-test/fashion-mnist",
download=True,
source_url=None,
num_workers=0,
):
"""Download the Fashion-MNIST dataset and then load into memory."""
root = os.path.expanduser(root)
trans = []
if resize:
trans.append(transforms.Resize(resize))
trans.append(transforms.ToTensor())
transform = transforms.Compose(trans)
mnist_train = flow.utils.vision.datasets.FashionMNIST(
root=root,
train=True,
transform=transform,
download=download,
source_url=source_url,
)
mnist_test = flow.utils.vision.datasets.FashionMNIST(
root=root,
train=False,
transform=transform,
download=download,
source_url=source_url,
)
train_iter = flow.utils.data.DataLoader(
mnist_train, batch_size, shuffle=True, num_workers=num_workers
)
test_iter = flow.utils.data.DataLoader(
mnist_test, batch_size, shuffle=False, num_workers=num_workers
)
return train_iter, test_iter
...@@ -20,6 +20,7 @@ import oneflow.unittest ...@@ -20,6 +20,7 @@ import oneflow.unittest
import oneflow as flow import oneflow as flow
import oneflow.nn as nn import oneflow.nn as nn
import oneflow.optim as optim import oneflow.optim as optim
from data_utils import load_data_cifar10
classes = ( classes = (
...@@ -81,21 +82,19 @@ def test(test_case): ...@@ -81,21 +82,19 @@ def test(test_case):
os.getenv("ONEFLOW_TEST_CACHE_DIR", "./data-test"), "cifar10" os.getenv("ONEFLOW_TEST_CACHE_DIR", "./data-test"), "cifar10"
) )
trainset = flow.utils.vision.datasets.CIFAR10( train_iter, test_iter = load_data_cifar10(
root=data_dir, batch_size=batch_size,
train=True, data_dir=data_dir,
download=True, download=True,
transform=transform, transform=transform,
source_url="https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/cifar/cifar-10-python.tar.gz", source_url="https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/cifar/cifar-10-python.tar.gz",
) num_workers=0,
trainloader = flow.utils.data.DataLoader(
trainset, batch_size=batch_size, shuffle=False, num_workers=0
) )
final_loss = 0 final_loss = 0
for epoch in range(1, train_epoch + 1): # loop over the dataset multiple times for epoch in range(1, train_epoch + 1): # loop over the dataset multiple times
running_loss = 0.0 running_loss = 0.0
for i, data in enumerate(trainloader, 1): for i, data in enumerate(train_iter, 1):
# get the inputs; data is a list of [inputs, labels] # get the inputs; data is a list of [inputs, labels]
inputs, labels = data inputs, labels = data
inputs = inputs.to(dtype=flow.float32, device=device) inputs = inputs.to(dtype=flow.float32, device=device)
...@@ -130,10 +129,3 @@ class TestCifarDataset(flow.unittest.TestCase): ...@@ -130,10 +129,3 @@ class TestCifarDataset(flow.unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
# 1 epoch training log
# epoch: 1 step: 2000 loss: 2.107
# epoch: 1 step: 4000 loss: 1.838
# epoch: 1 step: 6000 loss: 1.644
# epoch: 1 step: 8000 loss: 1.535
# epoch: 1 step: 10000 loss: 1.528
# epoch: 1 step: 12000 loss: 1.476
...@@ -20,42 +20,7 @@ import time ...@@ -20,42 +20,7 @@ import time
import oneflow.unittest import oneflow.unittest
import oneflow as flow import oneflow as flow
import oneflow.nn as nn import oneflow.nn as nn
from data_utils import load_data_fashion_mnist
# reference: http://tangshusen.me/Dive-into-DL-PyTorch/#/chapter03_DL-basics/3.10_mlp-pytorch
def load_data_fashion_mnist(
batch_size, resize=None, root="./data/fashion-mnist", download=True, source_url=None
):
"""Download the Fashion-MNIST dataset and then load into memory."""
root = os.path.expanduser(root)
transformer = []
if resize:
transformer += [flow.utils.vision.transforms.Resize(resize)]
transformer += [flow.utils.vision.transforms.ToTensor()]
transformer = flow.utils.vision.transforms.Compose(transformer)
mnist_train = flow.utils.vision.datasets.FashionMNIST(
root=root,
train=True,
transform=transformer,
download=download,
source_url=source_url,
)
mnist_test = flow.utils.vision.datasets.FashionMNIST(
root=root,
train=False,
transform=transformer,
download=download,
source_url=source_url,
)
num_workers = 0
train_iter = flow.utils.data.DataLoader(
mnist_train, batch_size, shuffle=True, num_workers=num_workers
)
test_iter = flow.utils.data.DataLoader(
mnist_test, batch_size, shuffle=False, num_workers=num_workers
)
return train_iter, test_iter
def get_fashion_mnist_labels(labels): def get_fashion_mnist_labels(labels):
...@@ -124,7 +89,7 @@ def test(test_case): ...@@ -124,7 +89,7 @@ def test(test_case):
) )
source_url = "https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/Fashion-MNIST/" source_url = "https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/Fashion-MNIST/"
train_iter, test_iter = load_data_fashion_mnist( train_iter, test_iter = load_data_fashion_mnist(
batch_size, root=data_dir, download=True, source_url=source_url batch_size, resize=None, root=data_dir, download=True, source_url=source_url
) )
loss = nn.CrossEntropyLoss() loss = nn.CrossEntropyLoss()
loss.to(device) loss.to(device)
...@@ -174,6 +139,3 @@ class TestFashionMnistDataset(flow.unittest.TestCase): ...@@ -174,6 +139,3 @@ class TestFashionMnistDataset(flow.unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
# 1 epoch training log
# epoch 1, loss 0.0034, train acc 0.718, test acc 0.771, cost >>>>>>> 158.32699990272522(s)
# epoch 2, loss 0.0022, train acc 0.807, test acc 0.726, cost >>>>>>> 159.64465260505676(s)
...@@ -20,6 +20,7 @@ import unittest ...@@ -20,6 +20,7 @@ import unittest
import oneflow as flow import oneflow as flow
import oneflow.nn as nn import oneflow.nn as nn
import oneflow.unittest import oneflow.unittest
from data_utils import load_data_fashion_mnist
# reference: http://tangshusen.me/Dive-into-DL-PyTorch/#/chapter05_CNN/5.5_lenet # reference: http://tangshusen.me/Dive-into-DL-PyTorch/#/chapter05_CNN/5.5_lenet
...@@ -49,46 +50,6 @@ class LeNet(nn.Module): ...@@ -49,46 +50,6 @@ class LeNet(nn.Module):
return output return output
def load_data_fashion_mnist(
batch_size,
resize=None,
root="./data-test/fashion-mnist",
download=True,
source_url=None,
num_workers=0,
):
"""Download the Fashion-MNIST dataset and then load into memory."""
root = os.path.expanduser(root)
trans = []
if resize:
trans.append(flow.utils.vision.transforms.Resize(resize))
trans.append(flow.utils.vision.transforms.ToTensor())
transform = flow.utils.vision.transforms.Compose(trans)
mnist_train = flow.utils.vision.datasets.FashionMNIST(
root=root,
train=True,
transform=transform,
download=download,
source_url=source_url,
)
mnist_test = flow.utils.vision.datasets.FashionMNIST(
root=root,
train=False,
transform=transform,
download=download,
source_url=source_url,
)
train_iter = flow.utils.data.DataLoader(
mnist_train, batch_size, shuffle=True, num_workers=num_workers
)
test_iter = flow.utils.data.DataLoader(
mnist_test, batch_size, shuffle=False, num_workers=num_workers
)
return train_iter, test_iter
def evaluate_accuracy(data_iter, net, device=None): def evaluate_accuracy(data_iter, net, device=None):
if device is None and isinstance(net, nn.Module): if device is None and isinstance(net, nn.Module):
device = list(net.parameters())[0].device device = list(net.parameters())[0].device
...@@ -176,8 +137,3 @@ class TestLenet(flow.unittest.TestCase): ...@@ -176,8 +137,3 @@ class TestLenet(flow.unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
# 1 epoch training log
# epoch 1, loss 1.1473, train acc 0.569, test acc 0.742, time 162.4 sec
# epoch 2, loss 0.5736, train acc 0.784, test acc 0.796, time 158.1 sec
# epoch 3, loss 0.4761, train acc 0.826, test acc 0.821, time 154.0 sec
# epoch 4, loss 0.4215, train acc 0.848, test acc 0.855, time 160.3 sec
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import unittest
import oneflow.unittest
import oneflow as flow
import oneflow.nn as nn
import oneflow.utils.vision.transforms as transforms
from data_utils import load_data_mnist
data_dir = os.path.join(
os.getenv("ONEFLOW_TEST_CACHE_DIR", "./data-test"), "mnist-dataset"
)
train_iter, test_iter = load_data_mnist(
batch_size=128,
download=True,
root=data_dir,
source_url="https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist/MNIST/",
)
def evaluate_accuracy(data_iter, net, device=None):
n_correct, n_samples = 0.0, 0
net.to(device)
net.eval()
with flow.no_grad():
for images, labels in data_iter:
images = images.reshape(-1, 28 * 28)
images = images.to(device=device)
labels = labels.to(device=device)
n_correct += (net(images).argmax(dim=1).numpy() == labels.numpy()).sum()
n_samples += images.shape[0]
net.train()
return n_correct / n_samples
class Net(nn.Module):
def __init__(
self, input_size=784, hidden_size1=128, hidden_size2=64, num_classes=10
):
super(Net, self).__init__()
self.l1 = nn.Linear(input_size, hidden_size1)
self.relu1 = nn.ReLU()
self.l2 = nn.Linear(hidden_size1, hidden_size2)
self.relu2 = nn.ReLU()
self.l3 = nn.Linear(hidden_size2, num_classes)
def forward(self, x):
out = self.l1(x)
out = self.relu1(out)
out = self.l2(out)
out = self.relu2(out)
out = self.l3(out)
return out
def test_train_and_eval(test_case):
if os.getenv("ONEFLOW_TEST_CPU_ONLY"):
device = flow.device("cpu")
else:
device = flow.device("cuda")
model = Net()
model.to(device)
loss = nn.CrossEntropyLoss().to(device)
optimizer = flow.optim.SGD(model.parameters(), lr=0.10)
num_epochs = 1
for epoch in range(num_epochs):
train_loss, n_correct, n_samples = 0.0, 0.0, 0
for images, labels in train_iter:
images = images.reshape(-1, 28 * 28)
images = images.to(device=device)
labels = labels.to(device=device)
features = model(images)
l = loss(features, labels).sum()
optimizer.zero_grad()
l.backward()
optimizer.step()
train_loss += l.numpy()
n_correct += (features.argmax(dim=1).numpy() == labels.numpy()).sum()
n_samples += images.shape[0]
if n_samples > 2000:
break
test_acc = evaluate_accuracy(test_iter, model, device)
train_acc = n_correct / n_samples
print(
"epoch %d, train loss %.4f, train acc %.3f, test acc %.3f"
% (epoch + 1, train_loss / n_samples, train_acc, test_acc)
)
# test_case.assertLess(0.8, test_acc)
@flow.unittest.skip_unless_1n1d()
class TestMnistDataset(flow.unittest.TestCase):
def test_mnist_dataset(test_case):
test_train_and_eval(test_case)
if __name__ == "__main__":
unittest.main()
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import unittest
import oneflow as flow
import oneflow.nn as nn
import oneflow.optim as optim
import oneflow.utils.vision.transforms as transforms
import oneflow.unittest
from data_utils import load_data_cifar10
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(flow.F.relu(self.conv1(x)))
x = self.pool(flow.F.relu(self.conv2(x)))
x = flow.flatten(x, 1) # flatten all dimensions except batch
x = flow.F.relu(self.fc1(x))
x = flow.F.relu(self.fc2(x))
x = self.fc3(x)
return x
def test(test_case):
if os.getenv("ONEFLOW_TEST_CPU_ONLY"):
device = flow.device("cpu")
else:
device = flow.device("cuda")
net = Net()
net.to(device)
optimizer = optim.SGD(net.parameters(), lr=0.002, momentum=0.9)
criterion = nn.CrossEntropyLoss()
criterion.to(device)
transform = flow.utils.vision.transforms.Compose(
[
transforms.Pad(10),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
transforms.CenterCrop(32),
transforms.Resize([32, 32]),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
train_epoch = 1
batch_size = 4
data_dir = os.path.join(
os.getenv("ONEFLOW_TEST_CACHE_DIR", "./data-test"), "cifar10"
)
train_iter, test_iter = load_data_cifar10(
batch_size=batch_size,
data_dir=data_dir,
download=True,
transform=transform,
source_url="https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/cifar/cifar-10-python.tar.gz",
num_workers=0,
)
final_loss = 0
for epoch in range(1, train_epoch + 1): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(train_iter, 1):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
inputs = inputs.to(dtype=flow.float32, device=device)
labels = labels.to(dtype=flow.int64, device=device)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.numpy()
# print every 2000 mini-batches
if i % 2000 == 0:
final_loss = running_loss / 2000
print("epoch: %d step: %5d loss: %.3f " % (epoch, i, final_loss))
running_loss = 0.0
print("final loss : ", final_loss)
# test_case.assertLess(final_loss, 1.79)
@flow.unittest.skip_unless_1n1d()
class TestCifarDataset(flow.unittest.TestCase):
def test_cifar_dataset(test_case):
test(test_case)
if __name__ == "__main__":
unittest.main()
...@@ -35,6 +35,7 @@ from oneflow.utils.data.decorator import ( ...@@ -35,6 +35,7 @@ from oneflow.utils.data.decorator import (
guaranteed_datapipes_determinism, guaranteed_datapipes_determinism,
non_deterministic, non_deterministic,
) )
from oneflow.utils.data.distributed import DistributedSampler
__all__ = [ __all__ = [
...@@ -55,4 +56,5 @@ __all__ = [ ...@@ -55,4 +56,5 @@ __all__ = [
"functional_datapipe", "functional_datapipe",
"guaranteed_datapipes_determinism", "guaranteed_datapipes_determinism",
"non_deterministic", "non_deterministic",
"DistributedSampler",
] ]
...@@ -163,9 +163,7 @@ class DataLoader(Generic[T_co]): ...@@ -163,9 +163,7 @@ class DataLoader(Generic[T_co]):
.. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn` .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
cannot be an unpicklable object, e.g., a lambda function. See cannot be an unpicklable object, e.g., a lambda function.
:ref:`multiprocessing-best-practices` on more details related
to multiprocessing in OneFlow.
.. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used. .. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used.
When :attr:`dataset` is an :class:`~flow.utils.data.IterableDataset`, When :attr:`dataset` is an :class:`~flow.utils.data.IterableDataset`,
...@@ -181,12 +179,6 @@ class DataLoader(Generic[T_co]): ...@@ -181,12 +179,6 @@ class DataLoader(Generic[T_co]):
dropped when :attr:`drop_last` is set. Unfortunately, OneFlow can not detect such dropped when :attr:`drop_last` is set. Unfortunately, OneFlow can not detect such
cases in general. cases in general.
See `Dataset Types`_ for more details on these two types of datasets and how
:class:`~flow.utils.data.IterableDataset` interacts with
`Multi-process data loading`_.
.. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and
:ref:`data-loading-randomness` notes for random seed related questions.
""" """
dataset: Dataset[T_co] dataset: Dataset[T_co]
batch_size: Optional[int] batch_size: Optional[int]
......
...@@ -195,7 +195,6 @@ class TensorDataset(Dataset[Tuple[Tensor, ...]]): ...@@ -195,7 +195,6 @@ class TensorDataset(Dataset[Tuple[Tensor, ...]]):
Args: Args:
*tensors (Tensor): tensors that have the same size of the first dimension. *tensors (Tensor): tensors that have the same size of the first dimension.
""" """
tensors: Tuple[Tensor, ...]
def __init__(self, *tensors: Tensor) -> None: def __init__(self, *tensors: Tensor) -> None:
assert all( assert all(
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import math
import numpy as np
from typing import TypeVar, Optional, Iterator
import oneflow as flow
import oneflow.distributed as dist
from oneflow.utils.data import Sampler, Dataset
T_co = TypeVar("T_co", covariant=True)
class DistributedSampler(Sampler[T_co]):
r"""Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
:class:`flow.nn.parallel.DistributedDataParallel`. In such a case, each
process can pass a :class:`~flow.utils.data.DistributedSampler` instance as a
:class:`~flow.utils.data.DataLoader` sampler, and load a subset of the
original dataset that is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Args:
dataset: Dataset used for sampling.
num_replicas (int, optional): Number of processes participating in
distributed training. By default, :attr:`world_size` is retrieved from the
current distributed group.
rank (int, optional): Rank of the current process within :attr:`num_replicas`.
By default, :attr:`rank` is retrieved from the current distributed
group.
shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
indices.
seed (int, optional): random seed used to shuffle the sampler if
:attr:`shuffle=True`. This number should be identical across all
processes in the distributed group. Default: ``0``.
drop_last (bool, optional): if ``True``, then the sampler will drop the
tail of the data to make it evenly divisible across the number of
replicas. If ``False``, the sampler will add extra indices to make
the data evenly divisible across the replicas. Default: ``False``.
.. warning::
In distributed mode, calling the :meth:`set_epoch` method at
the beginning of each epoch **before** creating the :class:`DataLoader` iterator
is necessary to make shuffling work properly across multiple epochs. Otherwise,
the same ordering will be always used.
For example:
.. code-block:: python
>>> sampler = DistributedSampler(dataset) if is_distributed else None
>>> loader = DataLoader(dataset, shuffle=(sampler is None), sampler=sampler)
>>> for epoch in range(start_epoch, n_epochs):
... if is_distributed:
... sampler.set_epoch(epoch)
... train(loader)
"""
def __init__(
self,
dataset: Dataset,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
) -> None:
if not dist.is_multi_client():
raise RuntimeError("Requires multi-client env to be available")
if num_replicas is None:
num_replicas = dist.get_world_size()
if rank is None:
rank = dist.get_rank()
print(
"dist.get_world_size() >>>>> ",
dist.get_world_size(),
"dist.get_rank() >>>>>",
dist.get_rank(),
)
if rank >= num_replicas or rank < 0:
raise ValueError(
"Invalid rank {}, rank should be in the interval"
" [0, {}]".format(rank, num_replicas - 1)
)
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.drop_last = drop_last
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if self.drop_last and len(self.dataset) % self.num_replicas != 0:
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil(
# `type:ignore` is required because Dataset cannot provide a default __len__
(len(self.dataset) - self.num_replicas)
/ self.num_replicas
)
else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
self.seed = seed
def __iter__(self) -> Iterator[T_co]:
if self.shuffle:
# deterministically shuffle based on epoch and seed
# TODO:replace with flow.randperm
g = flow.Generator()
g.manual_seed(self.seed + self.epoch)
# indices = flow.randperm(len(self.dataset), generator=g).tolist()
indices = np.random.permutation(len(self.dataset)).tolist()
else:
indices = list(range(len(self.dataset)))
if not self.drop_last:
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[
:padding_size
]
else:
# remove tail of data to make it evenly divisible.
indices = indices[: self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self) -> int:
return self.num_samples
def set_epoch(self, epoch: int) -> None:
"""Sets the epoch for this sampler.
When :attr:`shuffle=True`, this ensures all replicas use a different random
ordering for each epoch. Otherwise, the next iteration of this sampler
will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
...@@ -15,3 +15,28 @@ limitations under the License. ...@@ -15,3 +15,28 @@ limitations under the License.
""" """
from oneflow.utils.vision import datasets from oneflow.utils.vision import datasets
from oneflow.utils.vision import transforms from oneflow.utils.vision import transforms
_image_backend = "PIL"
def set_image_backend(backend):
"""
Specifies the package used to load images.
Args:
backend (string): Name of the image backend. one of {'PIL', 'accimage'}.
The :mod:`accimage` package uses the Intel IPP library. It is
generally faster than PIL, but does not support as many operations.
"""
global _image_backend
if backend not in ["PIL", "accimage"]:
raise ValueError(
"Invalid backend '{}'. Options are 'PIL' and 'accimage'".format(backend)
)
_image_backend = backend
def get_image_backend():
"""
Gets the name of the package used to load images
"""
return _image_backend
...@@ -15,5 +15,21 @@ limitations under the License. ...@@ -15,5 +15,21 @@ limitations under the License.
""" """
from .mnist import MNIST, FashionMNIST from .mnist import MNIST, FashionMNIST
from .cifar import CIFAR10, CIFAR100 from .cifar import CIFAR10, CIFAR100
from .coco import CocoCaptions, CocoDetection
from .imagenet import ImageNet
from .voc import VOCDetection, VOCSegmentation
from .folder import DatasetFolder, ImageFolder
__all__ = ["MNIST", "FashionMNIST", "CIFAR10", "CIFAR100"] __all__ = [
"MNIST",
"FashionMNIST",
"CIFAR10",
"CIFAR100",
"CocoCaptions",
"CocoDetection",
"ImageNet",
"VOCDetection",
"VOCSegmentation",
"DatasetFolder",
"ImageFolder",
]
...@@ -25,8 +25,10 @@ from .utils import check_integrity, download_and_extract_archive ...@@ -25,8 +25,10 @@ from .utils import check_integrity, download_and_extract_archive
class CIFAR10(VisionDataset): class CIFAR10(VisionDataset):
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset. r""" `CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
Args: Args:
root (string): Root directory of dataset where directory root (string): Root directory of dataset where directory
``cifar-10-batches-py`` exists or will be saved to if download is set to True. ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
train (bool, optional): If True, creates dataset from training set, otherwise train (bool, optional): If True, creates dataset from training set, otherwise
...@@ -39,7 +41,6 @@ class CIFAR10(VisionDataset): ...@@ -39,7 +41,6 @@ class CIFAR10(VisionDataset):
puts it in root directory. If dataset is already downloaded, it is not puts it in root directory. If dataset is already downloaded, it is not
downloaded again. downloaded again.
""" """
base_folder = "cifar-10-batches-py" base_folder = "cifar-10-batches-py"
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
filename = "cifar-10-python.tar.gz" filename = "cifar-10-python.tar.gz"
...@@ -128,6 +129,7 @@ class CIFAR10(VisionDataset): ...@@ -128,6 +129,7 @@ class CIFAR10(VisionDataset):
""" """
Args: Args:
index (int): Index index (int): Index
Returns: Returns:
tuple: (image, target) where target is index of the target class. tuple: (image, target) where target is index of the target class.
""" """
...@@ -170,10 +172,10 @@ class CIFAR10(VisionDataset): ...@@ -170,10 +172,10 @@ class CIFAR10(VisionDataset):
class CIFAR100(CIFAR10): class CIFAR100(CIFAR10):
"""`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset. r""" `CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
This is a subclass of the `CIFAR10` Dataset. This is a subclass of the `CIFAR10` Dataset.
""" """
base_folder = "cifar-100-python" base_folder = "cifar-100-python"
url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
filename = "cifar-100-python.tar.gz" filename = "cifar-100-python.tar.gz"
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from PIL import Image
import os
import os.path
from typing import Any, Callable, Optional, Tuple, List
from .vision import VisionDataset
class CocoDetection(VisionDataset):
r"""`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.
Args:
root (string): Root directory where images are downloaded to.
annFile (string): Path to json annotation file.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.ToTensor``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry
and returns a transformed version.
"""
def __init__(
self,
root: str,
annFile: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
):
super().__init__(root, transforms, transform, target_transform)
from pycocotools.coco import COCO
self.coco = COCO(annFile)
self.ids = list(sorted(self.coco.imgs.keys()))
def _load_image(self, id: int) -> Image.Image:
path = self.coco.loadImgs(id)[0]["file_name"]
return Image.open(os.path.join(self.root, path)).convert("RGB")
def _load_target(self, id) -> List[Any]:
return self.coco.loadAnns(self.coco.getAnnIds(id))
def __getitem__(self, index: int) -> Tuple[Any, Any]:
id = self.ids[index]
image = self._load_image(id)
target = self._load_target(id)
if self.transforms is not None:
image, target = self.transforms(image, target)
return image, target
def __len__(self) -> int:
return len(self.ids)
class CocoCaptions(CocoDetection):
r"""`MS Coco Captions <https://cocodataset.org/#captions-2015>`_ Dataset.
Args:
root (string): Root directory where images are downloaded to.
annFile (string): Path to json annotation file.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.ToTensor``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry
and returns a transformed version.
Example:
.. code:: python
import oneflow.utils.vision.datasets as dset
import oneflow.utils.vision.transforms as transforms
cap = dset.CocoCaptions(root = 'dir where images are',
annFile = 'json annotation file',
transform=transforms.ToTensor())
print('Number of samples: ', len(cap))
img, target = cap[3] # load 4th sample
print("Image Size: ", img.size())
print(target)
Output: ::
Number of samples: 82783
Image Size: (3L, 427L, 640L)
[u'A plane emitting smoke stream flying over a mountain.',
u'A plane darts across a bright blue sky behind a mountain covered in snow',
u'A plane leaves a contrail above the snowy mountain top.',
u'A mountain that has a plane flying overheard in the distance.',
u'A mountain view with a plume of smoke in the background']
"""
def _load_target(self, id) -> List[str]:
return [ann["caption"] for ann in super()._load_target(id)]
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import os.path
from PIL import Image
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
from .vision import VisionDataset
def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
"""Checks if a file is an allowed extension.
Args:
filename (string): path to a file
extensions (tuple of strings): extensions to consider (lowercase)
Returns:
bool: True if the filename ends with one of given extensions
"""
return filename.lower().endswith(extensions)
def is_image_file(filename: str) -> bool:
"""Checks if a file is an allowed image extension.
Args:
filename (string): path to a file
Returns:
bool: True if the filename ends with a known image extension
"""
return has_file_allowed_extension(filename, IMG_EXTENSIONS)
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Finds the class folders in a dataset.
See :class:`DatasetFolder` for details.
"""
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
if not classes:
raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
def make_dataset(
directory: str,
class_to_idx: Optional[Dict[str, int]] = None,
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
"""Generates a list of samples of a form (path_to_sample, class).
See :class:`DatasetFolder` for details.
Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
by default.
"""
directory = os.path.expanduser(directory)
if class_to_idx is None:
_, class_to_idx = find_classes(directory)
elif not class_to_idx:
raise ValueError(
"'class_to_index' must have at least one entry to collect any samples."
)
both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError(
"Both extensions and is_valid_file cannot be None or not None at the same time"
)
if extensions is not None:
def is_valid_file(x: str) -> bool:
return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))
is_valid_file = cast(Callable[[str], bool], is_valid_file)
instances = []
available_classes = set()
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
if not os.path.isdir(target_dir):
continue
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
for fname in sorted(fnames):
if is_valid_file(fname):
path = os.path.join(root, fname)
item = path, class_index
instances.append(item)
if target_class not in available_classes:
available_classes.add(target_class)
empty_classes = set(class_to_idx.keys()) - available_classes
if empty_classes:
msg = (
f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
)
if extensions is not None:
msg += f"Supported extensions are: {', '.join(extensions)}"
raise FileNotFoundError(msg)
return instances
class DatasetFolder(VisionDataset):
r"""A generic data loader.
This default directory structure can be customized by overriding the
:meth:`find_classes` method.
Args:
root (string): Root directory path.
loader (callable): A function to load a sample given its path.
extensions (tuple[string]): A list of allowed extensions.
both extensions and is_valid_file should not be passed.
transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.RandomCrop`` for images.
target_transform (callable, optional): A function/transform that takes
in the target and transforms it.
is_valid_file (callable, optional): A function that takes path of a file
and check if the file is a valid file (used to check of corrupt files)
both extensions and is_valid_file should not be passed.
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
samples (list): List of (sample path, class_index) tuples
targets (list): The class_index value for each image in the dataset
"""
def __init__(
self,
root: str,
loader: Callable[[str], Any],
extensions: Optional[Tuple[str, ...]] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> None:
super(DatasetFolder, self).__init__(
root, transform=transform, target_transform=target_transform
)
classes, class_to_idx = self.find_classes(self.root)
samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
self.loader = loader
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
@staticmethod
def make_dataset(
directory: str,
class_to_idx: Dict[str, int],
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
"""Generates a list of samples of a form (path_to_sample, class).
This can be overridden to e.g. read files from a compressed zip file instead of from the disk.
Args:
directory (str): root dataset directory, corresponding to ``self.root``.
class_to_idx (Dict[str, int]): Dictionary mapping class name to class index.
extensions (optional): A list of allowed extensions.
Either extensions or is_valid_file should be passed. Defaults to None.
is_valid_file (optional): A function that takes path of a file
and checks if the file is a valid file
(used to check of corrupt files) both extensions and
is_valid_file should not be passed. Defaults to None.
Raises:
ValueError: In case ``class_to_idx`` is empty.
ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
FileNotFoundError: In case no valid file was found for any class.
Returns:
List[Tuple[str, int]]: samples of a form (path_to_sample, class)
"""
if class_to_idx is None:
# prevent potential bug since make_dataset() would use the class_to_idx logic of the
# find_classes() function, instead of using that of the find_classes() method, which
# is potentially overridden and thus could have a different logic.
raise ValueError("The class_to_idx parameter cannot be None.")
return make_dataset(
directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file
)
def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Find the class folders in a dataset structured as follows:
.. code-block:: shell
directory/
├── class_x
│ ├── xxx.ext
│ ├── xxy.ext
│ └── ...
│ └── xxz.ext
└── class_y
├── 123.ext
├── nsdf3.ext
└── ...
└── asd932_.ext
This method can be overridden to only consider
a subset of classes, or to adapt to a different dataset directory structure.
Args:
directory(str): Root directory path, corresponding to ``self.root``
Raises:
FileNotFoundError: If ``dir`` has no class folders.
Returns:
(Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
"""
return find_classes(directory)
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self) -> int:
return len(self.samples)
IMG_EXTENSIONS = (
".jpg",
".jpeg",
".png",
".ppm",
".bmp",
".pgm",
".tif",
".tiff",
".webp",
)
def pil_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, "rb") as f:
img = Image.open(f)
return img.convert("RGB")
# TODO: specify the return type
def accimage_loader(path: str) -> Any:
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
def default_loader(path: str) -> Any:
from oneflow.utils.vision import get_image_backend
if get_image_backend() == "accimage":
return accimage_loader(path)
else:
return pil_loader(path)
class ImageFolder(DatasetFolder):
r"""A generic data loader where the images are arranged in this way by default:
.. code-block:: shell
root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/[...]/asd932_.png
This class inherits from :class:`~vision.datasets.DatasetFolder` so
the same methods can be overridden to customize the dataset.
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files)
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
loader: Callable[[str], Any] = default_loader,
is_valid_file: Optional[Callable[[str], bool]] = None,
):
super(ImageFolder, self).__init__(
root,
loader,
IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform,
is_valid_file=is_valid_file,
)
self.imgs = self.samples
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import warnings
from contextlib import contextmanager
import os
import shutil
import tempfile
from typing import Any, Dict, List, Iterator, Optional, Tuple
import oneflow as flow
from .folder import ImageFolder
from .utils import check_integrity, extract_archive, verify_str_arg
ARCHIVE_META = {
"train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"),
"val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"),
"devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"),
}
META_FILE = "meta.bin"
class ImageNet(ImageFolder):
r""" `ImageNet <http://image-net.org/>`_ 2012 Classification Dataset.
Args:
root (string): Root directory of the ImageNet Dataset.
split (string, optional): The dataset split, supports ``train``, or ``val``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
Attributes:
classes (list): List of the class name tuples.
class_to_idx (dict): Dict with items (class_name, class_index).
wnids (list): List of the WordNet IDs.
wnid_to_idx (dict): Dict with items (wordnet_id, class_index).
imgs (list): List of (image path, class_index) tuples
targets (list): The class_index value for each image in the dataset
"""
def __init__(
self,
root: str,
split: str = "train",
download: Optional[str] = None,
**kwargs: Any
) -> None:
if download is True:
msg = (
"The dataset is no longer publicly accessible. You need to "
"download the archives externally and place them in the root "
"directory."
)
raise RuntimeError(msg)
elif download is False:
msg = (
"The use of the download flag is deprecated, since the dataset "
"is no longer publicly accessible."
)
warnings.warn(msg, RuntimeWarning)
root = self.root = os.path.expanduser(root)
self.split = verify_str_arg(split, "split", ("train", "val"))
self.parse_archives()
wnid_to_classes = load_meta_file(self.root)[0]
super(ImageNet, self).__init__(self.split_folder, **kwargs)
self.root = root
self.wnids = self.classes
self.wnid_to_idx = self.class_to_idx
self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
self.class_to_idx = {
cls: idx for idx, clss in enumerate(self.classes) for cls in clss
}
def parse_archives(self) -> None:
if not check_integrity(os.path.join(self.root, META_FILE)):
parse_devkit_archive(self.root)
if not os.path.isdir(self.split_folder):
if self.split == "train":
parse_train_archive(self.root)
elif self.split == "val":
parse_val_archive(self.root)
@property
def split_folder(self) -> str:
return os.path.join(self.root, self.split)
def extra_repr(self) -> str:
return "Split: {split}".format(**self.__dict__)
def load_meta_file(
root: str, file: Optional[str] = None
) -> Tuple[Dict[str, str], List[str]]:
if file is None:
file = META_FILE
file = os.path.join(root, file)
if check_integrity(file):
return flow.load(file)
else:
msg = (
"The meta file {} is not present in the root directory or is corrupted. "
"This file is automatically created by the ImageNet dataset."
)
raise RuntimeError(msg.format(file, root))
def _verify_archive(root: str, file: str, md5: str) -> None:
if not check_integrity(os.path.join(root, file), md5):
msg = (
"The archive {} is not present in the root directory or is corrupted. "
"You need to download it externally and place it in {}."
)
raise RuntimeError(msg.format(file, root))
def parse_devkit_archive(root: str, file: Optional[str] = None) -> None:
"""Parse the devkit archive of the ImageNet2012 classification dataset and save
the meta information in a binary file.
Args:
root (str): Root directory containing the devkit archive
file (str, optional): Name of devkit archive. Defaults to
'ILSVRC2012_devkit_t12.tar.gz'
"""
import scipy.io as sio
def parse_meta_mat(devkit_root: str) -> Tuple[Dict[int, str], Dict[str, str]]:
metafile = os.path.join(devkit_root, "data", "meta.mat")
meta = sio.loadmat(metafile, squeeze_me=True)["synsets"]
nums_children = list(zip(*meta))[4]
meta = [
meta[idx]
for idx, num_children in enumerate(nums_children)
if num_children == 0
]
idcs, wnids, classes = list(zip(*meta))[:3]
classes = [tuple(clss.split(", ")) for clss in classes]
idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
return idx_to_wnid, wnid_to_classes
def parse_val_groundtruth_txt(devkit_root: str) -> List[int]:
file = os.path.join(
devkit_root, "data", "ILSVRC2012_validation_ground_truth.txt"
)
with open(file, "r") as txtfh:
val_idcs = txtfh.readlines()
return [int(val_idx) for val_idx in val_idcs]
@contextmanager
def get_tmp_dir() -> Iterator[str]:
tmp_dir = tempfile.mkdtemp()
try:
yield tmp_dir
finally:
shutil.rmtree(tmp_dir)
archive_meta = ARCHIVE_META["devkit"]
if file is None:
file = archive_meta[0]
md5 = archive_meta[1]
_verify_archive(root, file, md5)
with get_tmp_dir() as tmp_dir:
extract_archive(os.path.join(root, file), tmp_dir)
devkit_root = os.path.join(tmp_dir, "ILSVRC2012_devkit_t12")
idx_to_wnid, wnid_to_classes = parse_meta_mat(devkit_root)
val_idcs = parse_val_groundtruth_txt(devkit_root)
val_wnids = [idx_to_wnid[idx] for idx in val_idcs]
flow.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE))
def parse_train_archive(
root: str, file: Optional[str] = None, folder: str = "train"
) -> None:
"""Parse the train images archive of the ImageNet2012 classification dataset and
prepare it for usage with the ImageNet dataset.
Args:
root (str): Root directory containing the train images archive
file (str, optional): Name of train images archive. Defaults to
'ILSVRC2012_img_train.tar'
folder (str, optional): Optional name for train images folder. Defaults to
'train'
"""
archive_meta = ARCHIVE_META["train"]
if file is None:
file = archive_meta[0]
md5 = archive_meta[1]
_verify_archive(root, file, md5)
train_root = os.path.join(root, folder)
extract_archive(os.path.join(root, file), train_root)
archives = [os.path.join(train_root, archive) for archive in os.listdir(train_root)]
for archive in archives:
extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True)
def parse_val_archive(
root: str,
file: Optional[str] = None,
wnids: Optional[List[str]] = None,
folder: str = "val",
) -> None:
"""Parse the validation images archive of the ImageNet2012 classification dataset
and prepare it for usage with the ImageNet dataset.
Args:
root (str): Root directory containing the validation images archive
file (str, optional): Name of validation images archive. Defaults to
'ILSVRC2012_img_val.tar'
wnids (list, optional): List of WordNet IDs of the validation images. If None
is given, the IDs are loaded from the meta file in the root directory
folder (str, optional): Optional name for validation images folder. Defaults to
'val'
"""
archive_meta = ARCHIVE_META["val"]
if file is None:
file = archive_meta[0]
md5 = archive_meta[1]
if wnids is None:
wnids = load_meta_file(root)[1]
_verify_archive(root, file, md5)
val_root = os.path.join(root, folder)
extract_archive(os.path.join(root, file), val_root)
images = sorted([os.path.join(val_root, image) for image in os.listdir(val_root)])
for wnid in set(wnids):
os.mkdir(os.path.join(val_root, wnid))
for wnid, img_file in zip(wnids, images):
shutil.move(img_file, os.path.join(val_root, wnid, os.path.basename(img_file)))
...@@ -29,7 +29,8 @@ from oneflow.framework.tensor import Tensor ...@@ -29,7 +29,8 @@ from oneflow.framework.tensor import Tensor
class MNIST(VisionDataset): class MNIST(VisionDataset):
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset. r""" `MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
Args: Args:
root (string): Root directory of dataset where ``MNIST/processed/training.pt`` root (string): Root directory of dataset where ``MNIST/processed/training.pt``
and ``MNIST/processed/test.pt`` exist. and ``MNIST/processed/test.pt`` exist.
...@@ -43,7 +44,6 @@ class MNIST(VisionDataset): ...@@ -43,7 +44,6 @@ class MNIST(VisionDataset):
target_transform (callable, optional): A function/transform that takes in the target_transform (callable, optional): A function/transform that takes in the
target and transforms it. target and transforms it.
""" """
mirrors = [ mirrors = [
"http://yann.lecun.com/exdb/mnist/", "http://yann.lecun.com/exdb/mnist/",
"https://ossci-datasets.s3.amazonaws.com/mnist/", "https://ossci-datasets.s3.amazonaws.com/mnist/",
...@@ -222,7 +222,8 @@ class MNIST(VisionDataset): ...@@ -222,7 +222,8 @@ class MNIST(VisionDataset):
class FashionMNIST(MNIST): class FashionMNIST(MNIST):
"""`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset. r""" `Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset.
Args: Args:
root (string): Root directory of dataset where ``FashionMNIST/processed/training.pt`` root (string): Root directory of dataset where ``FashionMNIST/processed/training.pt``
and ``FashionMNIST/processed/test.pt`` exist. and ``FashionMNIST/processed/test.pt`` exist.
...@@ -236,7 +237,6 @@ class FashionMNIST(MNIST): ...@@ -236,7 +237,6 @@ class FashionMNIST(MNIST):
target_transform (callable, optional): A function/transform that takes in the target_transform (callable, optional): A function/transform that takes in the
target and transforms it. target and transforms it.
""" """
mirrors = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"] mirrors = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"]
resources = [ resources = [
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import warnings
import collections
from xml.etree.ElementTree import Element as ET_Element
try:
from defusedxml.ElementTree import parse as ET_parse
except ImportError:
from xml.etree.ElementTree import parse as ET_parse
from PIL import Image
from typing import Any, Callable, Dict, Optional, Tuple, List
from .utils import download_and_extract_archive, verify_str_arg
from .vision import VisionDataset
DATASET_YEAR_DICT = {
"2012": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar",
"filename": "VOCtrainval_11-May-2012.tar",
"md5": "6cd6e144f989b92b3379bac3b3de84fd",
"base_dir": os.path.join("VOCdevkit", "VOC2012"),
},
"2011": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar",
"filename": "VOCtrainval_25-May-2011.tar",
"md5": "6c3384ef61512963050cb5d687e5bf1e",
"base_dir": os.path.join("TrainVal", "VOCdevkit", "VOC2011"),
},
"2010": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar",
"filename": "VOCtrainval_03-May-2010.tar",
"md5": "da459979d0c395079b5c75ee67908abb",
"base_dir": os.path.join("VOCdevkit", "VOC2010"),
},
"2009": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar",
"filename": "VOCtrainval_11-May-2009.tar",
"md5": "59065e4b188729180974ef6572f6a212",
"base_dir": os.path.join("VOCdevkit", "VOC2009"),
},
"2008": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar",
"filename": "VOCtrainval_11-May-2012.tar",
"md5": "2629fa636546599198acfcfbfcf1904a",
"base_dir": os.path.join("VOCdevkit", "VOC2008"),
},
"2007": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar",
"filename": "VOCtrainval_06-Nov-2007.tar",
"md5": "c52e279531787c972589f7e41ab4ae64",
"base_dir": os.path.join("VOCdevkit", "VOC2007"),
},
"2007-test": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar",
"filename": "VOCtest_06-Nov-2007.tar",
"md5": "b6e924de25625d8de591ea690078ad9f",
"base_dir": os.path.join("VOCdevkit", "VOC2007"),
},
}
class _VOCBase(VisionDataset):
_SPLITS_DIR: str
_TARGET_DIR: str
_TARGET_FILE_EXT: str
def __init__(
self,
root: str,
year: str = "2012",
image_set: str = "train",
download: bool = False,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
):
super().__init__(root, transforms, transform, target_transform)
if year == "2007-test":
if image_set == "test":
warnings.warn(
"Acessing the test image set of the year 2007 with year='2007-test' is deprecated. "
"Please use the combination year='2007' and image_set='test' instead."
)
year = "2007"
else:
raise ValueError(
"In the test image set of the year 2007 only image_set='test' is allowed. "
"For all other image sets use year='2007' instead."
)
self.year = year
valid_image_sets = ["train", "trainval", "val"]
if year == "2007":
valid_image_sets.append("test")
self.image_set = verify_str_arg(image_set, "image_set", valid_image_sets)
key = "2007-test" if year == "2007" and image_set == "test" else year
dataset_year_dict = DATASET_YEAR_DICT[key]
self.url = dataset_year_dict["url"]
self.filename = dataset_year_dict["filename"]
self.md5 = dataset_year_dict["md5"]
base_dir = dataset_year_dict["base_dir"]
voc_root = os.path.join(self.root, base_dir)
if download:
download_and_extract_archive(
self.url, self.root, filename=self.filename, md5=self.md5
)
if not os.path.isdir(voc_root):
raise RuntimeError(
"Dataset not found or corrupted. You can use download=True to download it"
)
splits_dir = os.path.join(voc_root, "ImageSets", self._SPLITS_DIR)
split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt")
with open(os.path.join(split_f), "r") as f:
file_names = [x.strip() for x in f.readlines()]
image_dir = os.path.join(voc_root, "JPEGImages")
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
target_dir = os.path.join(voc_root, self._TARGET_DIR)
self.targets = [
os.path.join(target_dir, x + self._TARGET_FILE_EXT) for x in file_names
]
assert len(self.images) == len(self.targets)
def __len__(self) -> int:
return len(self.images)
class VOCSegmentation(_VOCBase):
r""" `Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
Args:
root (string): Root directory of the VOC Dataset.
year (string, optional): The dataset year, supports years ``"2007"`` to ``"2012"``.
image_set (string, optional): Select the image_set to use, ``"train"``, ``"trainval"`` or ``"val"``. If
``year=="2007"``, can also be ``"test"``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry
and returns a transformed version.
"""
_SPLITS_DIR = "Segmentation"
_TARGET_DIR = "SegmentationClass"
_TARGET_FILE_EXT = ".png"
@property
def masks(self) -> List[str]:
return self.targets
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is the image segmentation.
"""
img = Image.open(self.images[index]).convert("RGB")
target = Image.open(self.masks[index])
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
class VOCDetection(_VOCBase):
r""" `Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset.
Args:
root (string): Root directory of the VOC Dataset.
year (string, optional): The dataset year, supports years ``"2007"`` to ``"2012"``.
image_set (string, optional): Select the image_set to use, ``"train"``, ``"trainval"`` or ``"val"``. If
``year=="2007"``, can also be ``"test"``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
(default: alphabetic indexing of VOC's 20 classes).
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, required): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry
and returns a transformed version.
"""
_SPLITS_DIR = "Main"
_TARGET_DIR = "Annotations"
_TARGET_FILE_EXT = ".xml"
@property
def annotations(self) -> List[str]:
return self.targets
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is a dictionary of the XML tree.
"""
img = Image.open(self.images[index]).convert("RGB")
target = self.parse_voc_xml(ET_parse(self.annotations[index]).getroot())
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
def parse_voc_xml(self, node: ET_Element) -> Dict[str, Any]:
voc_dict: Dict[str, Any] = {}
children = list(node)
if children:
def_dic: Dict[str, Any] = collections.defaultdict(list)
for dc in map(self.parse_voc_xml, children):
for ind, v in dc.items():
def_dic[ind].append(v)
if node.tag == "annotation":
def_dic["object"] = [def_dic["object"]]
voc_dict = {
node.tag: {
ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items()
}
}
if node.text:
text = node.text.strip()
if not children:
voc_dict[node.tag] = text
return voc_dict
...@@ -13,6 +13,54 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -13,6 +13,54 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
from .transforms import Normalize, Compose, ToTensor, Resize from .transforms import (
Compose,
ToTensor,
PILToTensor,
ConvertImageDtype,
ToPILImage,
Normalize,
Resize,
Scale,
CenterCrop,
Pad,
Lambda,
RandomTransforms,
RandomApply,
RandomOrder,
RandomChoice,
RandomCrop,
RandomHorizontalFlip,
RandomVerticalFlip,
RandomResizedCrop,
RandomSizedCrop,
FiveCrop,
TenCrop,
InterpolationMode,
)
__all__ = ["Normalize", "Compose", "ToTensor", "Resize"] __all__ = [
"Compose",
"ToTensor",
"PILToTensor",
"ConvertImageDtype",
"ToPILImage",
"Normalize",
"Resize",
"Scale",
"CenterCrop",
"Pad",
"Lambda",
"RandomTransforms",
"RandomApply",
"RandomOrder",
"RandomChoice",
"RandomCrop",
"RandomHorizontalFlip",
"RandomVerticalFlip",
"RandomResizedCrop",
"RandomSizedCrop",
"FiveCrop",
"TenCrop",
"InterpolationMode",
]
...@@ -13,9 +13,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -13,9 +13,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import numbers
from typing import Any, List, Sequence from typing import Any, List, Sequence
import numpy as np
from PIL import Image, ImageOps
from PIL import Image import oneflow as flow
try: try:
import accimage import accimage
...@@ -30,6 +33,126 @@ def _is_pil_image(img: Any) -> bool: ...@@ -30,6 +33,126 @@ def _is_pil_image(img: Any) -> bool:
return isinstance(img, Image.Image) return isinstance(img, Image.Image)
def _get_image_size(img: Any) -> List[int]:
if _is_pil_image(img):
return img.size
raise TypeError("Unexpected type {}".format(type(img)))
def _get_image_num_channels(img: Any) -> int:
if _is_pil_image(img):
return 1 if img.mode == "L" else 3
def hflip(img):
if not _is_pil_image(img):
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
return img.transpose(Image.FLIP_LEFT_RIGHT)
def vflip(img):
if not _is_pil_image(img):
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
return img.transpose(Image.FLIP_TOP_BOTTOM)
def pad(img, padding, fill=0, padding_mode="constant"):
if not _is_pil_image(img):
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if not isinstance(fill, (numbers.Number, str, tuple)):
raise TypeError("Got inappropriate fill arg")
if not isinstance(padding_mode, str):
raise TypeError("Got inappropriate padding_mode arg")
if isinstance(padding, list):
padding = tuple(padding)
if isinstance(padding, tuple) and len(padding) not in [1, 2, 4]:
raise ValueError(
"Padding must be an int or a 1, 2, or 4 element tuple, not a "
+ "{} element tuple".format(len(padding))
)
if isinstance(padding, tuple) and len(padding) == 1:
# Compatibility with `functional_tensor.pad`
padding = padding[0]
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError(
"Padding mode should be either constant, edge, reflect or symmetric"
)
if padding_mode == "constant":
opts = _parse_fill(fill, img, name="fill")
if img.mode == "P":
palette = img.getpalette()
image = ImageOps.expand(img, border=padding, **opts)
image.putpalette(palette)
return image
return ImageOps.expand(img, border=padding, **opts)
else:
if isinstance(padding, int):
pad_left = pad_right = pad_top = pad_bottom = padding
if isinstance(padding, tuple) and len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
if isinstance(padding, tuple) and len(padding) == 4:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]
p = [pad_left, pad_top, pad_right, pad_bottom]
cropping = -np.minimum(p, 0)
if cropping.any():
crop_left, crop_top, crop_right, crop_bottom = cropping
img = img.crop(
(crop_left, crop_top, img.width - crop_right, img.height - crop_bottom)
)
pad_left, pad_top, pad_right, pad_bottom = np.maximum(p, 0)
if img.mode == "P":
palette = img.getpalette()
img = np.asarray(img)
img = np.pad(
img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode
)
img = Image.fromarray(img)
img.putpalette(palette)
return img
img = np.asarray(img)
# RGB image
if len(img.shape) == 3:
img = np.pad(
img,
((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)),
padding_mode,
)
# Grayscale image
if len(img.shape) == 2:
img = np.pad(
img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode
)
return Image.fromarray(img)
def crop(img: Image.Image, top: int, left: int, height: int, width: int) -> Image.Image:
if not _is_pil_image(img):
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
return img.crop((left, top, left + width, top + height))
def resize(img, size, interpolation=Image.BILINEAR): def resize(img, size, interpolation=Image.BILINEAR):
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError("img should be PIL Image. Got {}".format(type(img))) raise TypeError("img should be PIL Image. Got {}".format(type(img)))
...@@ -54,3 +177,31 @@ def resize(img, size, interpolation=Image.BILINEAR): ...@@ -54,3 +177,31 @@ def resize(img, size, interpolation=Image.BILINEAR):
return img.resize((ow, oh), interpolation) return img.resize((ow, oh), interpolation)
else: else:
return img.resize(size[::-1], interpolation) return img.resize(size[::-1], interpolation)
def _parse_fill(fill, img, name="fillcolor"):
# Process fill color for affine transforms
num_bands = len(img.getbands())
if fill is None:
fill = 0
if isinstance(fill, (int, float)) and num_bands > 1:
fill = tuple([fill] * num_bands)
if isinstance(fill, (list, tuple)):
if len(fill) != num_bands:
msg = (
"The number of elements in 'fill' does not match the number of "
"bands of the image ({} != {})"
)
raise ValueError(msg.format(len(fill), num_bands))
fill = tuple(fill)
return {name: fill}
def rotate(img, angle, interpolation=0, expand=False, center=None, fill=None):
if not _is_pil_image(img):
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
opts = _parse_fill(fill, img)
return img.rotate(angle, interpolation, expand, center, **opts)
...@@ -13,7 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -13,7 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
from typing import Tuple, List import warnings
from typing import Optional, Tuple, List
from oneflow.framework.tensor import Tensor from oneflow.framework.tensor import Tensor
import oneflow as flow import oneflow as flow
...@@ -43,6 +44,24 @@ def _get_image_num_channels(img: Tensor) -> int: ...@@ -43,6 +44,24 @@ def _get_image_num_channels(img: Tensor) -> int:
raise TypeError("Input ndim should be 2 or more. Got {}".format(img.ndim)) raise TypeError("Input ndim should be 2 or more. Got {}".format(img.ndim))
def _max_value(dtype: flow.dtype) -> float:
a = flow.tensor(2, dtype=dtype)
# TODO:Tensor.is_signed()
# signed = 1 if flow.tensor(0, dtype=dtype).is_signed() else 0
signed = 1
bits = 1
max_value = flow.tensor(-signed, dtype=flow.long)
while True:
next_value = a.pow(bits - signed).sub(1)
if next_value > max_value:
max_value = next_value
bits *= 2
else:
break
return max_value.item()
def _cast_squeeze_in( def _cast_squeeze_in(
img: Tensor, req_dtypes: List[flow.dtype] img: Tensor, req_dtypes: List[flow.dtype]
) -> Tuple[Tensor, bool, bool, flow.dtype]: ) -> Tuple[Tensor, bool, bool, flow.dtype]:
...@@ -76,6 +95,191 @@ def _cast_squeeze_out( ...@@ -76,6 +95,191 @@ def _cast_squeeze_out(
return img return img
def convert_image_dtype(
image: flow.Tensor, dtype: flow.dtype = flow.float
) -> flow.Tensor:
if image.dtype == dtype:
return image
if image.is_floating_point():
# TODO:Tensor.is_floating_point()
if flow.tensor(0, dtype=dtype).is_floating_point():
return image.to(dtype)
# float to int
if (image.dtype == flow.float32 and dtype in (flow.int32, flow.int64)) or (
image.dtype == flow.float64 and dtype == flow.int64
):
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
raise RuntimeError(msg)
# https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
# For data in the range 0-1, (float * 255).to(uint) is only 255
# when float is exactly 1.0.
# `max + 1 - epsilon` provides more evenly distributed mapping of
# ranges of floats to ints.
eps = 1e-3
max_val = _max_value(dtype)
result = image.mul(max_val + 1.0 - eps)
return result.to(dtype)
else:
input_max = _max_value(image.dtype)
# int to float
if flow.tensor(0, dtype=dtype).is_floating_point():
image = image.to(dtype)
return image / input_max
output_max = _max_value(dtype)
# int to int
if input_max > output_max:
factor = int((input_max + 1) // (output_max + 1))
image = flow.div(image, factor, rounding_mode="floor")
return image.to(dtype)
else:
factor = int((output_max + 1) // (input_max + 1))
image = image.to(dtype)
return image * factor
def vflip(img: Tensor) -> Tensor:
_assert_image_tensor(img)
return img.flip(-2)
def hflip(img: Tensor) -> Tensor:
_assert_image_tensor(img)
return img.flip(-1)
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
_assert_image_tensor(img)
w, h = _get_image_size(img)
right = left + width
bottom = top + height
if left < 0 or top < 0 or right > w or bottom > h:
padding_ltrb = [
max(-left, 0),
max(-top, 0),
max(right - w, 0),
max(bottom - h, 0),
]
return pad(
img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0
)
return img[..., top:bottom, left:right]
def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
# padding is left, right, top, bottom
# crop if needed
if padding[0] < 0 or padding[1] < 0 or padding[2] < 0 or padding[3] < 0:
crop_left, crop_right, crop_top, crop_bottom = [-min(x, 0) for x in padding]
img = img[
...,
crop_top : img.shape[-2] - crop_bottom,
crop_left : img.shape[-1] - crop_right,
]
padding = [max(x, 0) for x in padding]
in_sizes = img.size()
x_indices = [i for i in range(in_sizes[-1])] # [0, 1, 2, 3, ...]
left_indices = [i for i in range(padding[0] - 1, -1, -1)] # e.g. [3, 2, 1, 0]
right_indices = [-(i + 1) for i in range(padding[1])] # e.g. [-1, -2, -3]
x_indices = flow.tensor(left_indices + x_indices + right_indices, device=img.device)
y_indices = [i for i in range(in_sizes[-2])]
top_indices = [i for i in range(padding[2] - 1, -1, -1)]
bottom_indices = [-(i + 1) for i in range(padding[3])]
y_indices = flow.tensor(top_indices + y_indices + bottom_indices, device=img.device)
ndim = img.ndim
if ndim == 3:
return img[:, y_indices[:, None], x_indices[None, :]]
elif ndim == 4:
return img[:, :, y_indices[:, None], x_indices[None, :]]
else:
raise RuntimeError("Symmetric padding of N-D tensors are not supported yet")
def pad(
img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant"
) -> Tensor:
_assert_image_tensor(img)
if not isinstance(padding, (int, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if not isinstance(fill, (int, float)):
raise TypeError("Got inappropriate fill arg")
if not isinstance(padding_mode, str):
raise TypeError("Got inappropriate padding_mode arg")
if isinstance(padding, tuple):
padding = list(padding)
if isinstance(padding, list) and len(padding) not in [1, 2, 4]:
raise ValueError(
"Padding must be an int or a 1, 2, or 4 element tuple, not a "
+ "{} element tuple".format(len(padding))
)
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError(
"Padding mode should be either constant, edge, reflect or symmetric"
)
if isinstance(padding, int):
pad_left = pad_right = pad_top = pad_bottom = padding
elif len(padding) == 1:
pad_left = pad_right = pad_top = pad_bottom = padding[0]
elif len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
else:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]
p = [pad_left, pad_right, pad_top, pad_bottom]
if padding_mode == "edge":
# remap padding_mode str
padding_mode = "replicate"
elif padding_mode == "symmetric":
# route to another implementation
return _pad_symmetric(img, p)
need_squeeze = False
if img.ndim < 4:
img = img.unsqueeze(dim=0)
need_squeeze = True
out_dtype = img.dtype
need_cast = False
if (padding_mode != "constant") and img.dtype not in (flow.float32, flow.float64):
# Here we temporary cast input tensor to float
# until pytorch issue is resolved :
# https://github.com/pytorch/pytorch/issues/40763
need_cast = True
img = img.to(flow.float32)
img = flow.F.pad(img, pad=p, mode=padding_mode, value=float(fill))
if need_squeeze:
img = img.squeeze(dim=0)
if need_cast:
img = img.to(out_dtype)
return img
def resize(img: Tensor, size: List[int], interpolation: str = "bilinear") -> Tensor: def resize(img: Tensor, size: List[int], interpolation: str = "bilinear") -> Tensor:
_assert_image_tensor(img) _assert_image_tensor(img)
...@@ -121,7 +325,7 @@ def resize(img: Tensor, size: List[int], interpolation: str = "bilinear") -> Ten ...@@ -121,7 +325,7 @@ def resize(img: Tensor, size: List[int], interpolation: str = "bilinear") -> Ten
# Define align_corners to avoid warnings # Define align_corners to avoid warnings
align_corners = False if interpolation in ["bilinear", "bicubic"] else None align_corners = False if interpolation in ["bilinear", "bicubic"] else None
img = flow.F.interpolate( img = flow.nn.functional.interpolate(
img, size=[size_h, size_w], mode=interpolation, align_corners=align_corners img, size=[size_h, size_w], mode=interpolation, align_corners=align_corners
) )
...@@ -133,3 +337,48 @@ def resize(img: Tensor, size: List[int], interpolation: str = "bilinear") -> Ten ...@@ -133,3 +337,48 @@ def resize(img: Tensor, size: List[int], interpolation: str = "bilinear") -> Ten
) )
return img return img
def _assert_grid_transform_inputs(
img: Tensor,
matrix: Optional[List[float]],
interpolation: str,
fill: Optional[List[float]],
supported_interpolation_modes: List[str],
coeffs: Optional[List[float]] = None,
):
if not (isinstance(img, flow.Tensor)):
raise TypeError("Input img should be Tensor")
_assert_image_tensor(img)
if matrix is not None and not isinstance(matrix, list):
raise TypeError("Argument matrix should be a list")
if matrix is not None and len(matrix) != 6:
raise ValueError("Argument matrix should have 6 float values")
if coeffs is not None and len(coeffs) != 8:
raise ValueError("Argument coeffs should have 8 float values")
if fill is not None and not isinstance(fill, (int, float, tuple, list)):
warnings.warn("Argument fill should be either int, float, tuple or list")
# Check fill
num_channels = _get_image_num_channels(img)
if isinstance(fill, (tuple, list)) and (
len(fill) > 1 and len(fill) != num_channels
):
msg = (
"The number of elements in 'fill' cannot broadcast to match the number of "
"channels of the image ({} != {})"
)
raise ValueError(msg.format(len(fill), num_channels))
if interpolation not in supported_interpolation_modes:
raise ValueError(
"Interpolation mode '{}' is unsupported with Tensor input".format(
interpolation
)
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册