未验证 提交 c7ba8c44 编写于 作者: S shiyutang 提交者: GitHub

[Feature] Add module test process of mobilenetv3 (#5442)

* add_readme

* update

* update_dir

* add_falsely_delete_README

* update
上级 fbccf996
import torch
import paddle
import numpy as np
from reprod_log import ReprodDiffHelper
from reprod_log import ReprodLogger
from mobilenetv3_paddle.paddlevision.models import mobilenet_v3_small as mv3_small_paddle
from mobilenetv3_ref.torchvision.models import mobilenet_v3_small as mv3_small_torch
def test_forward():
# load paddle model
paddle_model = mv3_small_paddle()
paddle_model.eval()
paddle_state_dict = paddle.load("./data/mv3_small_paddle.pdparams")
paddle_model.set_dict(paddle_state_dict)
# load torch model
torch_model = mv3_small_torch()
torch_model.eval()
torch_state_dict = torch.load("./data/mobilenet_v3_small-047dcff4.pth")
torch_model.load_state_dict(torch_state_dict)
# load data
inputs = np.load("./data/fake_data.npy")
# save the paddle output
reprod_logger = ReprodLogger()
paddle_out = paddle_model(paddle.to_tensor(inputs, dtype="float32"))
reprod_logger.add("logits", paddle_out.cpu().detach().numpy())
reprod_logger.save("./result/forward_paddle.npy")
# save the torch output
torch_out = torch_model(torch.tensor(inputs, dtype=torch.float32))
reprod_logger.add("logits", torch_out.cpu().detach().numpy())
reprod_logger.save("./result/forward_ref.npy")
if __name__ == "__main__":
test_forward()
# load data
diff_helper = ReprodDiffHelper()
torch_info = diff_helper.load_info("./result/forward_ref.npy")
paddle_info = diff_helper.load_info("./result/forward_paddle.npy")
# compare result and produce log
diff_helper.compare_info(torch_info, paddle_info)
diff_helper.report(path="./result/log/forward_diff.log")
import os
import sys
import torch
import paddle
import numpy as np
from PIL import Image
from reprod_log import ReprodLogger, ReprodDiffHelper
import mobilenetv3_paddle.presets as presets_paddle
import mobilenetv3_paddle.paddlevision as paddlevision
import mobilenetv3_ref.presets as presets_torch
import mobilenetv3_ref.torchvision as torchvision
def build_paddle_data_pipeline():
# dataset & data_loader
dataset_test = paddlevision.datasets.ImageFolder(
"./lite_data/val/",
presets_paddle.ClassificationPresetEval(
crop_size=224, resize_size=256))
test_sampler = paddle.io.SequenceSampler(dataset_test)
test_batch_sampler = paddle.io.BatchSampler(
sampler=test_sampler, batch_size=4)
data_loader_test = paddle.io.DataLoader(
dataset_test, batch_sampler=test_batch_sampler, num_workers=0)
return dataset_test, data_loader_test
def build_torch_data_pipeline():
dataset_test = torchvision.datasets.ImageFolder(
"./lite_data/val/",
presets_torch.ClassificationPresetEval(
crop_size=224, resize_size=256),
is_valid_file=None)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
data_loader_test = torch.utils.data.DataLoader(
dataset_test,
batch_size=4,
sampler=test_sampler,
num_workers=0,
pin_memory=True)
return dataset_test, data_loader_test
def test_data_pipeline():
paddle_dataset, paddle_dataloader = build_paddle_data_pipeline()
torch_dataset, torch_dataloader = build_torch_data_pipeline()
logger_paddle_data = ReprodLogger()
logger_torch_data = ReprodLogger()
logger_paddle_data.add("length", np.array(len(paddle_dataset)))
logger_torch_data.add("length", np.array(len(torch_dataset)))
for idx, (paddle_batch, torch_batch
) in enumerate(zip(paddle_dataloader, torch_dataloader)):
if idx >= 5:
break
logger_paddle_data.add(f"dataloader_{idx}", paddle_batch[0].numpy())
logger_torch_data.add(f"dataloader_{idx}",
torch_batch[0].detach().cpu().numpy())
logger_paddle_data.save("./result/data_paddle.npy")
logger_torch_data.save("./result/data_ref.npy")
if __name__ == "__main__":
test_data_pipeline()
# load data
diff_helper = ReprodDiffHelper()
torch_info = diff_helper.load_info("./result/data_ref.npy")
paddle_info = diff_helper.load_info("./result/data_paddle.npy")
# compare result and produce log
diff_helper.compare_info(torch_info, paddle_info)
diff_helper.report(path="./result/log/data_diff.log")
# add test metric code paddle vs torch
import torch
import paddle
import numpy as np
from reprod_log import ReprodLogger
from reprod_log import ReprodDiffHelper
from mobilenetv3_paddle.paddlevision.models import mobilenet_v3_small as mv3_small_paddle
from mobilenetv3_ref.torchvision.models import mobilenet_v3_small as mv3_small_torch
from mobilenetv3_ref import accuracy_torch
from mobilenetv3_paddle import accuracy_paddle
def evaluate(image, labels, model, acc, tag, reprod_logger):
model.eval()
output = model(image)
accracy = acc(output, labels, topk=(1, 5))
reprod_logger.add("acc_top1", np.array(accracy[0]))
reprod_logger.add("acc_top5", np.array(accracy[1]))
reprod_logger.save("./result/metric_{}.npy".format(tag))
def test_forward():
# load paddle model
paddle_model = mv3_small_paddle()
paddle_model.eval()
paddle_state_dict = paddle.load("./data/mv3_small_paddle.pdparams")
paddle_model.set_dict(paddle_state_dict)
# load torch model
torch_model = mv3_small_torch()
torch_model.eval()
torch_state_dict = torch.load("./data/mobilenet_v3_small-047dcff4.pth")
torch_model.load_state_dict(torch_state_dict)
# prepare logger & load data
reprod_logger = ReprodLogger()
inputs = np.load("./data/fake_data.npy")
labels = np.load("./data/fake_label.npy")
image = paddle.to_tensor(inputs, dtype="float32")
target = paddle.to_tensor(labels, dtype="int64")
evaluate(
paddle.to_tensor(
inputs, dtype="float32"),
paddle.to_tensor(
labels, dtype="int64"),
paddle_model,
accuracy_paddle,
'paddle',
reprod_logger)
evaluate(
torch.tensor(
inputs, dtype=torch.float32),
torch.tensor(
labels, dtype=torch.int64),
torch_model,
accuracy_torch,
'ref',
reprod_logger)
if __name__ == "__main__":
test_forward()
# load data
diff_helper = ReprodDiffHelper()
torch_info = diff_helper.load_info("./result/metric_ref.npy")
paddle_info = diff_helper.load_info("./result/metric_paddle.npy")
# compare result and produce log
diff_helper.compare_info(torch_info, paddle_info)
diff_helper.report(path="./result/log/metric_diff.log")
# add loss comparing code
import torch
import paddle
import numpy as np
from reprod_log import ReprodLogger
from reprod_log import ReprodDiffHelper
from mobilenetv3_paddle.paddlevision.models import mobilenet_v3_small as mv3_small_paddle
from mobilenetv3_ref.torchvision.models import mobilenet_v3_small as mv3_small_torch
def test_forward():
# init loss
criterion_paddle = paddle.nn.CrossEntropyLoss()
criterion_torch = torch.nn.CrossEntropyLoss()
# load paddle model
paddle_model = mv3_small_paddle()
paddle_model.eval()
paddle_state_dict = paddle.load("./data/mv3_small_paddle.pdparams")
paddle_model.set_dict(paddle_state_dict)
# load torch model
torch_model = mv3_small_torch()
torch_model.eval()
torch_state_dict = torch.load("./data/mobilenet_v3_small-047dcff4.pth")
torch_model.load_state_dict(torch_state_dict)
# prepare logger & load data
reprod_logger = ReprodLogger()
inputs = np.load("./data/fake_data.npy")
labels = np.load("./data/fake_label.npy")
# save the paddle output
paddle_out = paddle_model(paddle.to_tensor(inputs, dtype="float32"))
loss_paddle = criterion_paddle(
paddle_out, paddle.to_tensor(
labels, dtype="int64"))
reprod_logger.add("loss", loss_paddle.cpu().detach().numpy())
reprod_logger.save("./result/loss_paddle.npy")
# save the torch output
torch_out = torch_model(torch.tensor(inputs, dtype=torch.float32))
loss_torch = criterion_torch(
torch_out, torch.tensor(
labels, dtype=torch.int64))
reprod_logger.add("loss", loss_torch.cpu().detach().numpy())
reprod_logger.save("./result/loss_ref.npy")
if __name__ == "__main__":
test_forward()
# load data
diff_helper = ReprodDiffHelper()
torch_info = diff_helper.load_info("./result/loss_ref.npy")
paddle_info = diff_helper.load_info("./result/loss_paddle.npy")
# compare result and produce log
diff_helper.compare_info(torch_info, paddle_info)
diff_helper.report(path="./result/log/loss_diff.log")
import paddle
import numpy as np
import torch
import torch.optim.lr_scheduler as lr_scheduler
from reprod_log import ReprodLogger
from reprod_log import ReprodDiffHelper
from mobilenetv3_paddle.paddlevision.models import mobilenet_v3_small as mv3_small_paddle
from mobilenetv3_ref.torchvision.models import mobilenet_v3_small as mv3_small_torch
def train_one_epoch_paddle(inputs, labels, model, criterion, optimizer,
lr_scheduler, max_iter, reprod_logger):
for idx in range(max_iter):
image = paddle.to_tensor(inputs, dtype="float32")
target = paddle.to_tensor(labels, dtype="int64")
# import pdb; pdb.set_trace()
output = model(image)
loss = criterion(output, target)
reprod_logger.add("loss_{}".format(idx), loss.cpu().detach().numpy())
reprod_logger.add("lr_{}".format(idx), np.array(lr_scheduler.get_lr()))
optimizer.clear_grad()
loss.backward()
optimizer.step()
# lr_scheduler.step()
reprod_logger.save("./result/losses_paddle.npy")
def train_one_epoch_torch(inputs, labels, model, criterion, optimizer,
lr_scheduler, max_iter, reprod_logger):
for idx in range(max_iter):
image = torch.tensor(inputs, dtype=torch.float32).cuda()
target = torch.tensor(labels, dtype=torch.int64).cuda()
model = model.cuda()
output = model(image)
loss = criterion(output, target)
reprod_logger.add("loss_{}".format(idx), loss.cpu().detach().numpy())
reprod_logger.add("lr_{}".format(idx),
np.array(lr_scheduler.get_last_lr()))
optimizer.zero_grad()
loss.backward()
optimizer.step()
# lr_scheduler.step()
reprod_logger.save("./result/losses_ref.npy")
def test_backward():
max_iter = 3
lr = 1e-3
momentum = 0.9
lr_gamma = 0.1
# set determinnistic flag
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
FLAGS_cudnn_deterministic = True
# load paddle model
paddle.set_device("gpu")
paddle_model = mv3_small_paddle(dropout=0.0)
paddle_model.eval()
paddle_state_dict = paddle.load("./data/mv3_small_paddle.pdparams")
paddle_model.set_dict(paddle_state_dict)
# load torch model
torch_model = mv3_small_torch(dropout=0.0)
torch_model.eval()
torch_state_dict = torch.load("./data/mobilenet_v3_small-047dcff4.pth")
torch_model.load_state_dict(torch_state_dict, strict=False)
# init loss
criterion_paddle = paddle.nn.CrossEntropyLoss()
criterion_torch = torch.nn.CrossEntropyLoss()
# init optimizer
lr_scheduler_paddle = paddle.optimizer.lr.StepDecay(
lr, step_size=max_iter // 3, gamma=lr_gamma)
opt_paddle = paddle.optimizer.Momentum(
learning_rate=lr,
momentum=momentum,
parameters=paddle_model.parameters())
opt_torch = torch.optim.SGD(torch_model.parameters(),
lr=lr,
momentum=momentum)
lr_scheduler_torch = lr_scheduler.StepLR(
opt_torch, step_size=max_iter // 3, gamma=lr_gamma)
# prepare logger & load data
reprod_logger = ReprodLogger()
inputs = np.load("./data/fake_data.npy")
labels = np.load("./data/fake_label.npy")
train_one_epoch_paddle(inputs, labels, paddle_model, criterion_paddle,
opt_paddle, lr_scheduler_paddle, max_iter,
reprod_logger)
train_one_epoch_torch(inputs, labels, torch_model, criterion_torch,
opt_torch, lr_scheduler_torch, max_iter,
reprod_logger)
if __name__ == "__main__":
test_backward()
# load data
diff_helper = ReprodDiffHelper()
torch_info = diff_helper.load_info("./result/losses_ref.npy")
paddle_info = diff_helper.load_info("./result/losses_paddle.npy")
# compare result and produce log
diff_helper.compare_info(torch_info, paddle_info)
diff_helper.report(path="./result/log/backward_diff.log")
# MobileNetV3
## 目录
- [1. 简介](#1)
- [2. 复现流程](#2)
- [2.1 reprod_log简介](#2.1)
- [3. 准备数据与环境](#3)
- [3.1 准备环境](#3.1)
- [3.2 生成伪数据](#3.2)
- [3.3 准备模型](#3.3)
- [4. 开始使用](#4)
- [4.1 模型前向对齐](#4.1)
- [4.2 数据加载对齐](#4.2)
- [4.3 评估指标对齐](#4.3)
- [4.4 损失对齐](#4.4)
- [4.5 反向梯度对齐](#4.5)
- [4.6 训练对齐](#4.6)
<a name="1"></a>
## 1. 简介
本部分内容包含基于 [MobileNetV3](https://arxiv.org/abs/1905.02244) 的复现对齐过程,可以结合[论文复现指南]()进行学习。
<a name="2"></a>
## 2. 复现流程
在论文复现中我们可以根据网络训练的流程,将对齐流程划分为数据加载对齐、模型前向对齐、评估指标对齐、反向梯度对齐和训练对齐。其中不同对齐部分我们会在下方详细介绍。
在对齐验证的流程中,我们依靠 reprod_log 日志工具查看 paddle 和官方同样输入下的输出是否相同,这样的查看方式具有标准统一,比较过程方便等优势。
<a name="2.1"></a>
### 2.1 reprod_log 简介
Reprod_log 是一个用于 numpy 数据记录和对比工具,通过传入需要对比的两个 numpy 数组就可以在指定的规则下得到数据之差是否满足期望的结论。其主要接口的说明可以看它的 [github 主页](https://github.com/WenmuZhou/reprod_log)
<a name="3"></a>
## 3. 准备数据和环境
在进行我们的对齐验证之前,我们需要准备运行环境、用于输入的伪数据、paddle 模型参数和官方模型权重参数。
<a name="3.1"></a>
### 3.1 准备环境
* 克隆本项目
```bash
git clone https://github.com/PaddlePaddle/models.git
cd model/tutorials/mobilenetv3_prod/
```
* 安装paddlepaddle
```bash
# 需要安装2.2及以上版本的Paddle,如果
# 安装GPU版本的Paddle
pip install paddlepaddle-gpu==2.2.0
# 安装CPU版本的Paddle
pip install paddlepaddle==2.2.0
```
更多版本或者环境下的安装可以参考:[Paddle安装指南](https://www.paddlepaddle.org.cn/)
* 安装requirements
```bash
pip install -r requirements.txt
```
<a name="3.2"></a>
### 3.2 生成伪数据
为了保证模型对齐不会受到输入数据的影响,我们生成一组数据作为两个模型的输入。
伪数据可以通过如下代码生成,我们在本地目录下也提供了好的伪数据(./data/fake_*.npy)。
```python
def gen_fake_data():
fake_data = np.random.rand(1, 3, 224, 224).astype(np.float32) - 0.5
fake_label = np.arange(1).astype(np.int64)
np.save("fake_data.npy", fake_data)
np.save("fake_label.npy", fake_label)
```
<a name="3.3"></a>
### 3.3 准备模型
为了保证模型前向对齐不受到模型参数不一致的影响,我们使用相同的权重参数对模型进行初始化。
生成相同权重参数分为以下 2 步:
1. 随机初始化官方模型参数并保存成 mobilenet_v3_small-047dcff4.pth
2. 将 model.pth 通过 ./torch2paddle.py 生成mv3_small_paddle.pdparams
转换模型时,torch 和 paddle 存在参数需要转换的部分,主要是bn层、全连接层、num_batches_tracked等,可以参见转换脚本(./torch2paddle.py)。
<a name="4"></a>
## 4. 开始使用
准备好数据之后,我们通过下面对应训练流程的拆解步骤进行复现对齐。
<a name="4.1"></a>
### 4.1 模型前向对齐
论文复现中,最重要的来到前向对齐的验证,验证流程如下图所示:
<div align="center">
<img src="./images/forward.png" width=500">
</div>
这里,为了判断判断模型组网部分能获得和原论文同样的输出,我们将两个模型参数固定,并输入相同伪数据,观察 paddle 模型产出的 logit 是否和官方模型一致。
我们的示例代码如下所示:
```python
def test_forward():
# load paddle model
paddle_model = mv3_small_paddle()
paddle_model.eval()
paddle_state_dict = paddle.load("./data/mv3_small_paddle.pdparams")
paddle_model.set_dict(paddle_state_dict)
# load torch model
torch_model = mv3_small_torch()
torch_model.eval()
torch_state_dict = torch.load("./data/mobilenet_v3_small-047dcff4.pth")
torch_model.load_state_dict(torch_state_dict)
# load data
inputs = np.load("./data/fake_data.npy")
# save the paddle output
reprod_logger = ReprodLogger()
paddle_out = paddle_model(paddle.to_tensor(inputs, dtype="float32"))
reprod_logger.add("logits", paddle_out.cpu().detach().numpy())
reprod_logger.save("./result/forward_paddle.npy")
# save the torch output
torch_out = torch_model(torch.tensor(inputs, dtype=torch.float32))
reprod_logger.add("logits", torch_out.cpu().detach().numpy())
reprod_logger.save("./result/forward_torch.npy")
```
可以看到,我们在代码中加载准备的相同的模型参数、并固定输入,从而获得两个模型的输出。输出结果使用相同的 key 值存到 numpy 文件中,随后使用下列代码加载并比较:
```python
# load data
diff_helper = ReprodDiffHelper()
torch_info = diff_helper.load_info("./result/forward_torch.npy")
paddle_info = diff_helper.load_info("./result/forward_paddle.npy")
# compare result and produce log
diff_helper.compare_info(torch_info, paddle_info)
diff_helper.report(path="./result/log/forward_diff.log")
```
在代码示例中也可以学习到 reprod_log的主要接口,包含add、save、load_infor、compare_infor、report的用法。
**运行文件**
通过运行以下代码,我们验证前向对齐效果。
```bash
cd models/tutorials/mobilenetv3_prod/
python 01_test_forward.py
```
**获得结果**
根据示例代码可以看到,我们将结果保存在`result/log/forward_diff.log`中,打开对应文件或者直接观察命令行输出,就会有下列结果:
```bash
[2021/12/21 15:00:38] root INFO: logits:
[2021/12/21 15:00:38] root INFO: mean diff: check passed: False, value: 2.308018565599923e-06
[2021/12/21 15:00:38] root INFO: diff check failed
```
这里我们发现在`reprod_log`默认的平均差异小于1e-6的标准下,当前前向对齐是不符合条件的,但是这是由于前向 op 计算导致的微小的差异。
一般说来前向误差在 1e-5 左右都是可以接受的,到这里我们就验证了网络的前向是对齐的,完成了第一个打卡点。
<a name="4.2"></a>
### 4.2 数据加载对齐
在验证了模型的前向对齐之后,我们验证数据读取部分,这一部分,我们比较从数据读取到模型传入之间我们进行的操作是否和参考操作一致。
主要代码如下所示,我们读取相同的输入,比较数据增强后输出之间的差异,即可知道我们的数据增强是否和参考实现保持一致:
```python
def build_torch_data_pipeline():
dataset_test = torchvision.datasets.ImageFolder(
"./lite_data/val/",
presets_torch.ClassificationPresetEval(
crop_size=224, resize_size=256), is_valid_file=None)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
data_loader_test = torch.utils.data.DataLoader(
dataset_test,
batch_size=4,
sampler=test_sampler,
num_workers=0,
pin_memory=True)
return dataset_test, data_loader_test
def test_data_pipeline():
paddle_dataset, paddle_dataloader = build_paddle_data_pipeline()
torch_dataset, torch_dataloader = build_torch_data_pipeline()
logger_paddle_data = ReprodLogger()
logger_torch_data = ReprodLogger()
logger_paddle_data.add("length", np.array(len(paddle_dataset)))
logger_torch_data.add("length", np.array(len(torch_dataset)))
for idx, (paddle_batch, torch_batch
) in enumerate(zip(paddle_dataloader, torch_dataloader)):
if idx >= 5:
break
logger_paddle_data.add(f"dataloader_{idx}", paddle_batch[0].numpy())
logger_torch_data.add(f"dataloader_{idx}",
torch_batch[0].detach().cpu().numpy())
logger_paddle_data.save("./result/data_paddle.npy")
logger_torch_data.save("./result/data_ref.npy")
```
**运行文件**
通过运行以下指令,我们进行测试,测试数据可以解压我们准备的 [lite_data.tar](https://github.com/PaddlePaddle/models/blob/release%2F2.2/tutorials/mobilenetv3_prod/Step6/test_images/lite_data.tar) 获得,对于自身的数据,也可以抽取几张 validationset 的图片用作验证。
```python
cd models/tutorials/mobilenetv3_prod/
tar -xvf lite_data.rar
python 02_test_data.py
```
**获得结果**
运行文件之后,我们获得以下命令行输出,可以发现我们的验证结果满足预期,数据加载部分验证通过:
```bash
[2021/12/23 17:21:22] root INFO: length:
[2021/12/23 17:21:22] root INFO: mean diff: check passed: True, value: 0.0
[2021/12/23 17:21:22] root INFO: dataloader_0:
[2021/12/23 17:21:22] root INFO: mean diff: check passed: True, value: 0.0
[2021/12/23 17:21:22] root INFO: dataloader_1:
[2021/12/23 17:21:22] root INFO: mean diff: check passed: True, value: 0.0
[2021/12/23 17:21:22] root INFO: dataloader_2:
[2021/12/23 17:21:22] root INFO: mean diff: check passed: True, value: 0.0
[2021/12/23 17:21:22] root INFO: dataloader_3:
[2021/12/23 17:21:22] root INFO: mean diff: check passed: True, value: 0.0
[2021/12/23 17:21:22] root INFO: diff check passed
```
<a name="4.3"></a>
### 4.3 评估指标对齐
随后我们来到评估指标对齐,对齐流程如图所示:
<div align="center">
<img src="./images/metric.png" width=500">
</div>
这部分的对齐流程主要差异在于我们在模型基础上添加了对应参考代码实现 metric,并导入到测试文件中。在论文复现中,我们尽量将模型的不同部分封装起来,之后就可以通过我们这样导入的方式进行验证。
这部分的参考代码如下:
```python
def evaluate(image, labels, model, acc, tag, reprod_logger):
model.eval()
output = model(image)
accracy = acc(output, labels, topk=(1, 5))
reprod_logger.add("acc_top1", np.array(accracy[0]))
reprod_logger.add("acc_top5", np.array(accracy[1]))
reprod_logger.save("./result/metric_{}.npy".format(tag))
def test_forward():
# load model & data
evaluate(
paddle.to_tensor(
inputs, dtype="float32"),
paddle.to_tensor(
labels, dtype="int64"),
paddle_model,
accuracy_paddle,
'paddle', reprod_logger)
evaluate(
torch.tensor(
inputs, dtype=torch.float32),
torch.tensor(
labels, dtype=torch.int64),
torch_model,
accuracy_torch,
'ref', reprod_logger)
```
这部分模型和输入的导入均和之前一致,只是在之前的基础上增加了模型计算评估指标的部分。
由于我们之前验证了模型的输出一致。那么也就是评估指标的输入相同,我们只需要对比输出是否一致,即可确定评估指标的实现是否正确。
**运行文件**
通过运行以下代码,我们验证评估指标对齐效果。
```bash
cd models/tutorials/mobilenetv3_prod/
python 03_test_metric.py
```
**获得结果**
进入`result/log/metric_diff.log`中,就会有下列结果,而结果说明我们评估指标的实现正确, 从而完成第二个打卡点:
```bash
[2021/12/21 19:28:49] root INFO: acc_top1:
[2021/12/21 19:28:49] root INFO: mean diff: check passed: True, value: 0.0
[2021/12/21 19:28:49] root INFO: acc_top5:
[2021/12/21 19:28:49] root INFO: mean diff: check passed: True, value: 0.0
[2021/12/21 19:28:49] root INFO: diff check passed
```
<a name="4.4"></a>
### 4.4 损失对齐
进一步,我们验证损失实现的正确性,验证流程如下:
<div align="center">
<img src="./images/losses.png" width=500">
</div>
这部分的对齐流程主要差异在于我们在模型基础上添加了对应参考代码实现的 loss。这部分的参考代码如下:
```python
def test_forward():
# init loss
criterion_paddle = paddle.nn.CrossEntropyLoss()
criterion_torch = torch.nn.CrossEntropyLoss()
# load model & data
# save the paddle output
paddle_out = paddle_model(paddle.to_tensor(inputs, dtype="float32"))
loss_paddle = criterion_paddle(
paddle_out, paddle.to_tensor(
labels, dtype="int64"))
reprod_logger.add("loss", loss_paddle.cpu().detach().numpy())
reprod_logger.save("./result/loss_paddle.npy")
# save the torch output
torch_out = torch_model(torch.tensor(inputs, dtype=torch.float32))
loss_torch = criterion_torch(
torch_out, torch.tensor(
labels, dtype=torch.int64))
reprod_logger.add("loss", loss_torch.cpu().detach().numpy())
reprod_logger.save("./result/loss_ref.npy")
```
这部分代码进一步增加损失导入的部分,由于我们之前验证了模型的输出一致,也就是损失的输入相同,我们只需要对比输出是否一致,即可确定损失的实现是否正确。
**运行文件**
通过运行以下代码,我们验证评估指标对齐效果。
```bash
cd models/tutorials/mobilenetv3_prod/
python 04_test_loss.py
```
**获得结果**
进入`result/log/loss_diff.log`中,就会有下列结果,而结果说明我们评估指标的实现正确,完成第三个打卡点:
```bash
[2021/12/22 20:13:41] root INFO: loss:
[2021/12/22 20:13:41] root INFO: mean diff: check passed: True, value: 0.0
[2021/12/22 20:13:41] root INFO: diff check passed
```
<a name="4.5"></a>
### 4.5 反向梯度对齐
结合模型和损失,我们就可以验证反向过程,反向梯度传导的是否正确包含了优化器,学习率以及梯度的计算,而验证过程只需要多观察几轮损失即可明确反向是否正确传导,主要验证流程如下所示:
<div align="center">
<img src="./images/backward.png" width=500">
</div>
以上参考流程可以使用以下代码实现:
```python
def train_one_epoch_torch(inputs, labels, model, criterion, optimizer,
max_iter, reprod_logger):
for idx in range(max_iter):
image = torch.tensor(inputs, dtype=torch.float32).cuda()
target = torch.tensor(labels, dtype=torch.int64).cuda()
model = model.cuda()
output = model(image)
loss = criterion(output, target)
reprod_logger.add("loss_{}".format(idx), loss.cpu().detach().numpy())
optimizer.zero_grad()
loss.backward()
optimizer.step()
reprod_logger.save("./result/losses_ref.npy")
def test_backward():
max_iter = 3
lr = 1e-3
momentum = 0.9
# load model, loss, data
# init optimizer
opt_paddle = paddle.optimizer.Momentum(
learning_rate=lr,
momentum=momentum,
parameters=paddle_model.parameters())
opt_torch = torch.optim.SGD(torch_model.parameters(), lr=lr, momentum=momentum)
train_one_epoch_paddle(inputs, labels, paddle_model, criterion_paddle,
opt_paddle, max_iter, reprod_logger)
train_one_epoch_torch(inputs, labels, torch_model, criterion_torch,
opt_torch, max_iter, reprod_logger)
```
代码中增加了optimizer用于迭代网络参数,其他则基本一致。
**运行文件**
通过运行以下代码,我们验证反向传播对齐效果。
```bash
cd models/tutorials/mobilenetv3_prod/
python 05_test_backward.py
```
**获得结果**
进入`result/log/loss_diff.log`中,就会有下列结果,结果表示三轮损失的差异在 1e-6 附近,说明我们反向传播的实现对齐, 完成第四个打卡点:
```bash
[2021/12/23 15:51:16] root INFO: loss_0:
[2021/12/23 15:51:16] root INFO: mean diff: check passed: False, value: 1.9073486328125e-06
[2021/12/23 15:51:16] root INFO: lr_0:
[2021/12/23 15:51:16] root INFO: mean diff: check passed: True, value: 0.0
[2021/12/23 15:51:16] root INFO: loss_1:
[2021/12/23 15:51:16] root INFO: mean diff: check passed: False, value: 2.384185791015625e-06
[2021/12/23 15:51:16] root INFO: lr_1:
[2021/12/23 15:51:16] root INFO: mean diff: check passed: True, value: 0.0
[2021/12/23 15:51:16] root INFO: loss_2:
[2021/12/23 15:51:16] root INFO: mean diff: check passed: False, value: 1.1920928955078125e-05
[2021/12/23 15:51:16] root INFO: lr_2:
[2021/12/23 15:51:16] root INFO: mean diff: check passed: True, value: 0.0
[2021/12/23 15:51:16] root INFO: diff check failed
```
<a name="4.6"></a>
### 4.6 训练对齐
通过以上步骤,我们验证了模型、数据、评估指标、损失、反向传播的正确性,也就为我们的训练对齐打下了良好的基础。
接下来,我们按照以下流程验证训练对齐结果,即对网络进行训练,并在训练后验证精度是否达到指标:
<div align="center">
<img src="./images/train.png" width=500">
</div>
我们可以使用reprd logger对比精度,也可以直接肉眼观察结果对比:
```python
if paddle.distributed.get_rank() == 0:
reprod_logger = ReprodLogger()
reprod_logger.add("top1", np.array([top1]))
reprod_logger.save("train_align_paddle.npy")
```
**运行文件**
```bash
cd models/tutorials/mobilenetv3_prod/Checkpoint6
python train.py
```
**获得结果**
最终训练精度超过原模型精度,我们的复现到这里就圆满结束,如果还有任何问题,欢迎随时向我们[提问](https://github.com/PaddlePaddle/Paddle/issues)
from .metric import accuracy_paddle
from .presets import *
\ No newline at end of file
import paddle
def accuracy_paddle(output, target, topk=(1, )):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with paddle.no_grad():
maxk = max(topk)
batch_size = target.shape[0]
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.equal(target)
res = []
for k in topk:
correct_k = correct.astype(paddle.int32)[:k].flatten().sum(
dtype='float32')
res.append(correct_k / batch_size)
return res
from .datasets import *
from .models import *
from .transforms import *
\ No newline at end of file
from .folder import ImageFolder, DatasetFolder
from .vision import VisionDataset
__all__ = ('ImageFolder', 'DatasetFolder', 'VisionDataset')
\ No newline at end of file
from .vision import VisionDataset
from PIL import Image
import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
def has_file_allowed_extension(filename: str,
extensions: Tuple[str, ...]) -> bool:
"""Checks if a file is an allowed extension.
Args:
filename (string): path to a file
extensions (tuple of strings): extensions to consider (lowercase)
Returns:
bool: True if the filename ends with one of given extensions
"""
return filename.lower().endswith(extensions)
def is_image_file(filename: str) -> bool:
"""Checks if a file is an allowed image extension.
Args:
filename (string): path to a file
Returns:
bool: True if the filename ends with a known image extension
"""
return has_file_allowed_extension(filename, IMG_EXTENSIONS)
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Finds the class folders in a dataset.
See :class:`DatasetFolder` for details.
"""
classes = sorted(
entry.name for entry in os.scandir(directory) if entry.is_dir())
if not classes:
raise FileNotFoundError(
f"Couldn't find any class folder in {directory}.")
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
def make_dataset(
directory: str,
class_to_idx: Optional[Dict[str, int]]=None,
extensions: Optional[Tuple[str, ...]]=None,
is_valid_file: Optional[Callable[[str], bool]]=None, ) -> List[Tuple[
str, int]]:
"""Generates a list of samples of a form (path_to_sample, class).
See :class:`DatasetFolder` for details.
Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
by default.
"""
directory = os.path.expanduser(directory)
if class_to_idx is None:
_, class_to_idx = find_classes(directory)
elif not class_to_idx:
raise ValueError(
"'class_to_index' must have at least one entry to collect any samples."
)
both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError(
"Both extensions and is_valid_file cannot be None or not None at the same time"
)
if extensions is not None:
def is_valid_file(x: str) -> bool:
return has_file_allowed_extension(
x, cast(Tuple[str, ...], extensions))
is_valid_file = cast(Callable[[str], bool], is_valid_file)
instances = []
available_classes = set()
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
if not os.path.isdir(target_dir):
continue
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
for fname in sorted(fnames):
if is_valid_file(fname):
path = os.path.join(root, fname)
item = path, class_index
instances.append(item)
if target_class not in available_classes:
available_classes.add(target_class)
# print(fname)
# exit()
# empty_classes = set(class_to_idx.keys()) - available_classes
# if empty_classes:
# msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
# if extensions is not None:
# msg += f"Supported extensions are: {', '.join(extensions)}"
# raise FileNotFoundError(msg)
return instances
class DatasetFolder(VisionDataset):
"""A generic data loader.
This default directory structure can be customized by overriding the
:meth:`find_classes` method.
Args:
root (string): Root directory path.
loader (callable): A function to load a sample given its path.
extensions (tuple[string]): A list of allowed extensions.
both extensions and is_valid_file should not be passed.
transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.RandomCrop`` for images.
target_transform (callable, optional): A function/transform that takes
in the target and transforms it.
is_valid_file (callable, optional): A function that takes path of a file
and check if the file is a valid file (used to check of corrupt files)
both extensions and is_valid_file should not be passed.
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
samples (list): List of (sample path, class_index) tuples
targets (list): The class_index value for each image in the dataset
"""
def __init__(
self,
root: str,
loader: Callable[[str], Any],
extensions: Optional[Tuple[str, ...]]=None,
transform: Optional[Callable]=None,
target_transform: Optional[Callable]=None,
is_valid_file: Optional[Callable[[str], bool]]=None, ) -> None:
super(DatasetFolder, self).__init__(
root, transform=transform, target_transform=target_transform)
classes, class_to_idx = self.find_classes(self.root)
samples = self.make_dataset(self.root, class_to_idx, extensions,
is_valid_file)
self.loader = loader
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
@staticmethod
def make_dataset(
directory: str,
class_to_idx: Dict[str, int],
extensions: Optional[Tuple[str, ...]]=None,
is_valid_file: Optional[Callable[[str], bool]]=None, ) -> List[
Tuple[str, int]]:
"""Generates a list of samples of a form (path_to_sample, class).
This can be overridden to e.g. read files from a compressed zip file instead of from the disk.
Args:
directory (str): root dataset directory, corresponding to ``self.root``.
class_to_idx (Dict[str, int]): Dictionary mapping class name to class index.
extensions (optional): A list of allowed extensions.
Either extensions or is_valid_file should be passed. Defaults to None.
is_valid_file (optional): A function that takes path of a file
and checks if the file is a valid file
(used to check of corrupt files) both extensions and
is_valid_file should not be passed. Defaults to None.
Raises:
ValueError: In case ``class_to_idx`` is empty.
ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
FileNotFoundError: In case no valid file was found for any class.
Returns:
List[Tuple[str, int]]: samples of a form (path_to_sample, class)
"""
if class_to_idx is None:
# prevent potential bug since make_dataset() would use the class_to_idx logic of the
# find_classes() function, instead of using that of the find_classes() method, which
# is potentially overridden and thus could have a different logic.
raise ValueError("The class_to_idx parameter cannot be None.")
return make_dataset(
directory,
class_to_idx,
extensions=extensions,
is_valid_file=is_valid_file)
def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Find the class folders in a dataset structured as follows::
directory/
├── class_x
│ ├── xxx.ext
│ ├── xxy.ext
│ └── ...
│ └── xxz.ext
└── class_y
├── 123.ext
├── nsdf3.ext
└── ...
└── asd932_.ext
This method can be overridden to only consider
a subset of classes, or to adapt to a different dataset directory structure.
Args:
directory(str): Root directory path, corresponding to ``self.root``
Raises:
FileNotFoundError: If ``dir`` has no class folders.
Returns:
(Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
"""
return find_classes(directory)
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self) -> int:
return len(self.samples)
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
'.tiff', '.webp')
def pil_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def default_loader(path: str) -> Any:
return pil_loader(path)
class ImageFolder(DatasetFolder):
"""A generic data loader where the images are arranged in this way by default: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/[...]/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files)
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def __init__(
self,
root: str,
transform: Optional[Callable]=None,
target_transform: Optional[Callable]=None,
loader: Callable[[str], Any]=default_loader,
is_valid_file: Optional[Callable[[str], bool]]=None, ):
super(ImageFolder, self).__init__(
root,
loader,
IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform,
is_valid_file=is_valid_file)
self.imgs = self.samples
import os
import paddle
from typing import Any, Callable, List, Optional, Tuple
class VisionDataset(paddle.io.Dataset):
"""
Base Class For making datasets which are compatible with our model.
It is necessary to override the ``__getitem__`` and ``__len__`` method.
Args:
root (string): Root directory of dataset.
transforms (callable, optional): A function/transforms that takes in
an image and a label and returns the transformed versions of both.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
.. note::
:attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive.
"""
_repr_indent = 4
def __init__(
self,
root: str,
transforms: Optional[Callable]=None,
transform: Optional[Callable]=None,
target_transform: Optional[Callable]=None, ) -> None:
if isinstance(root, (str, bytes())):
root = os.path.expanduser(root)
self.root = root
has_transforms = transforms is not None
has_separate_transform = transform is not None or target_transform is not None
if has_transforms and has_separate_transform:
raise ValueError(
"Only transforms or transform/target_transform can "
"be passed as argument")
# for backwards-compatibility
self.transform = transform
self.target_transform = target_transform
if has_separate_transform:
transforms = StandardTransform(transform, target_transform)
self.transforms = transforms
def __getitem__(self, index: int) -> Any:
"""
Args:
index (int): Index
Returns:
(Any): Sample and meta data, optionally transformed by the respective transforms.
"""
raise NotImplementedError
def __len__(self) -> int:
raise NotImplementedError
def __repr__(self) -> str:
head = "Dataset " + self.__class__.__name__
body = ["Number of datapoints: {}".format(self.__len__())]
if self.root is not None:
body.append("Root location: {}".format(self.root))
body += self.extra_repr().splitlines()
if hasattr(self, "transforms") and self.transforms is not None:
body += [repr(self.transforms)]
lines = [head] + [" " * self._repr_indent + line for line in body]
return '\n'.join(lines)
def _format_transform_repr(self, transform: Callable,
head: str) -> List[str]:
lines = transform.__repr__().splitlines()
return (["{}{}".format(head, lines[0])] +
["{}{}".format(" " * len(head), line) for line in lines[1:]])
def extra_repr(self) -> str:
return ""
class StandardTransform(object):
def __init__(self,
transform: Optional[Callable]=None,
target_transform: Optional[Callable]=None) -> None:
self.transform = transform
self.target_transform = target_transform
def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]:
if self.transform is not None:
input = self.transform(input)
if self.target_transform is not None:
target = self.target_transform(target)
return input, target
def _format_transform_repr(self, transform: Callable,
head: str) -> List[str]:
lines = transform.__repr__().splitlines()
return (["{}{}".format(head, lines[0])] +
["{}{}".format(" " * len(head), line) for line in lines[1:]])
def __repr__(self) -> str:
body = [self.__class__.__name__]
if self.transform is not None:
body += self._format_transform_repr(self.transform, "Transform: ")
if self.target_transform is not None:
body += self._format_transform_repr(self.target_transform,
"Target transform: ")
return '\n'.join(body)
from .mobilenet_v3_paddle import mobilenet_v3_large, mobilenet_v3_small
from typing import Any, Callable, List, Optional, Sequence
import paddle
import paddle.nn as nn
class ConvNormActivation(nn.Sequential):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int=3,
stride: int=1,
padding: Optional[int]=None,
groups: int=1,
norm_layer: Optional[Callable[..., nn.Layer]]=nn.BatchNorm2D,
activation_layer: Optional[Callable[..., nn.Layer]]=nn.ReLU,
dilation: int=1,
bias: Optional[bool]=None, ) -> None:
if padding is None:
padding = (kernel_size - 1) // 2 * dilation
if bias is None:
bias = norm_layer is None
layers = [
nn.Conv2D(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation=dilation,
groups=groups,
bias_attr=bias, )
]
if norm_layer is not None:
layers.append(norm_layer(out_channels))
if activation_layer is not None:
layers.append(activation_layer())
super().__init__(*layers)
self.out_channels = out_channels
class SqueezeExcitation(nn.Layer):
def __init__(
self,
input_channels: int,
squeeze_channels: int,
activation: Callable[..., nn.Layer]=nn.ReLU,
scale_activation: Callable[..., nn.Layer]=nn.Sigmoid, ) -> None:
super().__init__()
self.avgpool = nn.AdaptiveAvgPool2D(1)
self.fc1 = nn.Conv2D(input_channels, squeeze_channels, 1)
self.fc2 = nn.Conv2D(squeeze_channels, input_channels, 1)
self.activation = activation()
self.scale_activation = scale_activation()
def _scale(self, input: paddle.Tensor) -> paddle.Tensor:
scale = self.avgpool(input)
scale = self.fc1(scale)
scale = self.activation(scale)
scale = self.fc2(scale)
return self.scale_activation(scale)
def forward(self, input: paddle.Tensor) -> paddle.Tensor:
scale = self._scale(input)
return scale * input
import warnings
from functools import partial
from typing import Any, Callable, List, Optional, Sequence
import paddle
import paddle.nn as nn
from .misc_paddle import ConvNormActivation, SqueezeExcitation as SElayer
__all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"]
def _make_divisible(v: float, divisor: int,
min_value: Optional[int]=None) -> int:
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class SqueezeExcitation(SElayer):
def __init__(self, input_channels: int, squeeze_factor: int=4):
squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
super().__init__(
input_channels, squeeze_channels, scale_activation=nn.Hardsigmoid)
self.relu = self.activation
delattr(self, "activation")
class InvertedResidualConfig:
# Stores information listed at Tables 1 and 2 of the MobileNetV3 paper
def __init__(
self,
input_channels: int,
kernel: int,
expanded_channels: int,
out_channels: int,
use_se: bool,
activation: str,
stride: int,
dilation: int,
width_mult: float, ):
self.input_channels = self.adjust_channels(input_channels, width_mult)
self.kernel = kernel
self.expanded_channels = self.adjust_channels(expanded_channels,
width_mult)
self.out_channels = self.adjust_channels(out_channels, width_mult)
self.use_se = use_se
self.use_hs = activation == "HS"
self.stride = stride
self.dilation = dilation
@staticmethod
def adjust_channels(channels: int, width_mult: float):
return _make_divisible(channels * width_mult, 8)
class InvertedResidual(nn.Layer):
# Implemented as described at section 5 of MobileNetV3 paper
def __init__(
self,
cnf: InvertedResidualConfig,
norm_layer: Callable[..., nn.Layer],
se_layer: Callable[..., nn.Layer]=partial(
SElayer, scale_activation=nn.Hardsigmoid), ):
super().__init__()
if not (1 <= cnf.stride <= 2):
raise ValueError("illegal stride value")
self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
layers: List[nn.Layer] = []
activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU
# expand
if cnf.expanded_channels != cnf.input_channels:
layers.append(
ConvNormActivation(
cnf.input_channels,
cnf.expanded_channels,
kernel_size=1,
norm_layer=norm_layer,
activation_layer=activation_layer, ))
# depthwise
stride = 1 if cnf.dilation > 1 else cnf.stride
layers.append(
ConvNormActivation(
cnf.expanded_channels,
cnf.expanded_channels,
kernel_size=cnf.kernel,
stride=stride,
dilation=cnf.dilation,
groups=cnf.expanded_channels,
norm_layer=norm_layer,
activation_layer=activation_layer, ))
if cnf.use_se:
squeeze_channels = _make_divisible(cnf.expanded_channels // 4, 8)
layers.append(se_layer(cnf.expanded_channels, squeeze_channels))
# project
layers.append(
ConvNormActivation(
cnf.expanded_channels,
cnf.out_channels,
kernel_size=1,
norm_layer=norm_layer,
activation_layer=None))
self.block = nn.Sequential(*layers)
self.out_channels = cnf.out_channels
self._is_cn = cnf.stride > 1
def forward(self, input: paddle.Tensor) -> paddle.Tensor:
result = self.block(input)
if self.use_res_connect:
result += input
return result
class MobileNetV3(nn.Layer):
def __init__(
self,
inverted_residual_setting: List[InvertedResidualConfig],
last_channel: int,
num_classes: int=1000,
block: Optional[Callable[..., nn.Layer]]=None,
norm_layer: Optional[Callable[..., nn.Layer]]=None,
dropout: float=0.2,
**kwargs: Any, ) -> None:
"""
MobileNet V3 main class
Args:
inverted_residual_setting (List[InvertedResidualConfig]): Network structure
last_channel (int): The number of channels on the penultimate layer
num_classes (int): Number of classes
block (Optional[Callable[..., nn.Layer]]): Module specifying inverted residual building block for mobilenet
norm_layer (Optional[Callable[..., nn.Layer]]): Module specifying the normalization layer to use
dropout (float): The droupout probability
"""
super().__init__()
if not inverted_residual_setting:
raise ValueError(
"The inverted_residual_setting should not be empty")
elif not (isinstance(inverted_residual_setting, Sequence) and all([
isinstance(s, InvertedResidualConfig)
for s in inverted_residual_setting
])):
raise TypeError(
"The inverted_residual_setting should be List[InvertedResidualConfig]"
)
if block is None:
block = InvertedResidual
if norm_layer is None:
norm_layer = partial(nn.BatchNorm2D, epsilon=0.001, momentum=0.01)
layers: List[nn.Layer] = []
# building first layer
firstconv_output_channels = inverted_residual_setting[0].input_channels
layers.append(
ConvNormActivation(
3,
firstconv_output_channels,
kernel_size=3,
stride=2,
norm_layer=norm_layer,
activation_layer=nn.Hardswish, ))
# building inverted residual blocks
for cnf in inverted_residual_setting:
layers.append(block(cnf, norm_layer))
# building last several layers
lastconv_input_channels = inverted_residual_setting[-1].out_channels
lastconv_output_channels = 6 * lastconv_input_channels
layers.append(
ConvNormActivation(
lastconv_input_channels,
lastconv_output_channels,
kernel_size=1,
norm_layer=norm_layer,
activation_layer=nn.Hardswish, ))
self.features = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2D(1)
self.classifier = nn.Sequential(
nn.Linear(lastconv_output_channels, last_channel),
nn.Hardswish(),
nn.Dropout(p=dropout),
nn.Linear(last_channel, num_classes), )
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
x = self.features(x)
x = self.avgpool(x)
x = paddle.flatten(x, 1)
x = self.classifier(x)
return x
def _mobilenet_v3_conf(arch: str,
width_mult: float=1.0,
reduced_tail: bool=False,
dilated: bool=False,
**kwargs: Any):
reduce_divider = 2 if reduced_tail else 1
dilation = 2 if dilated else 1
bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
adjust_channels = partial(
InvertedResidualConfig.adjust_channels, width_mult=width_mult)
if arch == "mobilenet_v3_large":
inverted_residual_setting = [
bneck_conf(16, 3, 16, 16, False, "RE", 1, 1),
bneck_conf(16, 3, 64, 24, False, "RE", 2, 1), # C1
bneck_conf(24, 3, 72, 24, False, "RE", 1, 1),
bneck_conf(24, 5, 72, 40, True, "RE", 2, 1), # C2
bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
bneck_conf(40, 3, 240, 80, False, "HS", 2, 1), # C3
bneck_conf(80, 3, 200, 80, False, "HS", 1, 1),
bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
bneck_conf(80, 3, 480, 112, True, "HS", 1, 1),
bneck_conf(112, 3, 672, 112, True, "HS", 1, 1),
bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2,
dilation), # C4
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider,
160 // reduce_divider, True, "HS", 1, dilation),
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider,
160 // reduce_divider, True, "HS", 1, dilation),
]
last_channel = adjust_channels(1280 // reduce_divider) # C5
elif arch == "mobilenet_v3_small":
inverted_residual_setting = [
bneck_conf(16, 3, 16, 16, True, "RE", 2, 1), # C1
bneck_conf(16, 3, 72, 24, False, "RE", 2, 1), # C2
bneck_conf(24, 3, 88, 24, False, "RE", 1, 1),
bneck_conf(24, 5, 96, 40, True, "HS", 2, 1), # C3
bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
bneck_conf(40, 5, 120, 48, True, "HS", 1, 1),
bneck_conf(48, 5, 144, 48, True, "HS", 1, 1),
bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2,
dilation), # C4
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider,
96 // reduce_divider, True, "HS", 1, dilation),
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider,
96 // reduce_divider, True, "HS", 1, dilation),
]
last_channel = adjust_channels(1024 // reduce_divider) # C5
else:
raise ValueError(f"Unsupported model type {arch}")
return inverted_residual_setting, last_channel
def _mobilenet_v3(
arch: str,
inverted_residual_setting: List[InvertedResidualConfig],
last_channel: int,
pretrained: bool,
progress: bool,
**kwargs: Any, ):
model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
if pretrained:
state_dict = paddle.load(pretrained)
model.set_dict(state_dict)
return model
def mobilenet_v3_large(pretrained: bool=False,
progress: bool=True,
**kwargs: Any) -> MobileNetV3:
"""
Constructs a large MobileNetV3 architecture from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
arch = "mobilenet_v3_large"
inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch,
**kwargs)
return _mobilenet_v3(arch, inverted_residual_setting, last_channel,
pretrained, progress, **kwargs)
def mobilenet_v3_small(pretrained: bool=False,
progress: bool=True,
**kwargs: Any) -> MobileNetV3:
"""
Constructs a small MobileNetV3 architecture from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
arch = "mobilenet_v3_small"
inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch,
**kwargs)
return _mobilenet_v3(arch, inverted_residual_setting, last_channel,
pretrained, progress, **kwargs)
import math
import paddle
from enum import Enum
from paddle import Tensor
from typing import List, Tuple, Optional
from . import functional as f
from .functional import InterpolationMode
__all__ = ["AutoAugmentPolicy", "AutoAugment"]
class AutoAugmentPolicy(Enum):
"""AutoAugment policies learned on different datasets.
Available policies are IMAGENET, CIFAR10 and SVHN.
"""
IMAGENET = "imagenet"
CIFAR10 = "cifar10"
SVHN = "svhn"
def _get_transforms(policy: AutoAugmentPolicy):
if policy == AutoAugmentPolicy.IMAGENET:
return [
(("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
(("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
(("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
(("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
(("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
(("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
(("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
(("Rotate", 0.8, 8), ("Color", 0.4, 0)),
(("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
(("Equalize", 0.0, None), ("Equalize", 0.8, None)),
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
(("Rotate", 0.8, 8), ("Color", 1.0, 2)),
(("Color", 0.8, 8), ("Solarize", 0.8, 7)),
(("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
(("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
(("Color", 0.4, 0), ("Equalize", 0.6, None)),
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
]
elif policy == AutoAugmentPolicy.CIFAR10:
return [
(("Invert", 0.1, None), ("Contrast", 0.2, 6)),
(("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
(("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
(("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
(("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
(("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
(("Color", 0.4, 3), ("Brightness", 0.6, 7)),
(("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
(("Equalize", 0.6, None), ("Equalize", 0.5, None)),
(("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
(("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
(("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
(("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
(("Brightness", 0.9, 6), ("Color", 0.2, 8)),
(("Solarize", 0.5, 2), ("Invert", 0.0, None)),
(("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
(("Equalize", 0.2, None), ("Equalize", 0.6, None)),
(("Color", 0.9, 9), ("Equalize", 0.6, None)),
(("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
(("Brightness", 0.1, 3), ("Color", 0.7, 0)),
(("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
(("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
(("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
(("Equalize", 0.8, None), ("Invert", 0.1, None)),
(("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
]
elif policy == AutoAugmentPolicy.SVHN:
return [
(("ShearX", 0.9, 4), ("Invert", 0.2, None)),
(("ShearY", 0.9, 8), ("Invert", 0.7, None)),
(("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
(("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
(("ShearY", 0.9, 8), ("Invert", 0.4, None)),
(("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
(("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
(("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
(("ShearY", 0.8, 8), ("Invert", 0.7, None)),
(("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
(("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
(("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
(("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
(("Invert", 0.6, None), ("Rotate", 0.8, 4)),
(("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
(("ShearX", 0.1, 6), ("Invert", 0.6, None)),
(("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
(("ShearY", 0.8, 4), ("Invert", 0.8, None)),
(("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
(("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
(("ShearX", 0.7, 2), ("Invert", 0.1, None)),
]
def _get_magnitudes():
_BINS = 10
return {
# name: (magnitudes, signed)
"ShearX": (paddle.linspace(0.0, 0.3, _BINS), True),
"ShearY": (paddle.linspace(0.0, 0.3, _BINS), True),
"TranslateX": (paddle.linspace(0.0, 150.0 / 331.0, _BINS), True),
"TranslateY": (paddle.linspace(0.0, 150.0 / 331.0, _BINS), True),
"Rotate": (paddle.linspace(0.0, 30.0, _BINS), True),
"Brightness": (paddle.linspace(0.0, 0.9, _BINS), True),
"Color": (paddle.linspace(0.0, 0.9, _BINS), True),
"Contrast": (paddle.linspace(0.0, 0.9, _BINS), True),
"Sharpness": (paddle.linspace(0.0, 0.9, _BINS), True),
"Posterize": (paddle.tensor([8, 8, 7, 7, 6, 6, 5, 5, 4, 4]), False),
"Solarize": (paddle.linspace(256.0, 0.0, _BINS), False),
"AutoContrast": (None, None),
"Equalize": (None, None),
"Invert": (None, None),
}
class AutoAugment(paddle.nn.Layer):
r"""AutoAugment data augmentation method based on
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
If the image is paddle Tensor, it should be of type paddle.uint8, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
policy (AutoAugmentPolicy): Default is ``AutoAugmentPolicy.IMAGENET``.
interpolation (InterpolationMode): Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""
def __init__(self,
policy: AutoAugmentPolicy=AutoAugmentPolicy.IMAGENET,
interpolation: InterpolationMode=InterpolationMode.NEAREST,
fill: Optional[List[float]]=None):
super().__init__()
self.policy = policy
self.interpolation = interpolation
self.fill = fill
self.transforms = _get_transforms(policy)
if self.transforms is None:
raise ValueError(
"The provided policy {} is not recognized.".format(policy))
self._op_meta = _get_magnitudes()
@staticmethod
def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]:
"""Get parameters for autoaugment transformation
Returns:
params required by the autoaugment transformation
"""
policy_id = int(paddle.randint(low=0, high=transform_num, shape=(1, )))
probs = paddle.rand((2, ))
signs = paddle.randint(low=0, high=2, shape=(2, ))
return policy_id, probs, signs
def _get_op_meta(self,
name: str) -> Tuple[Optional[Tensor], Optional[bool]]:
return self._op_meta[name]
def forward(self, img: Tensor):
"""
img (PIL Image or Tensor): Image to be transformed.
Returns:
PIL Image or Tensor: AutoAugmented image.
"""
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img)
elif fill is not None:
fill = [float(f) for f in fill]
transform_id, probs, signs = self.get_params(len(self.transforms))
for i, (op_name, p,
magnitude_id) in enumerate(self.transforms[transform_id]):
if probs[i] <= p:
magnitudes, signed = self._get_op_meta(op_name)
magnitude = float(magnitudes[magnitude_id].item()) \
if magnitudes is not None and magnitude_id is not None else 0.0
if signed is not None and signed and signs[i] == 0:
magnitude *= -1.0
if op_name == "ShearX":
img = F.affine(
img,
angle=0.0,
translate=[0, 0],
scale=1.0,
shear=[math.degrees(magnitude), 0.0],
interpolation=self.interpolation,
fill=fill)
elif op_name == "ShearY":
img = F.affine(
img,
angle=0.0,
translate=[0, 0],
scale=1.0,
shear=[0.0, math.degrees(magnitude)],
interpolation=self.interpolation,
fill=fill)
elif op_name == "TranslateX":
img = F.affine(
img,
angle=0.0,
translate=[
int(F._get_image_size(img)[0] * magnitude), 0
],
scale=1.0,
interpolation=self.interpolation,
shear=[0.0, 0.0],
fill=fill)
elif op_name == "TranslateY":
img = F.affine(
img,
angle=0.0,
translate=[
0, int(F._get_image_size(img)[1] * magnitude)
],
scale=1.0,
interpolation=self.interpolation,
shear=[0.0, 0.0],
fill=fill)
elif op_name == "Rotate":
img = F.rotate(
img,
magnitude,
interpolation=self.interpolation,
fill=fill)
elif op_name == "Brightness":
img = F.adjust_brightness(img, 1.0 + magnitude)
elif op_name == "Color":
img = F.adjust_saturation(img, 1.0 + magnitude)
elif op_name == "Contrast":
img = F.adjust_contrast(img, 1.0 + magnitude)
elif op_name == "Sharpness":
img = F.adjust_sharpness(img, 1.0 + magnitude)
elif op_name == "Posterize":
img = F.posterize(img, int(magnitude))
elif op_name == "Solarize":
img = F.solarize(img, magnitude)
elif op_name == "AutoContrast":
img = F.autocontrast(img)
elif op_name == "Equalize":
img = F.equalize(img)
elif op_name == "Invert":
img = F.invert(img)
else:
raise ValueError(
"The provided operator {} is not recognized.".format(
op_name))
return img
def __repr__(self):
return self.__class__.__name__ + '(policy={}, fill={})'.format(
self.policy, self.fill)
import numbers
import warnings
from enum import Enum
import numpy as np
import paddle
from paddle import Tensor
from typing import List, Tuple, Any, Optional
try:
import accimage
except ImportError:
accimage = None
from . import functional_pil as F_pil
from . import functional_tensor as F_t
class InterpolationMode(Enum):
"""Interpolation modes
Available interpolation methods are ``nearest``, ``bilinear``, ``bicubic``, ``box``, ``hamming``, and ``lanczos``.
"""
NEAREST = "nearest"
BILINEAR = "bilinear"
BICUBIC = "bicubic"
# For PIL compatibility
BOX = "box"
HAMMING = "hamming"
LANCZOS = "lanczos"
def _interpolation_modes_from_int(i: int) -> InterpolationMode:
inverse_modes_mapping = {
0: InterpolationMode.NEAREST,
2: InterpolationMode.BILINEAR,
3: InterpolationMode.BICUBIC,
4: InterpolationMode.BOX,
5: InterpolationMode.HAMMING,
1: InterpolationMode.LANCZOS,
}
return inverse_modes_mapping[i]
pil_modes_mapping = {
InterpolationMode.NEAREST: 0,
InterpolationMode.BILINEAR: 2,
InterpolationMode.BICUBIC: 3,
InterpolationMode.BOX: 4,
InterpolationMode.HAMMING: 5,
InterpolationMode.LANCZOS: 1,
}
def _is_numpy(img: Any) -> bool:
return isinstance(img, np.ndarray)
def _is_numpy_image(img: Any) -> bool:
return img.ndim in {2, 3}
def to_tensor(pic):
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
See :class:`~paddlevision.transforms.ToTensor` for more details.
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
if not (F_pil._is_pil_image(pic) or _is_numpy(pic)):
raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(
type(pic)))
if _is_numpy(pic) and not _is_numpy_image(pic):
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.
format(pic.ndim))
default_float_dtype = paddle.get_default_dtype()
if isinstance(pic, np.ndarray):
# handle numpy array
if pic.ndim == 2:
pic = pic[:, :, None]
img = paddle.to_tensor(pic.transpose((2, 0, 1)))
# backward compatibility
if not img.dtype == default_float_dtype:
img = img.astype(dtype=default_float_dtype)
return img.divide(paddle.full_like(img, 255))
else:
return img
if accimage is not None and isinstance(pic, accimage.Image):
nppic = np.zeros(
[pic.channels, pic.height, pic.width], dtype=np.float32)
pic.copyto(nppic)
return paddle.to_tensor(nppic).astype(dtype=default_float_dtype)
# handle PIL Image
mode_to_nptype = {'I': np.int32, 'I;16': np.int16, 'F': np.float32}
img = paddle.to_tensor(
np.array(
pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True))
if pic.mode == '1':
img = 255 * img
img = img.reshape([pic.size[1], pic.size[0], len(pic.getbands())])
if not img.dtype == default_float_dtype:
img = img.astype(dtype=default_float_dtype)
# put it from HWC to CHW format
img = img.transpose((2, 0, 1))
return img.divide(paddle.full_like(img, 255))
else:
# put it from HWC to CHW format
img = img.transpose((2, 0, 1))
return img
def normalize(tensor: Tensor,
mean: List[float],
std: List[float],
inplace: bool=False) -> Tensor:
"""Normalize a float tensor image with mean and standard deviation.
This transform does not support PIL Image.
.. note::
This transform acts out of place by default, i.e., it does not mutates the input tensor.
See :class:`~paddlevision.transforms.Normalize` for more details.
Args:
tensor (Tensor): Float tensor image of size (C, H, W) or (B, C, H, W) to be normalized.
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
inplace(bool,optional): Bool to make this operation inplace.
Returns:
Tensor: Normalized Tensor image.
"""
if not isinstance(tensor, paddle.Tensor):
raise TypeError('Input tensor should be a paddle tensor. Got {}.'.
format(type(tensor)))
if not tensor.dtype in (paddle.float16, paddle.float32, paddle.float64):
raise TypeError('Input tensor should be a float tensor. Got {}.'.
format(tensor.dtype))
if tensor.ndim < 3:
raise ValueError(
'Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.shape() = '
'{}.'.format(tensor.shape))
if not inplace:
tensor = tensor.clone()
dtype = tensor.dtype
mean = paddle.to_tensor(mean, dtype=dtype, place=tensor.place)
std = paddle.to_tensor(std, dtype=dtype, place=tensor.place)
if (std == 0).any():
raise ValueError('std evaluated to zero, leading to division by zero.')
if mean.ndim == 1:
mean = mean.reshape((-1, 1, 1))
if std.ndim == 1:
std = std.reshape((-1, 1, 1))
tensor = tensor.subtract(mean).divide(std)
return tensor
def resize(img: Tensor,
size: List[int],
interpolation: InterpolationMode=InterpolationMode.BILINEAR,
max_size: Optional[int]=None,
antialias: Optional[bool]=None) -> Tensor:
r"""Resize the input image to the given size.
If the image is paddle Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
.. warning::
The output image might be different depending on its type: when downsampling, the interpolation of PIL images
and tensors is slightly different, because PIL applies antialiasing. This may lead to significant differences
in the performance of a network. Therefore, it is preferable to train and serve a model with the same input
types. See also below the ``antialias`` parameter, which can help making the output of PIL images and tensors
closer.
Args:
img (PIL Image or Tensor): Image to be resized.
size (sequence or int): Desired output size. If size is a sequence like
(h, w), the output size will be matched to this. If size is an int,
the smaller edge of the image will be matched to this number maintaining
the aspect ratio. i.e, if height > width, then image will be rescaled to
:math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`paddlevision.transforms.InterpolationMode`.
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
max_size (int, optional): The maximum allowed for the longer edge of
the resized image: if the longer edge of the image is greater
than ``max_size`` after being resized according to ``size``, then
the image is resized again so that the longer edge is equal to
``max_size``. As a result, ``size`` might be overruled, i.e the
smaller edge may be shorter than ``size``.
antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias
is always used. If ``img`` is Tensor, the flag is False by default and can be set to True for
``InterpolationMode.BILINEAR`` only mode. This can help making the output for PIL images and tensors
closer.
.. warning::
There is no autodiff support for ``antialias=True`` option with input ``img`` as Tensor.
Returns:
PIL Image or Tensor: Resized image.
"""
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationMode instead of int. "
"Please, use InterpolationMode enum.")
interpolation = _interpolation_modes_from_int(interpolation)
if not isinstance(interpolation, InterpolationMode):
raise TypeError("Argument interpolation should be a InterpolationMode")
if not isinstance(img, paddle.Tensor):
if antialias is not None and not antialias:
warnings.warn(
"Anti-alias option is always applied for PIL Image input. Argument antialias is ignored."
)
pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.resize(
img, size=size, interpolation=pil_interpolation, max_size=max_size)
return F_t.resize(
img,
size=size,
interpolation=interpolation.value,
max_size=max_size,
antialias=antialias)
def _get_image_size(img: Tensor) -> List[int]:
"""Returns image size as [w, h]
"""
if isinstance(img, paddle.Tensor):
return F_t._get_image_size(img)
return F_pil._get_image_size(img)
def pad(img: Tensor,
padding: List[int],
fill: int=0,
padding_mode: str="constant") -> Tensor:
r"""Pad the given image on all sides with the given "pad" value.
If the image is paddle Tensor, it is expected
to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric,
at most 3 leading dimensions for mode edge,
and an arbitrary number of leading dimensions for mode constant
Args:
img (PIL Image or Tensor): Image to be padded.
padding (int or sequence): Padding on each border. If a single int is provided this
is used to pad all borders. If sequence of length 2 is provided this is the padding
on left/right and top/bottom respectively. If a sequence of length 4 is provided
this is the padding for the left, top, right and bottom borders respectively.
fill (number or str or tuple): Pixel fill value for constant fill. Default is 0.
If a tuple of length 3, it is used to fill R, G, B channels respectively.
This value is only used when the padding_mode is constant.
Only number is supported for paddle Tensor.
Only int or str or tuple value is supported for PIL Image.
padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
Default is constant.
- constant: pads with a constant value, this value is specified with fill
- edge: pads with the last value at the edge of the image.
If input a 5D paddle Tensor, the last 3 dimensions will be padded instead of the last 2
- reflect: pads with reflection of image without repeating the last value on the edge.
For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
will result in [3, 2, 1, 2, 3, 4, 3, 2]
- symmetric: pads with reflection of image repeating the last value on the edge.
For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
will result in [2, 1, 1, 2, 3, 4, 4, 3]
Returns:
PIL Image or Tensor: Padded image.
"""
if not isinstance(img, paddle.Tensor):
return F_pil.pad(img,
padding=padding,
fill=fill,
padding_mode=padding_mode)
return F_t.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
"""Crop the given image at specified location and output size.
If the image is paddle Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
If image size is smaller than output size along any edge, image is padded with 0 and then cropped.
Args:
img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
top (int): Vertical component of the top left corner of the crop box.
left (int): Horizontal component of the top left corner of the crop box.
height (int): Height of the crop box.
width (int): Width of the crop box.
Returns:
PIL Image or Tensor: Cropped image.
"""
if not isinstance(img, paddle.Tensor):
return F_pil.crop(img, top, left, height, width)
return F_t.crop(img, top, left, height, width)
def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
"""Crops the given image at the center.
If the image is paddle Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
Args:
img (PIL Image or Tensor): Image to be cropped.
output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
it is used for both directions.
Returns:
PIL Image or Tensor: Cropped image.
"""
if isinstance(output_size, numbers.Number):
output_size = (int(output_size), int(output_size))
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
output_size = (output_size[0], output_size[0])
image_width, image_height = _get_image_size(img)
crop_height, crop_width = output_size
if crop_width > image_width or crop_height > image_height:
padding_ltrb = [
(crop_width - image_width) // 2 if crop_width > image_width else 0,
(crop_height - image_height) // 2
if crop_height > image_height else 0,
(crop_width - image_width + 1) // 2
if crop_width > image_width else 0,
(crop_height - image_height + 1) // 2
if crop_height > image_height else 0,
]
img = pad(img, padding_ltrb, fill=0) # PIL uses fill value 0
image_width, image_height = _get_image_size(img)
if crop_width == image_width and crop_height == image_height:
return img
crop_top = int(round((image_height - crop_height) / 2.))
crop_left = int(round((image_width - crop_width) / 2.))
return crop(img, crop_top, crop_left, crop_height, crop_width)
def resized_crop(
img: Tensor,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode=InterpolationMode.BILINEAR) -> Tensor:
"""Crop the given image and resize it to desired size.
If the image is paddle Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
Args:
img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
top (int): Vertical component of the top left corner of the crop box.
left (int): Horizontal component of the top left corner of the crop box.
height (int): Height of the crop box.
width (int): Width of the crop box.
size (sequence or int): Desired output size. Same semantics as ``resize``.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`paddlevision.transforms.InterpolationMode`.
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
Returns:
PIL Image or Tensor: Cropped image.
"""
img = crop(img, top, left, height, width)
img = resize(img, size, interpolation)
return img
import numbers
from typing import Any, List, Sequence
import numpy as np
from PIL import Image, ImageOps, ImageEnhance
try:
import accimage
except ImportError:
accimage = None
def _is_pil_image(img: Any) -> bool:
if accimage is not None:
return isinstance(img, (Image.Image, accimage.Image))
else:
return isinstance(img, Image.Image)
def _get_image_size(img: Any) -> List[int]:
if _is_pil_image(img):
return img.size
raise TypeError("Unexpected type {}".format(type(img)))
def _get_image_num_channels(img: Any) -> int:
if _is_pil_image(img):
return 1 if img.mode == 'L' else 3
raise TypeError("Unexpected type {}".format(type(img)))
def hflip(img):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return img.transpose(Image.FLIP_LEFT_RIGHT)
def vflip(img):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return img.transpose(Image.FLIP_TOP_BOTTOM)
def adjust_brightness(img, brightness_factor):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
enhancer = ImageEnhance.Brightness(img)
img = enhancer.enhance(brightness_factor)
return img
def adjust_contrast(img, contrast_factor):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
enhancer = ImageEnhance.Contrast(img)
img = enhancer.enhance(contrast_factor)
return img
def adjust_saturation(img, saturation_factor):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
enhancer = ImageEnhance.Color(img)
img = enhancer.enhance(saturation_factor)
return img
def adjust_hue(img, hue_factor):
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(
hue_factor))
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
input_mode = img.mode
if input_mode in {'L', '1', 'I', 'F'}:
return img
h, s, v = img.convert('HSV').split()
np_h = np.array(h, dtype=np.uint8)
# uint8 addition take cares of rotation across boundaries
with np.errstate(over='ignore'):
np_h += np.uint8(hue_factor * 255)
h = Image.fromarray(np_h, 'L')
img = Image.merge('HSV', (h, s, v)).convert(input_mode)
return img
def adjust_gamma(img, gamma, gain=1):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if gamma < 0:
raise ValueError('Gamma should be a non-negative real number')
input_mode = img.mode
img = img.convert('RGB')
gamma_map = [(255 + 1 - 1e-3) * gain * pow(ele / 255., gamma)
for ele in range(256)] * 3
img = img.point(
gamma_map) # use PIL's point-function to accelerate this part
img = img.convert(input_mode)
return img
def pad(img, padding, fill=0, padding_mode="constant"):
if not _is_pil_image(img):
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if not isinstance(fill, (numbers.Number, str, tuple)):
raise TypeError("Got inappropriate fill arg")
if not isinstance(padding_mode, str):
raise TypeError("Got inappropriate padding_mode arg")
if isinstance(padding, list):
padding = tuple(padding)
if isinstance(padding, tuple) and len(padding) not in [1, 2, 4]:
raise ValueError(
"Padding must be an int or a 1, 2, or 4 element tuple, not a " +
"{} element tuple".format(len(padding)))
if isinstance(padding, tuple) and len(padding) == 1:
# Compatibility with `functional_tensor.pad`
padding = padding[0]
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError(
"Padding mode should be either constant, edge, reflect or symmetric"
)
if padding_mode == "constant":
opts = _parse_fill(fill, img, name="fill")
if img.mode == "P":
palette = img.getpalette()
image = ImageOps.expand(img, border=padding, **opts)
image.putpalette(palette)
return image
return ImageOps.expand(img, border=padding, **opts)
else:
if isinstance(padding, int):
pad_left = pad_right = pad_top = pad_bottom = padding
if isinstance(padding, tuple) and len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
if isinstance(padding, tuple) and len(padding) == 4:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]
p = [pad_left, pad_top, pad_right, pad_bottom]
cropping = -np.minimum(p, 0)
if cropping.any():
crop_left, crop_top, crop_right, crop_bottom = cropping
img = img.crop((crop_left, crop_top, img.width - crop_right,
img.height - crop_bottom))
pad_left, pad_top, pad_right, pad_bottom = np.maximum(p, 0)
if img.mode == 'P':
palette = img.getpalette()
img = np.asarray(img)
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)),
padding_mode)
img = Image.fromarray(img)
img.putpalette(palette)
return img
img = np.asarray(img)
# RGB image
if len(img.shape) == 3:
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right),
(0, 0)), padding_mode)
# Grayscale image
if len(img.shape) == 2:
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)),
padding_mode)
return Image.fromarray(img)
def crop(img: Image.Image, top: int, left: int, height: int,
width: int) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return img.crop((left, top, left + width, top + height))
def resize(img, size, interpolation=Image.BILINEAR, max_size=None):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if not (isinstance(size, int) or
(isinstance(size, Sequence) and len(size) in (1, 2))):
raise TypeError('Got inappropriate size arg: {}'.format(size))
if isinstance(size, Sequence) and len(size) == 1:
size = size[0]
if isinstance(size, int):
w, h = img.size
short, long = (w, h) if w <= h else (h, w)
if short == size:
return img
new_short, new_long = size, int(size * long / short)
if max_size is not None:
if max_size <= size:
raise ValueError(
f"max_size = {max_size} must be strictly greater than the requested "
f"size for the smaller edge size = {size}")
if new_long > max_size:
new_short, new_long = int(max_size * new_short /
new_long), max_size
new_w, new_h = (new_short, new_long) if w <= h else (new_long,
new_short)
return img.resize((new_w, new_h), interpolation)
else:
if max_size is not None:
raise ValueError(
"max_size should only be passed if size specifies the length of the smaller edge, "
"i.e. size should be an int or a sequence of length 1 in deploy mode."
)
return img.resize(size[::-1], interpolation)
def _parse_fill(fill, img, name="fillcolor"):
# Process fill color for affine transforms
num_bands = len(img.getbands())
if fill is None:
fill = 0
if isinstance(fill, (int, float)) and num_bands > 1:
fill = tuple([fill] * num_bands)
if isinstance(fill, (list, tuple)):
if len(fill) != num_bands:
msg = (
"The number of elements in 'fill' does not match the number of "
"bands of the image ({} != {})")
raise ValueError(msg.format(len(fill), num_bands))
fill = tuple(fill)
return {name: fill}
def affine(img, matrix, interpolation=0, fill=None):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
output_size = img.size
opts = _parse_fill(fill, img)
return img.transform(output_size, Image.AFFINE, matrix, interpolation,
**opts)
def rotate(img, angle, interpolation=0, expand=False, center=None, fill=None):
if not _is_pil_image(img):
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
opts = _parse_fill(fill, img)
return img.rotate(angle, interpolation, expand, center, **opts)
def perspective(img,
perspective_coeffs,
interpolation=Image.BICUBIC,
fill=None):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
opts = _parse_fill(fill, img)
return img.transform(img.size, Image.PERSPECTIVE, perspective_coeffs,
interpolation, **opts)
def to_grayscale(img, num_output_channels):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if num_output_channels == 1:
img = img.convert('L')
elif num_output_channels == 3:
img = img.convert('L')
np_img = np.array(img, dtype=np.uint8)
np_img = np.dstack([np_img, np_img, np_img])
img = Image.fromarray(np_img, 'RGB')
else:
raise ValueError('num_output_channels should be either 1 or 3')
return img
def invert(img):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.invert(img)
def posterize(img, bits):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.posterize(img, bits)
def solarize(img, threshold):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.solarize(img, threshold)
def adjust_sharpness(img, sharpness_factor):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
enhancer = ImageEnhance.Sharpness(img)
img = enhancer.enhance(sharpness_factor)
return img
def autocontrast(img):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.autocontrast(img)
def equalize(img):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.equalize(img)
import warnings
import paddle
from paddle import Tensor
from paddle.nn.functional import grid_sample, conv2d, interpolate, pad as paddle_pad
from typing import Optional, Tuple, List
def _is_tensor_a_paddle_image(x: Tensor) -> bool:
return x.ndim >= 2
def _assert_image_tensor(img):
if not _is_tensor_a_paddle_image(img):
raise TypeError("Tensor is not a paddle image.")
def _get_image_size(img: Tensor) -> List[int]:
# Returns (w, h) of tensor image
_assert_image_tensor(img)
return [img.shape[-1], img.shape[-2]]
def _cast_squeeze_in(img: Tensor, req_dtypes: List[paddle.dtype]) -> Tuple[
Tensor, bool, bool, paddle.dtype]:
need_squeeze = False
# make image NCHW
if img.ndim < 4:
img = img.unsqueeze(dim=0)
need_squeeze = True
out_dtype = img.dtype
need_cast = False
if out_dtype not in req_dtypes:
need_cast = True
req_dtype = req_dtypes[0]
img = img.as_type(req_dtype)
return img, need_cast, need_squeeze, out_dtype
def _cast_squeeze_out(img: Tensor,
need_cast: bool,
need_squeeze: bool,
out_dtype: paddle.dtype):
if need_squeeze:
img = img.squeeze(dim=0)
if need_cast:
if out_dtype in (paddle.uint8, paddle.int8, paddle.int16, paddle.int32,
paddle.int64):
# it is better to round before cast
img = paddle.round(img)
img = img.as_type(out_dtype)
return img
def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
# padding is left, right, top, bottom
# crop if needed
if padding[0] < 0 or padding[1] < 0 or padding[2] < 0 or padding[3] < 0:
crop_left, crop_right, crop_top, crop_bottom = [
-min(x, 0) for x in padding
]
img = img[..., crop_top:img.shape[-2] - crop_bottom, crop_left:
img.shape[-1] - crop_right]
padding = [max(x, 0) for x in padding]
in_sizes = img.size()
x_indices = [i for i in range(in_sizes[-1])] # [0, 1, 2, 3, ...]
left_indices = [i for i in range(padding[0] - 1, -1, -1)
] # e.g. [3, 2, 1, 0]
right_indices = [-(i + 1) for i in range(padding[1])] # e.g. [-1, -2, -3]
x_indices = paddle.to_tensor(
left_indices + x_indices + right_indices, device=img.device)
y_indices = [i for i in range(in_sizes[-2])]
top_indices = [i for i in range(padding[2] - 1, -1, -1)]
bottom_indices = [-(i + 1) for i in range(padding[3])]
y_indices = paddle.to_tensor(
top_indices + y_indices + bottom_indices, device=img.device)
ndim = img.ndim
if ndim == 3:
return img[:, y_indices[:, None], x_indices[None, :]]
elif ndim == 4:
return img[:, :, y_indices[:, None], x_indices[None, :]]
else:
raise RuntimeError(
"Symmetric padding of N-D tensors are not supported yet")
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
_assert_image_tensor(img)
w, h = _get_image_size(img)
right = left + width
bottom = top + height
if left < 0 or top < 0 or right > w or bottom > h:
padding_ltrb = [
max(-left, 0), max(-top, 0), max(right - w, 0), max(bottom - h, 0)
]
return pad(img[..., max(top, 0):bottom, max(left, 0):right],
padding_ltrb,
fill=0)
return img[..., top:bottom, left:right]
def pad(img: Tensor,
padding: List[int],
fill: int=0,
padding_mode: str="constant") -> Tensor:
_assert_image_tensor(img)
if not isinstance(padding, (int, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if not isinstance(fill, (int, float)):
raise TypeError("Got inappropriate fill arg")
if not isinstance(padding_mode, str):
raise TypeError("Got inappropriate padding_mode arg")
if isinstance(padding, tuple):
padding = list(padding)
if isinstance(padding, list) and len(padding) not in [1, 2, 4]:
raise ValueError(
"Padding must be an int or a 1, 2, or 4 element tuple, not a " +
"{} element tuple".format(len(padding)))
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError(
"Padding mode should be either constant, edge, reflect or symmetric"
)
if isinstance(padding, int):
pad_left = pad_right = pad_top = pad_bottom = padding
elif len(padding) == 1:
pad_left = pad_right = pad_top = pad_bottom = padding[0]
elif len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
else:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]
p = [pad_left, pad_right, pad_top, pad_bottom]
if padding_mode == "edge":
# remap padding_mode str
padding_mode = "replicate"
elif padding_mode == "symmetric":
# route to another implementation
return _pad_symmetric(img, p)
need_squeeze = False
if img.ndim < 4:
img = img.unsqueeze(dim=0)
need_squeeze = True
out_dtype = img.dtype
need_cast = False
if (padding_mode != "constant") and img.dtype not in (paddle.float32,
paddle.float64):
# Here we temporary cast input tensor to float
need_cast = True
img = img.as_type(paddle.float32)
img = paddle_pad(img, p, mode=padding_mode, value=float(fill))
if need_squeeze:
img = img.squeeze(axis=0)
if need_cast:
img = img.as_type(out_dtype)
return img
def resize(img: Tensor,
size: List[int],
interpolation: str="bilinear",
max_size: Optional[int]=None,
antialias: Optional[bool]=None) -> Tensor:
_assert_image_tensor(img)
if not isinstance(size, (int, tuple, list)):
raise TypeError("Got inappropriate size arg")
if not isinstance(interpolation, str):
raise TypeError("Got inappropriate interpolation arg")
if interpolation not in ["nearest", "bilinear", "bicubic"]:
raise ValueError(
"This interpolation mode is unsupported with Tensor input")
if isinstance(size, tuple):
size = list(size)
if isinstance(size, list):
if len(size) not in [1, 2]:
raise ValueError(
"Size must be an int or a 1 or 2 element tuple/list, not a "
"{} element tuple/list".format(len(size)))
if max_size is not None and len(size) != 1:
raise ValueError(
"max_size should only be passed if size specifies the length of the smaller edge."
)
if antialias is None:
antialias = False
if antialias and interpolation not in ["bilinear", "bicubic"]:
raise ValueError(
"Antialias option is supported for bilinear and bicubic interpolation modes only"
)
w, h = _get_image_size(img)
if isinstance(size, int) or len(
size) == 1: # specified size only for the smallest edge
short, long = (w, h) if w <= h else (h, w)
requested_new_short = size if isinstance(size, int) else size[0]
if short == requested_new_short:
return img
new_short, new_long = requested_new_short, int(requested_new_short *
long / short)
if max_size is not None:
if max_size <= requested_new_short:
raise ValueError(
f"max_size = {max_size} must be strictly greater than the requested "
f"size for the smaller edge size = {size}")
if new_long > max_size:
new_short, new_long = int(max_size * new_short /
new_long), max_size
new_w, new_h = (new_short, new_long) if w <= h else (new_long,
new_short)
else: # specified both h and w
new_w, new_h = size[1], size[0]
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
img, [paddle.float32, paddle.float64])
# Define align_corners to avoid warnings
align_corners = False if interpolation in ["bilinear", "bicubic"] else None
img = interpolate(
img,
size=[new_h, new_w],
mode=interpolation,
align_corners=align_corners)
if interpolation == "bicubic" and out_dtype == paddle.uint8:
img = img.clamp(min=0, max=255)
img = _cast_squeeze_out(
img,
need_cast=need_cast,
need_squeeze=need_squeeze,
out_dtype=out_dtype)
return img
import math
import numbers
import warnings
from collections.abc import Sequence
from typing import Tuple, List
import paddle
from paddle import Tensor
try:
import accimage
except ImportError:
accimage = None
from . import functional as F
from .functional import InterpolationMode, _interpolation_modes_from_int
__all__ = [
"Compose", "ToTensor", "Normalize", "Resize", "CenterCrop",
"RandomResizedCrop"
]
class Compose:
"""Composes several transforms together.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.ToTensor(),
>>> ])
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img):
for t in self.transforms:
img = t(img)
return img
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
class ToTensor:
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a paddle tensor of shape (C x H x W) in the range [0.0, 1.0]
if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
or if the numpy.ndarray has dtype = np.uint8
In the other cases, tensors are returned without scaling.
.. note::
Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
transforming target image masks. See the `references`_ for implementing the transforms for image masks.
"""
def __call__(self, pic):
"""
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
return F.to_tensor(pic)
def __repr__(self):
return self.__class__.__name__ + '()'
class Normalize(paddle.nn.Layer):
"""Normalize a tensor image with mean and standard deviation.
This transform does not support PIL Image.
Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
channels, this transform will normalize each channel of the input
``paddle.*Tensor`` i.e.,
``output[channel] = (input[channel] - mean[channel]) / std[channel]``
.. note::
This transform acts out of place, i.e., it does not mutate the input tensor.
Args:
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
inplace(bool,optional): Bool to make this operation in-place.
"""
def __init__(self, mean, std, inplace=False):
super().__init__()
self.mean = mean
self.std = std
self.inplace = inplace
def forward(self, tensor: Tensor) -> Tensor:
"""
Args:
tensor (Tensor): Tensor image to be normalized.
Returns:
Tensor: Normalized Tensor image.
"""
return F.normalize(tensor, self.mean, self.std, self.inplace)
def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(
self.mean, self.std)
class Resize(paddle.nn.Layer):
"""Resize the input image to the given size.
If the image is paddle Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
.. warning::
The output image might be different depending on its type: when downsampling, the interpolation of PIL images
and tensors is slightly different, because PIL applies antialiasing. This may lead to significant differences
in the performance of a network. Therefore, it is preferable to train and serve a model with the same input
types. See also below the ``antialias`` parameter, which can help making the output of PIL images and tensors
closer.
Args:
size (sequence or int): Desired output size. If size is a sequence like
(h, w), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size).
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`paddlevision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and
``InterpolationMode.BICUBIC`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
max_size (int, optional): The maximum allowed for the longer edge of
the resized image: if the longer edge of the image is greater
than ``max_size`` after being resized according to ``size``, then
the image is resized again so that the longer edge is equal to
``max_size``. As a result, ``size`` might be overruled, i.e the
smaller edge may be shorter than ``size``.
antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias
is always used. If ``img`` is Tensor, the flag is False by default and can be set to True for
``InterpolationMode.BILINEAR`` only mode. This can help making the output for PIL images and tensors
closer.
.. warning::
There is no autodiff support for ``antialias=True`` option with input ``img`` as Tensor.
"""
def __init__(self,
size,
interpolation=InterpolationMode.BILINEAR,
max_size=None,
antialias=None):
super().__init__()
if not isinstance(size, (int, Sequence)):
raise TypeError("Size should be int or sequence. Got {}".format(
type(size)))
if isinstance(size, Sequence) and len(size) not in (1, 2):
raise ValueError(
"If size is a sequence, it should have 1 or 2 values")
self.size = size
self.max_size = max_size
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationMode instead of int. "
"Please, use InterpolationMode enum.")
interpolation = _interpolation_modes_from_int(interpolation)
self.interpolation = interpolation
self.antialias = antialias
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be scaled.
Returns:
PIL Image or Tensor: Rescaled image.
"""
return F.resize(img, self.size, self.interpolation, self.max_size,
self.antialias)
def __repr__(self):
interpolate_str = self.interpolation.value
return self.__class__.__name__ + '(size={0}, interpolation={1}, max_size={2}, antialias={3})'.format(
self.size, interpolate_str, self.max_size, self.antialias)
class CenterCrop(paddle.nn.Layer):
"""Crops the given image at the center.
If the image is paddle Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
"""
def __init__(self, size):
super().__init__()
self.size = _setup_size(
size,
error_msg="Please provide only two dimensions (h, w) for size.")
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be cropped.
Returns:
PIL Image or Tensor: Cropped image.
"""
return F.center_crop(img, self.size)
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)
class RandomResizedCrop(paddle.nn.Layer):
"""Crop a random portion of image and resize it to a given size.
If the image is paddle Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
A crop of the original image is made: the crop has a random area (H * W)
and a random aspect ratio. This crop is finally resized to the given
size. This is popularly used to train the Inception networks.
Args:
size (int or sequence): expected output size of the crop, for each edge. If size is an
int instead of sequence like (h, w), a square output size ``(size, size)`` is
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop,
before resizing. The scale is defined with respect to the area of the original image.
ratio (tuple of float): lower and upper bounds for the random aspect ratio of the crop, before
resizing.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`paddlevision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and
``InterpolationMode.BICUBIC`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
"""
def __init__(self,
size,
scale=(0.08, 1.0),
ratio=(3. / 4., 4. / 3.),
interpolation=InterpolationMode.BILINEAR):
super().__init__()
self.size = _setup_size(
size,
error_msg="Please provide only two dimensions (h, w) for size.")
if not isinstance(scale, Sequence):
raise TypeError("Scale should be a sequence")
if not isinstance(ratio, Sequence):
raise TypeError("Ratio should be a sequence")
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("Scale and ratio should be of kind (min, max)")
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationMode instead of int. "
"Please, use InterpolationMode enum.")
interpolation = _interpolation_modes_from_int(interpolation)
self.interpolation = interpolation
self.scale = scale
self.ratio = ratio
@staticmethod
def get_params(img: Tensor, scale: List[float],
ratio: List[float]) -> Tuple[int, int, int, int]:
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (PIL Image or Tensor): Input image.
scale (list): range of scale of the origin size cropped
ratio (list): range of aspect ratio of the origin aspect ratio cropped
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop.
"""
width, height = F._get_image_size(img)
area = height * width
log_ratio = paddle.log(paddle.to_tensor(ratio))
for _ in range(10):
target_area = area * paddle.uniform(
shape=[1], min=scale[0], max=scale[1]).numpy().item()
aspect_ratio = paddle.exp(
paddle.uniform(
shape=[1], min=log_ratio[0], max=log_ratio[1])).numpy(
).item()
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if 0 < w <= width and 0 < h <= height:
i = paddle.randint(
0, height - h + 1, shape=(1, )).numpy().item()
j = paddle.randint(
0, width - w + 1, shape=(1, )).numpy().item()
return i, j, h, w
# Fallback to central crop
in_ratio = float(width) / float(height)
if in_ratio < min(ratio):
w = width
h = int(round(w / min(ratio)))
elif in_ratio > max(ratio):
h = height
w = int(round(h * max(ratio)))
else: # whole image
w = width
h = height
i = (height - h) // 2
j = (width - w) // 2
return i, j, h, w
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be cropped and resized.
Returns:
PIL Image or Tensor: Randomly cropped and resized image.
"""
i, j, h, w = self.get_params(img, self.scale, self.ratio)
return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
def __repr__(self):
interpolate_str = self.interpolation.value
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
format_string += ', scale={0}'.format(
tuple(round(s, 4) for s in self.scale))
format_string += ', ratio={0}'.format(
tuple(round(r, 4) for r in self.ratio))
format_string += ', interpolation={0})'.format(interpolate_str)
return format_string
def _setup_size(size, error_msg):
if isinstance(size, numbers.Number):
return int(size), int(size)
if isinstance(size, Sequence) and len(size) == 1:
return size[0], size[0]
if len(size) != 2:
raise ValueError(error_msg)
return size
import os
import sys
sys.path.append(os.path.abspath(os.path.join(__file__, '../')))
from paddlevision.transforms import autoaugment, transforms
class ClassificationPresetTrain:
def __init__(self,
crop_size,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
hflip_prob=0.5,
auto_augment_policy=None,
random_erase_prob=0.0):
trans = [transforms.RandomResizedCrop(crop_size)]
# if hflip_prob > 0:
# trans.append(transforms.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy))
trans.extend([
transforms.ToTensor(),
transforms.Normalize(
mean=mean, std=std),
])
# if random_erase_prob > 0:
# trans.append(transforms.RandomErasing(p=random_erase_prob))
self.transforms = transforms.Compose(trans)
def __call__(self, img):
return self.transforms(img)
class ClassificationPresetEval:
def __init__(self,
crop_size,
resize_size=256,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)):
self.transforms = transforms.Compose([
transforms.Resize(resize_size),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize(
mean=mean, std=std),
])
def __call__(self, img):
return self.transforms(img)
import datetime
import os
import sys
import time
import paddle
from paddle import nn
import paddlevision
import presets
import utils
import numpy as np
import random
apex = None
import numpy as np
from reprod_log import ReprodLogger
def train_one_epoch(
model,
criterion,
optimizer,
data_loader,
device,
epoch,
print_freq, ):
model.train()
# training log
train_reader_cost = 0.0
train_run_cost = 0.0
total_samples = 0
acc1 = 0.0
acc5 = 0.0
reader_start = time.time()
batch_past = 0
for batch_idx, (image, target) in enumerate(data_loader):
train_reader_cost += time.time() - reader_start
train_start = time.time()
output = model(image)
loss = criterion(output, target)
loss.backward()
optimizer.step()
optimizer.clear_grad()
train_run_cost += time.time() - train_start
acc = utils.accuracy(output, target, topk=(1, 5))
acc1 += acc[0].item()
acc5 += acc[1].item()
total_samples += image.shape[0]
batch_past += 1
if batch_idx > 0 and batch_idx % print_freq == 0:
msg = "[Epoch {}, iter: {}] top1: {:.5f}, top5: {:.5f}, lr: {:.5f}, loss: {:.5f}, avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {}, avg_ips: {:.5f} images/sec.".format(
epoch, batch_idx, acc1 / batch_past, acc5 / batch_past,
optimizer.get_lr(),
loss.item(), train_reader_cost / batch_past,
(train_reader_cost + train_run_cost) / batch_past,
total_samples / batch_past,
total_samples / (train_reader_cost + train_run_cost))
if paddle.distributed.get_rank() <= 0:
print(msg)
sys.stdout.flush()
train_reader_cost = 0.0
train_run_cost = 0.0
total_samples = 0
acc1 = 0.0
acc5 = 0.0
batch_past = 0
reader_start = time.time()
def evaluate(model, criterion, data_loader, device, print_freq=100):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
with paddle.no_grad():
for image, target in metric_logger.log_every(data_loader, print_freq,
header):
output = model(image)
loss = criterion(output, target)
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
# FIXME need to take into account that the datasets
# could have been padded in distributed setup
batch_size = image.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print(' * Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}'.format(
top1=metric_logger.acc1, top5=metric_logger.acc5))
return metric_logger.acc1.global_avg
def load_data(traindir, valdir, args):
# Data loading code
print("Loading data")
resize_size, crop_size = (342, 299) if args.model == 'inception_v3' else (
256, 224)
print("Loading training data")
st = time.time()
auto_augment_policy = getattr(args, "auto_augment", None)
random_erase_prob = getattr(args, "random_erase", 0.0)
dataset = paddlevision.datasets.ImageFolder(
traindir,
presets.ClassificationPresetTrain(
crop_size=crop_size,
auto_augment_policy=auto_augment_policy,
random_erase_prob=random_erase_prob))
print("Took", time.time() - st)
print("Loading validation data")
dataset_test = paddlevision.datasets.ImageFolder(
valdir,
presets.ClassificationPresetEval(
crop_size=crop_size, resize_size=resize_size))
print("Creating data loaders")
train_sampler = paddle.io.DistributedBatchSampler(
dataset=dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=False)
test_sampler = paddle.io.SequenceSampler(dataset_test)
return dataset, dataset_test, train_sampler, test_sampler
def main(args):
if args.output_dir:
utils.mkdir(args.output_dir)
print(args)
device = paddle.set_device(args.device)
# multi cards
if paddle.distributed.get_world_size() > 1:
paddle.distributed.init_parallel_env()
train_dir = os.path.join(args.data_path, 'train')
val_dir = os.path.join(args.data_path, 'val')
dataset, dataset_test, train_sampler, test_sampler = load_data(
train_dir, val_dir, args)
train_batch_sampler = train_sampler
data_loader = paddle.io.DataLoader(
dataset=dataset,
num_workers=args.workers,
return_list=True,
batch_sampler=train_batch_sampler)
test_batch_sampler = paddle.io.BatchSampler(
sampler=test_sampler, batch_size=args.batch_size)
data_loader_test = paddle.io.DataLoader(
dataset_test,
batch_sampler=test_batch_sampler,
num_workers=args.workers)
print("Creating model")
model = paddlevision.models.__dict__[args.model](
pretrained=args.pretrained)
criterion = nn.CrossEntropyLoss()
lr_scheduler = paddle.optimizer.lr.StepDecay(
args.lr, step_size=args.lr_step_size, gamma=args.lr_gamma)
opt_name = args.opt.lower()
if opt_name == 'sgd':
optimizer = paddle.optimizer.Momentum(
learning_rate=lr_scheduler,
momentum=args.momentum,
parameters=model.parameters(),
weight_decay=args.weight_decay)
elif opt_name == 'rmsprop':
optimizer = paddle.optimizer.RMSprop(
learning_rate=lr_scheduler,
momentum=args.momentum,
parameters=model.parameters(),
weight_decay=args.weight_decay,
eps=0.0316,
alpha=0.9)
else:
raise RuntimeError(
"Invalid optimizer {}. Only SGD and RMSprop are supported.".format(
args.opt))
if args.resume:
layer_state_dict = paddle.load(os.path.join(args.resume, '.pdparams'))
model.set_state_dict(layer_state_dict)
opt_state_dict = paddle.load(os.path.join(args.resume, '.pdopt'))
optimizer.load_state_dict(opt_state_dict)
# multi cards
if paddle.distributed.get_world_size() > 1:
model = paddle.DataParallel(model)
if args.test_only and paddle.distributed.get_rank() == 0:
top1 = evaluate(model, criterion, data_loader_test, device=device)
return top1
print("Start training")
start_time = time.time()
best_top1 = 0.0
for epoch in range(args.start_epoch, args.epochs):
train_one_epoch(model, criterion, optimizer, data_loader, device,
epoch, args.print_freq)
lr_scheduler.step()
if paddle.distributed.get_rank() == 0:
top1 = evaluate(model, criterion, data_loader_test, device=device)
best_top1 = max(best_top1, top1)
if args.output_dir:
paddle.save(model.state_dict(),
os.path.join(args.output_dir,
'model_{}.pdparams'.format(epoch)))
paddle.save(optimizer.state_dict(),
os.path.join(args.output_dir,
'model_{}.pdopt'.format(epoch)))
paddle.save(model.state_dict(),
os.path.join(args.output_dir, 'latest.pdparams'))
paddle.save(optimizer.state_dict(),
os.path.join(args.output_dir, 'latest.pdopt'))
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
return best_top1
def get_args_parser(add_help=True):
import argparse
parser = argparse.ArgumentParser(
description='PaddlePaddle Classification Training', add_help=add_help)
parser.add_argument('--data-path', default='../data', help='dataset')
parser.add_argument('--model', default='alexnet', help='model')
parser.add_argument('--device', default='gpu', help='device')
parser.add_argument('-b', '--batch-size', default=32, type=int)
parser.add_argument(
'--epochs',
default=90,
type=int,
metavar='N',
help='number of total epochs to run')
parser.add_argument(
'-j',
'--workers',
default=8,
type=int,
metavar='N',
help='number of data loading workers (default: 16)')
parser.add_argument('--opt', default='sgd', type=str, help='optimizer')
parser.add_argument(
'--lr', default=0.00125, type=float, help='initial learning rate')
parser.add_argument(
'--momentum', default=0.9, type=float, metavar='M', help='momentum')
parser.add_argument(
'--wd',
'--weight-decay',
default=1e-4,
type=float,
metavar='W',
help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument(
'--lr-step-size',
default=30,
type=int,
help='decrease lr every step-size epochs')
parser.add_argument(
'--lr-gamma',
default=0.1,
type=float,
help='decrease lr by a factor of lr-gamma')
parser.add_argument(
'--print-freq', default=10, type=int, help='print frequency')
parser.add_argument('--output-dir', default='.', help='path where to save')
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument(
'--start-epoch', default=0, type=int, metavar='N', help='start epoch')
parser.add_argument(
"--sync-bn",
dest="sync_bn",
help="Use sync batch norm",
action="store_true", )
parser.add_argument(
"--test-only",
dest="test_only",
help="Only test the model",
action="store_true", )
parser.add_argument(
"--pretrained",
dest="pretrained",
help="Use pre-trained models from the modelzoo")
parser.add_argument(
'--auto-augment',
default=None,
help='auto augment policy (default: None)')
parser.add_argument(
'--random-erase',
default=0.0,
type=float,
help='random erasing probability (default: 0.0)')
# Mixed precision training parameters
parser.add_argument(
'--apex',
action='store_true',
help='Use apex for mixed precision training')
parser.add_argument(
'--apex-opt-level',
default='O1',
type=str,
help='For apex mixed precision training'
'O0 for FP32 training, O1 for mixed precision training.'
'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet'
)
return parser
if __name__ == "__main__":
args = get_args_parser().parse_args()
top1 = main(args)
if paddle.distributed.get_rank() == 0:
reprod_logger = ReprodLogger()
reprod_logger.add("top1", np.array([top1]))
reprod_logger.save("train_align_paddle.npy")
from collections import defaultdict, deque, OrderedDict
import copy
import datetime
import hashlib
import time
import paddle
import paddle.distributed as dist
import errno
import os
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
t = paddle.to_tensor([self.count, self.total], dtype='float64')
t = t.numpy().tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = paddle.to_tensor(list(self.deque))
return d.median().numpy().item()
@property
def avg(self):
d = paddle.to_tensor(list(self.deque), dtype='float32')
return d.mean().numpy().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, paddle.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append("{}: {}".format(name, str(meter)))
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
log_msg = self.delimiter.join([
header, '[{0' + space_fmt + '}/{1}]', 'eta: {eta}', '{meters}',
'time: {time}', 'data: {data}'
])
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {}'.format(header, total_time_str))
def accuracy(output, target, topk=(1, )):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with paddle.no_grad():
maxk = max(topk)
batch_size = target.shape[0]
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.equal(target)
res = []
for k in topk:
correct_k = correct.astype(paddle.int32)[:k].flatten().sum(
dtype='float32')
res.append(correct_k / batch_size)
return res
def get_world_size():
return dist.get_world_size()
def mkdir(path):
try:
os.makedirs(path)
except OSError as e:
if e.errno != errno.EEXIST:
raise
from .metric import accuracy_torch
from .presets import *
import torch
def accuracy_torch(output, target, topk=(1, )):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target[None])
res = []
for k in topk:
correct_k = correct[:k].flatten().sum(dtype=torch.float32)
res.append(correct_k * (100.0 / batch_size))
return res
from torchvision.transforms import autoaugment, transforms
class ClassificationPresetTrain:
def __init__(self,
crop_size,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
hflip_prob=0.5,
auto_augment_policy=None,
random_erase_prob=0.0):
trans = [transforms.RandomResizedCrop(crop_size)]
# if hflip_prob > 0:
# trans.append(transforms.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy))
trans.extend([
transforms.ToTensor(),
transforms.Normalize(
mean=mean, std=std),
])
# if random_erase_prob > 0:
# trans.append(transforms.RandomErasing(p=random_erase_prob))
self.transforms = transforms.Compose(trans)
def __call__(self, img):
return self.transforms(img)
class ClassificationPresetEval:
def __init__(self,
crop_size,
resize_size=256,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)):
self.transforms = transforms.Compose([
transforms.Resize(resize_size),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize(
mean=mean, std=std),
])
def __call__(self, img):
return self.transforms(img)
from . import datasets
from . import models
from . import transforms
import os
import importlib.machinery
def _download_file_from_remote_location(fpath: str, url: str) -> None:
pass
def _is_remote_location_available() -> bool:
return False
try:
from torch.hub import load_state_dict_from_url
except ImportError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url
def _get_extension_path(lib_name):
lib_dir = os.path.dirname(__file__)
if os.name == 'nt':
# Register the main torchvision library location on the default DLL path
import ctypes
import sys
kernel32 = ctypes.WinDLL('kernel32.dll', use_last_error=True)
with_load_library_flags = hasattr(kernel32, 'AddDllDirectory')
prev_error_mode = kernel32.SetErrorMode(0x0001)
if with_load_library_flags:
kernel32.AddDllDirectory.restype = ctypes.c_void_p
if sys.version_info >= (3, 8):
os.add_dll_directory(lib_dir)
elif with_load_library_flags:
res = kernel32.AddDllDirectory(lib_dir)
if res is None:
err = ctypes.WinError(ctypes.get_last_error())
err.strerror += f' Error adding "{lib_dir}" to the DLL directories.'
raise err
kernel32.SetErrorMode(prev_error_mode)
loader_details = (importlib.machinery.ExtensionFileLoader,
importlib.machinery.EXTENSION_SUFFIXES)
extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
ext_specs = extfinder.find_spec(lib_name)
if ext_specs is None:
raise ImportError
return ext_specs.origin
from .folder import ImageFolder, DatasetFolder
from .vision import VisionDataset
__all__ = ('ImageFolder', 'DatasetFolder', 'VisionDataset')
from .vision import VisionDataset
from PIL import Image
import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
def has_file_allowed_extension(filename: str,
extensions: Tuple[str, ...]) -> bool:
"""Checks if a file is an allowed extension.
Args:
filename (string): path to a file
extensions (tuple of strings): extensions to consider (lowercase)
Returns:
bool: True if the filename ends with one of given extensions
"""
return filename.lower().endswith(extensions)
def is_image_file(filename: str) -> bool:
"""Checks if a file is an allowed image extension.
Args:
filename (string): path to a file
Returns:
bool: True if the filename ends with a known image extension
"""
return has_file_allowed_extension(filename, IMG_EXTENSIONS)
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Finds the class folders in a dataset.
See :class:`DatasetFolder` for details.
"""
classes = sorted(
entry.name for entry in os.scandir(directory) if entry.is_dir())
if not classes:
raise FileNotFoundError(
f"Couldn't find any class folder in {directory}.")
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
def make_dataset(
directory: str,
class_to_idx: Optional[Dict[str, int]]=None,
extensions: Optional[Tuple[str, ...]]=None,
is_valid_file: Optional[Callable[[str], bool]]=None, ) -> List[Tuple[
str, int]]:
"""Generates a list of samples of a form (path_to_sample, class).
See :class:`DatasetFolder` for details.
Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
by default.
"""
directory = os.path.expanduser(directory)
if class_to_idx is None:
_, class_to_idx = find_classes(directory)
elif not class_to_idx:
raise ValueError(
"'class_to_index' must have at least one entry to collect any samples."
)
both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError(
"Both extensions and is_valid_file cannot be None or not None at the same time"
)
if extensions is not None:
def is_valid_file(x: str) -> bool:
return has_file_allowed_extension(
x, cast(Tuple[str, ...], extensions))
is_valid_file = cast(Callable[[str], bool], is_valid_file)
instances = []
available_classes = set()
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
if not os.path.isdir(target_dir):
continue
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
for fname in sorted(fnames):
if is_valid_file(fname):
path = os.path.join(root, fname)
item = path, class_index
instances.append(item)
if target_class not in available_classes:
available_classes.add(target_class)
return instances
class DatasetFolder(VisionDataset):
"""A generic data loader.
This default directory structure can be customized by overriding the
:meth:`find_classes` method.
Args:
root (string): Root directory path.
loader (callable): A function to load a sample given its path.
extensions (tuple[string]): A list of allowed extensions.
both extensions and is_valid_file should not be passed.
transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.RandomCrop`` for images.
target_transform (callable, optional): A function/transform that takes
in the target and transforms it.
is_valid_file (callable, optional): A function that takes path of a file
and check if the file is a valid file (used to check of corrupt files)
both extensions and is_valid_file should not be passed.
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
samples (list): List of (sample path, class_index) tuples
targets (list): The class_index value for each image in the dataset
"""
def __init__(
self,
root: str,
loader: Callable[[str], Any],
extensions: Optional[Tuple[str, ...]]=None,
transform: Optional[Callable]=None,
target_transform: Optional[Callable]=None,
is_valid_file: Optional[Callable[[str], bool]]=None, ) -> None:
super(DatasetFolder, self).__init__(
root, transform=transform, target_transform=target_transform)
classes, class_to_idx = self.find_classes(self.root)
samples = self.make_dataset(self.root, class_to_idx, extensions,
is_valid_file)
self.loader = loader
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
@staticmethod
def make_dataset(
directory: str,
class_to_idx: Dict[str, int],
extensions: Optional[Tuple[str, ...]]=None,
is_valid_file: Optional[Callable[[str], bool]]=None, ) -> List[
Tuple[str, int]]:
"""Generates a list of samples of a form (path_to_sample, class).
This can be overridden to e.g. read files from a compressed zip file instead of from the disk.
Args:
directory (str): root dataset directory, corresponding to ``self.root``.
class_to_idx (Dict[str, int]): Dictionary mapping class name to class index.
extensions (optional): A list of allowed extensions.
Either extensions or is_valid_file should be passed. Defaults to None.
is_valid_file (optional): A function that takes path of a file
and checks if the file is a valid file
(used to check of corrupt files) both extensions and
is_valid_file should not be passed. Defaults to None.
Raises:
ValueError: In case ``class_to_idx`` is empty.
ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
FileNotFoundError: In case no valid file was found for any class.
Returns:
List[Tuple[str, int]]: samples of a form (path_to_sample, class)
"""
if class_to_idx is None:
# prevent potential bug since make_dataset() would use the class_to_idx logic of the
# find_classes() function, instead of using that of the find_classes() method, which
# is potentially overridden and thus could have a different logic.
raise ValueError("The class_to_idx parameter cannot be None.")
return make_dataset(
directory,
class_to_idx,
extensions=extensions,
is_valid_file=is_valid_file)
def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Find the class folders in a dataset structured as follows::
directory/
├── class_x
│ ├── xxx.ext
│ ├── xxy.ext
│ └── ...
│ └── xxz.ext
└── class_y
├── 123.ext
├── nsdf3.ext
└── ...
└── asd932_.ext
This method can be overridden to only consider
a subset of classes, or to adapt to a different dataset directory structure.
Args:
directory(str): Root directory path, corresponding to ``self.root``
Raises:
FileNotFoundError: If ``dir`` has no class folders.
Returns:
(Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
"""
return find_classes(directory)
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self) -> int:
return len(self.samples)
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
'.tiff', '.webp')
def pil_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
# TODO: specify the return type
def accimage_loader(path: str) -> Any:
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
def default_loader(path: str) -> Any:
return pil_loader(path)
class ImageFolder(DatasetFolder):
"""A generic data loader where the images are arranged in this way by default: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/[...]/asd932_.png
This class inherits from :class:`~torchvision.datasets.DatasetFolder` so
the same methods can be overridden to customize the dataset.
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files)
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def __init__(
self,
root: str,
transform: Optional[Callable]=None,
target_transform: Optional[Callable]=None,
loader: Callable[[str], Any]=default_loader,
is_valid_file: Optional[Callable[[str], bool]]=None, ):
super(ImageFolder, self).__init__(
root,
loader,
IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform,
is_valid_file=is_valid_file)
self.imgs = self.samples
import os
import torch
import torch.utils.data as data
from typing import Any, Callable, List, Optional, Tuple
class VisionDataset(data.Dataset):
"""
Base Class For making datasets which are compatible with torchvision.
It is necessary to override the ``__getitem__`` and ``__len__`` method.
Args:
root (string): Root directory of dataset.
transforms (callable, optional): A function/transforms that takes in
an image and a label and returns the transformed versions of both.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
.. note::
:attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive.
"""
_repr_indent = 4
def __init__(
self,
root: str,
transforms: Optional[Callable]=None,
transform: Optional[Callable]=None,
target_transform: Optional[Callable]=None, ) -> None:
if isinstance(root, torch._six.string_classes):
root = os.path.expanduser(root)
self.root = root
has_transforms = transforms is not None
has_separate_transform = transform is not None or target_transform is not None
if has_transforms and has_separate_transform:
raise ValueError(
"Only transforms or transform/target_transform can "
"be passed as argument")
# for backwards-compatibility
self.transform = transform
self.target_transform = target_transform
if has_separate_transform:
transforms = StandardTransform(transform, target_transform)
self.transforms = transforms
def __getitem__(self, index: int) -> Any:
"""
Args:
index (int): Index
Returns:
(Any): Sample and meta data, optionally transformed by the respective transforms.
"""
raise NotImplementedError
def __len__(self) -> int:
raise NotImplementedError
def __repr__(self) -> str:
head = "Dataset " + self.__class__.__name__
body = ["Number of datapoints: {}".format(self.__len__())]
if self.root is not None:
body.append("Root location: {}".format(self.root))
body += self.extra_repr().splitlines()
if hasattr(self, "transforms") and self.transforms is not None:
body += [repr(self.transforms)]
lines = [head] + [" " * self._repr_indent + line for line in body]
return '\n'.join(lines)
def _format_transform_repr(self, transform: Callable,
head: str) -> List[str]:
lines = transform.__repr__().splitlines()
return (["{}{}".format(head, lines[0])] +
["{}{}".format(" " * len(head), line) for line in lines[1:]])
def extra_repr(self) -> str:
return ""
class StandardTransform(object):
def __init__(self,
transform: Optional[Callable]=None,
target_transform: Optional[Callable]=None) -> None:
self.transform = transform
self.target_transform = target_transform
def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]:
if self.transform is not None:
input = self.transform(input)
if self.target_transform is not None:
target = self.target_transform(target)
return input, target
def _format_transform_repr(self, transform: Callable,
head: str) -> List[str]:
lines = transform.__repr__().splitlines()
return (["{}{}".format(head, lines[0])] +
["{}{}".format(" " * len(head), line) for line in lines[1:]])
def __repr__(self) -> str:
body = [self.__class__.__name__]
if self.transform is not None:
body += self._format_transform_repr(self.transform, "Transform: ")
if self.target_transform is not None:
body += self._format_transform_repr(self.target_transform,
"Target transform: ")
return '\n'.join(body)
from .mobilenet_v3_torch import mobilenet_v3_large, mobilenet_v3_small
from collections import OrderedDict
from typing import Dict, Optional
from torch import nn
class IntermediateLayerGetter(nn.ModuleDict):
"""
Module wrapper that returns intermediate layers from a model
It has a strong assumption that the modules have been registered
into the model in the same order as they are used.
This means that one should **not** reuse the same nn.Module
twice in the forward if you want this to work.
Additionally, it is only able to query submodules that are directly
assigned to the model. So if `model` is passed, `model.feature1` can
be returned, but not `model.feature1.layer2`.
Args:
model (nn.Module): model on which we will extract the features
return_layers (Dict[name, new_name]): a dict containing the names
of the modules for which the activations will be returned as
the key of the dict, and the value of the dict is the name
of the returned activation (which the user can specify).
Examples::
>>> m = torchvision.models.resnet18(pretrained=True)
>>> # extract layer1 and layer3, giving as names `feat1` and feat2`
>>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
>>> {'layer1': 'feat1', 'layer3': 'feat2'})
>>> out = new_m(torch.rand(1, 3, 224, 224))
>>> print([(k, v.shape) for k, v in out.items()])
>>> [('feat1', torch.Size([1, 64, 56, 56])),
>>> ('feat2', torch.Size([1, 256, 14, 14]))]
"""
_version = 2
__annotations__ = {"return_layers": Dict[str, str], }
def __init__(self, model: nn.Module,
return_layers: Dict[str, str]) -> None:
if not set(return_layers).issubset(
[name for name, _ in model.named_children()]):
raise ValueError("return_layers are not present in model")
orig_return_layers = return_layers
return_layers = {str(k): str(v) for k, v in return_layers.items()}
layers = OrderedDict()
for name, module in model.named_children():
layers[name] = module
if name in return_layers:
del return_layers[name]
if not return_layers:
break
super().__init__(layers)
self.return_layers = orig_return_layers
def forward(self, x):
out = OrderedDict()
for name, module in self.items():
x = module(x)
if name in self.return_layers:
out_name = self.return_layers[name]
out[out_name] = x
return out
def _make_divisible(v: float, divisor: int,
min_value: Optional[int]=None) -> int:
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
import warnings
from typing import Callable, List, Optional
import torch
from torch import Tensor
class Conv2d(torch.nn.Conv2d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
warnings.warn(
"torchvision.ops.misc.Conv2d is deprecated and will be "
"removed in future versions, use torch.nn.Conv2d instead.",
FutureWarning, )
class ConvTranspose2d(torch.nn.ConvTranspose2d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
warnings.warn(
"torchvision.ops.misc.ConvTranspose2d is deprecated and will be "
"removed in future versions, use torch.nn.ConvTranspose2d instead.",
FutureWarning, )
class BatchNorm2d(torch.nn.BatchNorm2d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
warnings.warn(
"torchvision.ops.misc.BatchNorm2d is deprecated and will be "
"removed in future versions, use torch.nn.BatchNorm2d instead.",
FutureWarning, )
interpolate = torch.nn.functional.interpolate
# This is not in nn
class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed
Args:
num_features (int): Number of features ``C`` from an expected input of size ``(N, C, H, W)``
eps (float): a value added to the denominator for numerical stability. Default: 1e-5
"""
def __init__(
self,
num_features: int,
eps: float=1e-5,
n: Optional[int]=None, ):
# n=None for backward-compatibility
if n is not None:
warnings.warn(
"`n` argument is deprecated and has been renamed `num_features`",
DeprecationWarning)
num_features = n
super().__init__()
# _log_api_usage_once("ops", self.__class__.__name__)
self.eps = eps
self.register_buffer("weight", torch.ones(num_features))
self.register_buffer("bias", torch.zeros(num_features))
self.register_buffer("running_mean", torch.zeros(num_features))
self.register_buffer("running_var", torch.ones(num_features))
def _load_from_state_dict(
self,
state_dict: dict,
prefix: str,
local_metadata: dict,
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str], ):
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys,
error_msgs)
def forward(self, x: Tensor) -> Tensor:
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
scale = w * (rv + self.eps).rsqrt()
bias = b - rm * scale
return x * scale + bias
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps})"
class ConvNormActivation(torch.nn.Sequential):
"""
Configurable block used for Convolution-Normalzation-Activation blocks.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the Convolution-Normalzation-Activation block
kernel_size: (int, optional): Size of the convolving kernel. Default: 3
stride (int, optional): Stride of the convolution. Default: 1
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in wich case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolutiuon layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d``
activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
dilation (int): Spacing between kernel elements. Default: 1
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int=3,
stride: int=1,
padding: Optional[int]=None,
groups: int=1,
norm_layer: Optional[Callable[
..., torch.nn.Module]]=torch.nn.BatchNorm2d,
activation_layer: Optional[Callable[
..., torch.nn.Module]]=torch.nn.ReLU,
dilation: int=1,
inplace: bool=True,
bias: Optional[bool]=None, ) -> None:
if padding is None:
padding = (kernel_size - 1) // 2 * dilation
if bias is None:
bias = norm_layer is None
layers = [
torch.nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation=dilation,
groups=groups,
bias=bias, )
]
if norm_layer is not None:
layers.append(norm_layer(out_channels))
if activation_layer is not None:
layers.append(activation_layer(inplace=inplace))
super().__init__(*layers)
# _log_api_usage_once("ops", self.__class__.__name__)
self.out_channels = out_channels
class SqueezeExcitation(torch.nn.Module):
"""
This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507 (see Fig. 1).
Parameters ``activation``, and ``scale_activation`` correspond to ``delta`` and ``sigma`` in in eq. 3.
Args:
input_channels (int): Number of channels in the input image
squeeze_channels (int): Number of squeeze channels
activation (Callable[..., torch.nn.Module], optional): ``delta`` activation. Default: ``torch.nn.ReLU``
scale_activation (Callable[..., torch.nn.Module]): ``sigma`` activation. Default: ``torch.nn.Sigmoid``
"""
def __init__(
self,
input_channels: int,
squeeze_channels: int,
activation: Callable[..., torch.nn.Module]=torch.nn.ReLU,
scale_activation: Callable[..., torch.nn.Module]=torch.nn.Sigmoid,
) -> None:
super().__init__()
# _log_api_usage_once("ops", self.__class__.__name__)
self.avgpool = torch.nn.AdaptiveAvgPool2d(1)
self.fc1 = torch.nn.Conv2d(input_channels, squeeze_channels, 1)
self.fc2 = torch.nn.Conv2d(squeeze_channels, input_channels, 1)
self.activation = activation()
self.scale_activation = scale_activation()
def _scale(self, input: Tensor) -> Tensor:
scale = self.avgpool(input)
scale = self.fc1(scale)
scale = self.activation(scale)
scale = self.fc2(scale)
return self.scale_activation(scale)
def forward(self, input: Tensor) -> Tensor:
scale = self._scale(input)
return scale * input
import warnings
from functools import partial
from typing import Any, Callable, List, Optional, Sequence
import torch
from torch import nn, Tensor
# from .._internally_replaced_utils import load_state_dict_from_url
from .misc_torch import ConvNormActivation, SqueezeExcitation as SElayer
from ._utils import _make_divisible
__all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"]
model_urls = {
"mobilenet_v3_large":
"https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
"mobilenet_v3_small":
"https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
}
class SqueezeExcitation(SElayer):
"""DEPRECATED"""
def __init__(self, input_channels: int, squeeze_factor: int=4):
squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
super().__init__(
input_channels, squeeze_channels, scale_activation=nn.Hardsigmoid)
self.relu = self.activation
delattr(self, "activation")
warnings.warn(
"This SqueezeExcitation class is deprecated and will be removed in future versions. "
"Use torchvision.ops.misc.SqueezeExcitation instead.",
FutureWarning, )
class InvertedResidualConfig:
# Stores information listed at Tables 1 and 2 of the MobileNetV3 paper
def __init__(
self,
input_channels: int,
kernel: int,
expanded_channels: int,
out_channels: int,
use_se: bool,
activation: str,
stride: int,
dilation: int,
width_mult: float, ):
self.input_channels = self.adjust_channels(input_channels, width_mult)
self.kernel = kernel
self.expanded_channels = self.adjust_channels(expanded_channels,
width_mult)
self.out_channels = self.adjust_channels(out_channels, width_mult)
self.use_se = use_se
self.use_hs = activation == "HS"
self.stride = stride
self.dilation = dilation
@staticmethod
def adjust_channels(channels: int, width_mult: float):
return _make_divisible(channels * width_mult, 8)
class InvertedResidual(nn.Module):
# Implemented as described at section 5 of MobileNetV3 paper
def __init__(
self,
cnf: InvertedResidualConfig,
norm_layer: Callable[..., nn.Module],
se_layer: Callable[..., nn.Module]=partial(
SElayer, scale_activation=nn.Hardsigmoid), ):
super().__init__()
if not (1 <= cnf.stride <= 2):
raise ValueError("illegal stride value")
self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
layers: List[nn.Module] = []
activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU
# expand
if cnf.expanded_channels != cnf.input_channels:
layers.append(
ConvNormActivation(
cnf.input_channels,
cnf.expanded_channels,
kernel_size=1,
norm_layer=norm_layer,
activation_layer=activation_layer, ))
# depthwise
stride = 1 if cnf.dilation > 1 else cnf.stride
layers.append(
ConvNormActivation(
cnf.expanded_channels,
cnf.expanded_channels,
kernel_size=cnf.kernel,
stride=stride,
dilation=cnf.dilation,
groups=cnf.expanded_channels,
norm_layer=norm_layer,
activation_layer=activation_layer, ))
if cnf.use_se:
squeeze_channels = _make_divisible(cnf.expanded_channels // 4, 8)
layers.append(se_layer(cnf.expanded_channels, squeeze_channels))
# project
layers.append(
ConvNormActivation(
cnf.expanded_channels,
cnf.out_channels,
kernel_size=1,
norm_layer=norm_layer,
activation_layer=None))
self.block = nn.Sequential(*layers)
self.out_channels = cnf.out_channels
self._is_cn = cnf.stride > 1
def forward(self, input: Tensor) -> Tensor:
result = self.block(input)
if self.use_res_connect:
result += input
return result
class MobileNetV3(nn.Module):
def __init__(
self,
inverted_residual_setting: List[InvertedResidualConfig],
last_channel: int,
num_classes: int=1000,
block: Optional[Callable[..., nn.Module]]=None,
norm_layer: Optional[Callable[..., nn.Module]]=None,
dropout: float=0.2,
**kwargs: Any, ) -> None:
"""
MobileNet V3 main class
Args:
inverted_residual_setting (List[InvertedResidualConfig]): Network structure
last_channel (int): The number of channels on the penultimate layer
num_classes (int): Number of classes
block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet
norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
dropout (float): The droupout probability
"""
super().__init__()
if not inverted_residual_setting:
raise ValueError(
"The inverted_residual_setting should not be empty")
elif not (isinstance(inverted_residual_setting, Sequence) and all([
isinstance(s, InvertedResidualConfig)
for s in inverted_residual_setting
])):
raise TypeError(
"The inverted_residual_setting should be List[InvertedResidualConfig]"
)
if block is None:
block = InvertedResidual
if norm_layer is None:
norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01)
layers: List[nn.Module] = []
# building first layer
firstconv_output_channels = inverted_residual_setting[0].input_channels
layers.append(
ConvNormActivation(
3,
firstconv_output_channels,
kernel_size=3,
stride=2,
norm_layer=norm_layer,
activation_layer=nn.Hardswish, ))
# building inverted residual blocks
for cnf in inverted_residual_setting:
layers.append(block(cnf, norm_layer))
# building last several layers
lastconv_input_channels = inverted_residual_setting[-1].out_channels
lastconv_output_channels = 6 * lastconv_input_channels
layers.append(
ConvNormActivation(
lastconv_input_channels,
lastconv_output_channels,
kernel_size=1,
norm_layer=norm_layer,
activation_layer=nn.Hardswish, ))
self.features = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Sequential(
nn.Linear(lastconv_output_channels, last_channel),
nn.Hardswish(inplace=True),
nn.Dropout(
p=dropout, inplace=True),
nn.Linear(last_channel, num_classes), )
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)
def _forward_impl(self, x: Tensor) -> Tensor:
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)
def _mobilenet_v3_conf(arch: str,
width_mult: float=1.0,
reduced_tail: bool=False,
dilated: bool=False,
**kwargs: Any):
reduce_divider = 2 if reduced_tail else 1
dilation = 2 if dilated else 1
bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
adjust_channels = partial(
InvertedResidualConfig.adjust_channels, width_mult=width_mult)
if arch == "mobilenet_v3_large":
inverted_residual_setting = [
bneck_conf(16, 3, 16, 16, False, "RE", 1, 1),
bneck_conf(16, 3, 64, 24, False, "RE", 2, 1), # C1
bneck_conf(24, 3, 72, 24, False, "RE", 1, 1),
bneck_conf(24, 5, 72, 40, True, "RE", 2, 1), # C2
bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
bneck_conf(40, 3, 240, 80, False, "HS", 2, 1), # C3
bneck_conf(80, 3, 200, 80, False, "HS", 1, 1),
bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
bneck_conf(80, 3, 480, 112, True, "HS", 1, 1),
bneck_conf(112, 3, 672, 112, True, "HS", 1, 1),
bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2,
dilation), # C4
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider,
160 // reduce_divider, True, "HS", 1, dilation),
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider,
160 // reduce_divider, True, "HS", 1, dilation),
]
last_channel = adjust_channels(1280 // reduce_divider) # C5
elif arch == "mobilenet_v3_small":
inverted_residual_setting = [
bneck_conf(16, 3, 16, 16, True, "RE", 2, 1), # C1
bneck_conf(16, 3, 72, 24, False, "RE", 2, 1), # C2
bneck_conf(24, 3, 88, 24, False, "RE", 1, 1),
bneck_conf(24, 5, 96, 40, True, "HS", 2, 1), # C3
bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
bneck_conf(40, 5, 120, 48, True, "HS", 1, 1),
bneck_conf(48, 5, 144, 48, True, "HS", 1, 1),
bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2,
dilation), # C4
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider,
96 // reduce_divider, True, "HS", 1, dilation),
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider,
96 // reduce_divider, True, "HS", 1, dilation),
]
last_channel = adjust_channels(1024 // reduce_divider) # C5
else:
raise ValueError(f"Unsupported model type {arch}")
return inverted_residual_setting, last_channel
def _mobilenet_v3(
arch: str,
inverted_residual_setting: List[InvertedResidualConfig],
last_channel: int,
pretrained: bool,
progress: bool,
**kwargs: Any, ):
model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
if pretrained:
if model_urls.get(arch, None) is None:
raise ValueError(
f"No checkpoint is available for model type {arch}")
state_dict = load_state_dict_from_url(
model_urls[arch], progress=progress)
model.load_state_dict(state_dict)
return model
def mobilenet_v3_large(pretrained: bool=False,
progress: bool=True,
**kwargs: Any) -> MobileNetV3:
"""
Constructs a large MobileNetV3 architecture from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
arch = "mobilenet_v3_large"
inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch,
**kwargs)
return _mobilenet_v3(arch, inverted_residual_setting, last_channel,
pretrained, progress, **kwargs)
def mobilenet_v3_small(pretrained: bool=False,
progress: bool=True,
**kwargs: Any) -> MobileNetV3:
"""
Constructs a small MobileNetV3 architecture from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
arch = "mobilenet_v3_small"
inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch,
**kwargs)
return _mobilenet_v3(arch, inverted_residual_setting, last_channel,
pretrained, progress, **kwargs)
import math
import torch
from enum import Enum
from torch import Tensor
from typing import List, Tuple, Optional
from . import functional as F, InterpolationMode
__all__ = ["AutoAugmentPolicy", "AutoAugment"]
class AutoAugmentPolicy(Enum):
"""AutoAugment policies learned on different datasets.
Available policies are IMAGENET, CIFAR10 and SVHN.
"""
IMAGENET = "imagenet"
CIFAR10 = "cifar10"
SVHN = "svhn"
def _get_transforms(policy: AutoAugmentPolicy):
if policy == AutoAugmentPolicy.IMAGENET:
return [
(("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
(("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
(("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
(("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
(("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
(("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
(("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
(("Rotate", 0.8, 8), ("Color", 0.4, 0)),
(("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
(("Equalize", 0.0, None), ("Equalize", 0.8, None)),
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
(("Rotate", 0.8, 8), ("Color", 1.0, 2)),
(("Color", 0.8, 8), ("Solarize", 0.8, 7)),
(("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
(("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
(("Color", 0.4, 0), ("Equalize", 0.6, None)),
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
]
elif policy == AutoAugmentPolicy.CIFAR10:
return [
(("Invert", 0.1, None), ("Contrast", 0.2, 6)),
(("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
(("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
(("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
(("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
(("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
(("Color", 0.4, 3), ("Brightness", 0.6, 7)),
(("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
(("Equalize", 0.6, None), ("Equalize", 0.5, None)),
(("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
(("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
(("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
(("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
(("Brightness", 0.9, 6), ("Color", 0.2, 8)),
(("Solarize", 0.5, 2), ("Invert", 0.0, None)),
(("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
(("Equalize", 0.2, None), ("Equalize", 0.6, None)),
(("Color", 0.9, 9), ("Equalize", 0.6, None)),
(("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
(("Brightness", 0.1, 3), ("Color", 0.7, 0)),
(("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
(("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
(("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
(("Equalize", 0.8, None), ("Invert", 0.1, None)),
(("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
]
elif policy == AutoAugmentPolicy.SVHN:
return [
(("ShearX", 0.9, 4), ("Invert", 0.2, None)),
(("ShearY", 0.9, 8), ("Invert", 0.7, None)),
(("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
(("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
(("ShearY", 0.9, 8), ("Invert", 0.4, None)),
(("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
(("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
(("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
(("ShearY", 0.8, 8), ("Invert", 0.7, None)),
(("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
(("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
(("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
(("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
(("Invert", 0.6, None), ("Rotate", 0.8, 4)),
(("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
(("ShearX", 0.1, 6), ("Invert", 0.6, None)),
(("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
(("ShearY", 0.8, 4), ("Invert", 0.8, None)),
(("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
(("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
(("ShearX", 0.7, 2), ("Invert", 0.1, None)),
]
def _get_magnitudes():
_BINS = 10
return {
# name: (magnitudes, signed)
"ShearX": (torch.linspace(0.0, 0.3, _BINS), True),
"ShearY": (torch.linspace(0.0, 0.3, _BINS), True),
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0, _BINS), True),
"TranslateY": (torch.linspace(0.0, 150.0 / 331.0, _BINS), True),
"Rotate": (torch.linspace(0.0, 30.0, _BINS), True),
"Brightness": (torch.linspace(0.0, 0.9, _BINS), True),
"Color": (torch.linspace(0.0, 0.9, _BINS), True),
"Contrast": (torch.linspace(0.0, 0.9, _BINS), True),
"Sharpness": (torch.linspace(0.0, 0.9, _BINS), True),
"Posterize": (torch.tensor([8, 8, 7, 7, 6, 6, 5, 5, 4, 4]), False),
"Solarize": (torch.linspace(256.0, 0.0, _BINS), False),
"AutoContrast": (None, None),
"Equalize": (None, None),
"Invert": (None, None),
}
class AutoAugment(torch.nn.Module):
r"""AutoAugment data augmentation method based on
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
policy (AutoAugmentPolicy): Desired policy enum defined by
:class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""
def __init__(self,
policy: AutoAugmentPolicy=AutoAugmentPolicy.IMAGENET,
interpolation: InterpolationMode=InterpolationMode.NEAREST,
fill: Optional[List[float]]=None):
super().__init__()
self.policy = policy
self.interpolation = interpolation
self.fill = fill
self.transforms = _get_transforms(policy)
if self.transforms is None:
raise ValueError(
"The provided policy {} is not recognized.".format(policy))
self._op_meta = _get_magnitudes()
@staticmethod
def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]:
"""Get parameters for autoaugment transformation
Returns:
params required by the autoaugment transformation
"""
policy_id = torch.randint(transform_num, (1, )).item()
probs = torch.rand((2, ))
signs = torch.randint(2, (2, ))
return policy_id, probs, signs
def _get_op_meta(self,
name: str) -> Tuple[Optional[Tensor], Optional[bool]]:
return self._op_meta[name]
def forward(self, img: Tensor):
"""
img (PIL Image or Tensor): Image to be transformed.
Returns:
PIL Image or Tensor: AutoAugmented image.
"""
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img)
elif fill is not None:
fill = [float(f) for f in fill]
transform_id, probs, signs = self.get_params(len(self.transforms))
for i, (op_name, p,
magnitude_id) in enumerate(self.transforms[transform_id]):
if probs[i] <= p:
magnitudes, signed = self._get_op_meta(op_name)
magnitude = float(magnitudes[magnitude_id].item()) \
if magnitudes is not None and magnitude_id is not None else 0.0
if signed is not None and signed and signs[i] == 0:
magnitude *= -1.0
if op_name == "ShearX":
img = F.affine(
img,
angle=0.0,
translate=[0, 0],
scale=1.0,
shear=[math.degrees(magnitude), 0.0],
interpolation=self.interpolation,
fill=fill)
elif op_name == "ShearY":
img = F.affine(
img,
angle=0.0,
translate=[0, 0],
scale=1.0,
shear=[0.0, math.degrees(magnitude)],
interpolation=self.interpolation,
fill=fill)
elif op_name == "TranslateX":
img = F.affine(
img,
angle=0.0,
translate=[
int(F._get_image_size(img)[0] * magnitude), 0
],
scale=1.0,
interpolation=self.interpolation,
shear=[0.0, 0.0],
fill=fill)
elif op_name == "TranslateY":
img = F.affine(
img,
angle=0.0,
translate=[
0, int(F._get_image_size(img)[1] * magnitude)
],
scale=1.0,
interpolation=self.interpolation,
shear=[0.0, 0.0],
fill=fill)
elif op_name == "Rotate":
img = F.rotate(
img,
magnitude,
interpolation=self.interpolation,
fill=fill)
elif op_name == "Brightness":
img = F.adjust_brightness(img, 1.0 + magnitude)
elif op_name == "Color":
img = F.adjust_saturation(img, 1.0 + magnitude)
elif op_name == "Contrast":
img = F.adjust_contrast(img, 1.0 + magnitude)
elif op_name == "Sharpness":
img = F.adjust_sharpness(img, 1.0 + magnitude)
elif op_name == "Posterize":
img = F.posterize(img, int(magnitude))
elif op_name == "Solarize":
img = F.solarize(img, magnitude)
elif op_name == "AutoContrast":
img = F.autocontrast(img)
elif op_name == "Equalize":
img = F.equalize(img)
elif op_name == "Invert":
img = F.invert(img)
else:
raise ValueError(
"The provided operator {} is not recognized.".format(
op_name))
return img
def __repr__(self):
return self.__class__.__name__ + '(policy={}, fill={})'.format(
self.policy, self.fill)
import math
import numbers
import warnings
from enum import Enum
import numpy as np
from PIL import Image
import torch
from torch import Tensor
from typing import List, Tuple, Any, Optional
try:
import accimage
except ImportError:
accimage = None
from . import functional_pil as F_pil
from . import functional_tensor as F_t
class InterpolationMode(Enum):
"""Interpolation modes
Available interpolation methods are ``nearest``, ``bilinear``, ``bicubic``, ``box``, ``hamming``, and ``lanczos``.
"""
NEAREST = "nearest"
BILINEAR = "bilinear"
BICUBIC = "bicubic"
# For PIL compatibility
BOX = "box"
HAMMING = "hamming"
LANCZOS = "lanczos"
# TODO: Once torchscript supports Enums with staticmethod
# this can be put into InterpolationMode as staticmethod
def _interpolation_modes_from_int(i: int) -> InterpolationMode:
inverse_modes_mapping = {
0: InterpolationMode.NEAREST,
2: InterpolationMode.BILINEAR,
3: InterpolationMode.BICUBIC,
4: InterpolationMode.BOX,
5: InterpolationMode.HAMMING,
1: InterpolationMode.LANCZOS,
}
return inverse_modes_mapping[i]
pil_modes_mapping = {
InterpolationMode.NEAREST: 0,
InterpolationMode.BILINEAR: 2,
InterpolationMode.BICUBIC: 3,
InterpolationMode.BOX: 4,
InterpolationMode.HAMMING: 5,
InterpolationMode.LANCZOS: 1,
}
_is_pil_image = F_pil._is_pil_image
def _get_image_size(img: Tensor) -> List[int]:
"""Returns image size as [w, h]
"""
if isinstance(img, torch.Tensor):
return F_t._get_image_size(img)
return F_pil._get_image_size(img)
def _get_image_num_channels(img: Tensor) -> int:
"""Returns number of image channels
"""
if isinstance(img, torch.Tensor):
return F_t._get_image_num_channels(img)
return F_pil._get_image_num_channels(img)
@torch.jit.unused
def _is_numpy(img: Any) -> bool:
return isinstance(img, np.ndarray)
@torch.jit.unused
def _is_numpy_image(img: Any) -> bool:
return img.ndim in {2, 3}
def to_tensor(pic):
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
This function does not support torchscript.
See :class:`~torchvision.transforms.ToTensor` for more details.
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
if not (F_pil._is_pil_image(pic) or _is_numpy(pic)):
raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(
type(pic)))
if _is_numpy(pic) and not _is_numpy_image(pic):
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.
format(pic.ndim))
default_float_dtype = torch.get_default_dtype()
if isinstance(pic, np.ndarray):
# handle numpy array
if pic.ndim == 2:
pic = pic[:, :, None]
img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
# backward compatibility
if isinstance(img, torch.ByteTensor):
return img.to(dtype=default_float_dtype).div(255)
else:
return img
if accimage is not None and isinstance(pic, accimage.Image):
nppic = np.zeros(
[pic.channels, pic.height, pic.width], dtype=np.float32)
pic.copyto(nppic)
return torch.from_numpy(nppic).to(dtype=default_float_dtype)
# handle PIL Image
mode_to_nptype = {'I': np.int32, 'I;16': np.int16, 'F': np.float32}
img = torch.from_numpy(
np.array(
pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True))
if pic.mode == '1':
img = 255 * img
img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
# put it from HWC to CHW format
img = img.permute((2, 0, 1)).contiguous()
if isinstance(img, torch.ByteTensor):
return img.to(dtype=default_float_dtype).div(255)
else:
return img
def pil_to_tensor(pic):
"""Convert a ``PIL Image`` to a tensor of the same type.
This function does not support torchscript.
See :class:`~torchvision.transforms.PILToTensor` for more details.
Args:
pic (PIL Image): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
if not F_pil._is_pil_image(pic):
raise TypeError('pic should be PIL Image. Got {}'.format(type(pic)))
if accimage is not None and isinstance(pic, accimage.Image):
# accimage format is always uint8 internally, so always return uint8 here
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.uint8)
pic.copyto(nppic)
return torch.as_tensor(nppic)
# handle PIL Image
img = torch.as_tensor(np.asarray(pic))
img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
# put it from HWC to CHW format
img = img.permute((2, 0, 1))
return img
def convert_image_dtype(image: torch.Tensor,
dtype: torch.dtype=torch.float) -> torch.Tensor:
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly
This function does not support PIL Image.
Args:
image (torch.Tensor): Image to be converted
dtype (torch.dtype): Desired data type of the output
Returns:
Tensor: Converted image
.. note::
When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
If converted back and forth, this mismatch has no effect.
Raises:
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
of the integer ``dtype``.
"""
if not isinstance(image, torch.Tensor):
raise TypeError('Input img should be Tensor Image')
return F_t.convert_image_dtype(image, dtype)
def to_pil_image(pic, mode=None):
"""Convert a tensor or an ndarray to PIL Image. This function does not support torchscript.
See :class:`~torchvision.transforms.ToPILImage` for more details.
Args:
pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
Returns:
PIL Image: Image converted to PIL Image.
"""
if not (isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(
type(pic)))
elif isinstance(pic, torch.Tensor):
if pic.ndimension() not in {2, 3}:
raise ValueError(
'pic should be 2/3 dimensional. Got {} dimensions.'.format(
pic.ndimension()))
elif pic.ndimension() == 2:
# if 2D image, add channel dimension (CHW)
pic = pic.unsqueeze(0)
# check number of channels
if pic.shape[-3] > 4:
raise ValueError(
'pic should not have > 4 channels. Got {} channels.'.format(
pic.shape[-3]))
elif isinstance(pic, np.ndarray):
if pic.ndim not in {2, 3}:
raise ValueError(
'pic should be 2/3 dimensional. Got {} dimensions.'.format(
pic.ndim))
elif pic.ndim == 2:
# if 2D image, add channel dimension (HWC)
pic = np.expand_dims(pic, 2)
# check number of channels
if pic.shape[-1] > 4:
raise ValueError(
'pic should not have > 4 channels. Got {} channels.'.format(
pic.shape[-1]))
npimg = pic
if isinstance(pic, torch.Tensor):
if pic.is_floating_point() and mode != 'F':
pic = pic.mul(255).byte()
npimg = np.transpose(pic.cpu().numpy(), (1, 2, 0))
if not isinstance(npimg, np.ndarray):
raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' +
'not {}'.format(type(npimg)))
if npimg.shape[2] == 1:
expected_mode = None
npimg = npimg[:, :, 0]
if npimg.dtype == np.uint8:
expected_mode = 'L'
elif npimg.dtype == np.int16:
expected_mode = 'I;16'
elif npimg.dtype == np.int32:
expected_mode = 'I'
elif npimg.dtype == np.float32:
expected_mode = 'F'
if mode is not None and mode != expected_mode:
raise ValueError(
"Incorrect mode ({}) supplied for input type {}. Should be {}"
.format(mode, np.dtype, expected_mode))
mode = expected_mode
elif npimg.shape[2] == 2:
permitted_2_channel_modes = ['LA']
if mode is not None and mode not in permitted_2_channel_modes:
raise ValueError("Only modes {} are supported for 2D inputs".
format(permitted_2_channel_modes))
if mode is None and npimg.dtype == np.uint8:
mode = 'LA'
elif npimg.shape[2] == 4:
permitted_4_channel_modes = ['RGBA', 'CMYK', 'RGBX']
if mode is not None and mode not in permitted_4_channel_modes:
raise ValueError("Only modes {} are supported for 4D inputs".
format(permitted_4_channel_modes))
if mode is None and npimg.dtype == np.uint8:
mode = 'RGBA'
else:
permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV']
if mode is not None and mode not in permitted_3_channel_modes:
raise ValueError("Only modes {} are supported for 3D inputs".
format(permitted_3_channel_modes))
if mode is None and npimg.dtype == np.uint8:
mode = 'RGB'
if mode is None:
raise TypeError('Input type {} is not supported'.format(npimg.dtype))
return Image.fromarray(npimg, mode=mode)
def normalize(tensor: Tensor,
mean: List[float],
std: List[float],
inplace: bool=False) -> Tensor:
"""Normalize a float tensor image with mean and standard deviation.
This transform does not support PIL Image.
.. note::
This transform acts out of place by default, i.e., it does not mutates the input tensor.
See :class:`~torchvision.transforms.Normalize` for more details.
Args:
tensor (Tensor): Float tensor image of size (C, H, W) or (B, C, H, W) to be normalized.
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
inplace(bool,optional): Bool to make this operation inplace.
Returns:
Tensor: Normalized Tensor image.
"""
if not isinstance(tensor, torch.Tensor):
raise TypeError('Input tensor should be a torch tensor. Got {}.'.
format(type(tensor)))
if not tensor.is_floating_point():
raise TypeError('Input tensor should be a float tensor. Got {}.'.
format(tensor.dtype))
if tensor.ndim < 3:
raise ValueError(
'Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = '
'{}.'.format(tensor.size()))
if not inplace:
tensor = tensor.clone()
dtype = tensor.dtype
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
if (std == 0).any():
raise ValueError(
'std evaluated to zero after conversion to {}, leading to division by zero.'.
format(dtype))
if mean.ndim == 1:
mean = mean.view(-1, 1, 1)
if std.ndim == 1:
std = std.view(-1, 1, 1)
tensor.sub_(mean).div_(std)
return tensor
def resize(img: Tensor,
size: List[int],
interpolation: InterpolationMode=InterpolationMode.BILINEAR,
max_size: Optional[int]=None,
antialias: Optional[bool]=None) -> Tensor:
r"""Resize the input image to the given size.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
.. warning::
The output image might be different depending on its type: when downsampling, the interpolation of PIL images
and tensors is slightly different, because PIL applies antialiasing. This may lead to significant differences
in the performance of a network. Therefore, it is preferable to train and serve a model with the same input
types. See also below the ``antialias`` parameter, which can help making the output of PIL images and tensors
closer.
Args:
img (PIL Image or Tensor): Image to be resized.
size (sequence or int): Desired output size. If size is a sequence like
(h, w), the output size will be matched to this. If size is an int,
the smaller edge of the image will be matched to this number maintaining
the aspect ratio. i.e, if height > width, then image will be rescaled to
:math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
.. note::
In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`.
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
max_size (int, optional): The maximum allowed for the longer edge of
the resized image: if the longer edge of the image is greater
than ``max_size`` after being resized according to ``size``, then
the image is resized again so that the longer edge is equal to
``max_size``. As a result, ``size`` might be overruled, i.e the
smaller edge may be shorter than ``size``. This is only supported
if ``size`` is an int (or a sequence of length 1 in torchscript
mode).
antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias
is always used. If ``img`` is Tensor, the flag is False by default and can be set to True for
``InterpolationMode.BILINEAR`` only mode. This can help making the output for PIL images and tensors
closer.
.. warning::
There is no autodiff support for ``antialias=True`` option with input ``img`` as Tensor.
Returns:
PIL Image or Tensor: Resized image.
"""
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationMode instead of int. "
"Please, use InterpolationMode enum.")
interpolation = _interpolation_modes_from_int(interpolation)
if not isinstance(interpolation, InterpolationMode):
raise TypeError("Argument interpolation should be a InterpolationMode")
if not isinstance(img, torch.Tensor):
if antialias is not None and not antialias:
warnings.warn(
"Anti-alias option is always applied for PIL Image input. Argument antialias is ignored."
)
pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.resize(
img, size=size, interpolation=pil_interpolation, max_size=max_size)
return F_t.resize(
img,
size=size,
interpolation=interpolation.value,
max_size=max_size,
antialias=antialias)
def scale(*args, **kwargs):
warnings.warn("The use of the transforms.Scale transform is deprecated, " +
"please use transforms.Resize instead.")
return resize(*args, **kwargs)
def pad(img: Tensor,
padding: List[int],
fill: int=0,
padding_mode: str="constant") -> Tensor:
r"""Pad the given image on all sides with the given "pad" value.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric,
at most 3 leading dimensions for mode edge,
and an arbitrary number of leading dimensions for mode constant
Args:
img (PIL Image or Tensor): Image to be padded.
padding (int or sequence): Padding on each border. If a single int is provided this
is used to pad all borders. If sequence of length 2 is provided this is the padding
on left/right and top/bottom respectively. If a sequence of length 4 is provided
this is the padding for the left, top, right and bottom borders respectively.
.. note::
In torchscript mode padding as single int is not supported, use a sequence of
length 1: ``[padding, ]``.
fill (number or str or tuple): Pixel fill value for constant fill. Default is 0.
If a tuple of length 3, it is used to fill R, G, B channels respectively.
This value is only used when the padding_mode is constant.
Only number is supported for torch Tensor.
Only int or str or tuple value is supported for PIL Image.
padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
Default is constant.
- constant: pads with a constant value, this value is specified with fill
- edge: pads with the last value at the edge of the image.
If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
- reflect: pads with reflection of image without repeating the last value on the edge.
For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
will result in [3, 2, 1, 2, 3, 4, 3, 2]
- symmetric: pads with reflection of image repeating the last value on the edge.
For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
will result in [2, 1, 1, 2, 3, 4, 4, 3]
Returns:
PIL Image or Tensor: Padded image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.pad(img,
padding=padding,
fill=fill,
padding_mode=padding_mode)
return F_t.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
"""Crop the given image at specified location and output size.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
If image size is smaller than output size along any edge, image is padded with 0 and then cropped.
Args:
img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
top (int): Vertical component of the top left corner of the crop box.
left (int): Horizontal component of the top left corner of the crop box.
height (int): Height of the crop box.
width (int): Width of the crop box.
Returns:
PIL Image or Tensor: Cropped image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.crop(img, top, left, height, width)
return F_t.crop(img, top, left, height, width)
def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
"""Crops the given image at the center.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
Args:
img (PIL Image or Tensor): Image to be cropped.
output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
it is used for both directions.
Returns:
PIL Image or Tensor: Cropped image.
"""
if isinstance(output_size, numbers.Number):
output_size = (int(output_size), int(output_size))
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
output_size = (output_size[0], output_size[0])
image_width, image_height = _get_image_size(img)
crop_height, crop_width = output_size
if crop_width > image_width or crop_height > image_height:
padding_ltrb = [
(crop_width - image_width) // 2 if crop_width > image_width else 0,
(crop_height - image_height) // 2
if crop_height > image_height else 0,
(crop_width - image_width + 1) // 2
if crop_width > image_width else 0,
(crop_height - image_height + 1) // 2
if crop_height > image_height else 0,
]
img = pad(img, padding_ltrb, fill=0) # PIL uses fill value 0
image_width, image_height = _get_image_size(img)
if crop_width == image_width and crop_height == image_height:
return img
crop_top = int(round((image_height - crop_height) / 2.))
crop_left = int(round((image_width - crop_width) / 2.))
return crop(img, crop_top, crop_left, crop_height, crop_width)
def resized_crop(
img: Tensor,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode=InterpolationMode.BILINEAR) -> Tensor:
"""Crop the given image and resize it to desired size.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
Notably used in :class:`~torchvision.transforms.RandomResizedCrop`.
Args:
img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
top (int): Vertical component of the top left corner of the crop box.
left (int): Horizontal component of the top left corner of the crop box.
height (int): Height of the crop box.
width (int): Width of the crop box.
size (sequence or int): Desired output size. Same semantics as ``resize``.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`.
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
Returns:
PIL Image or Tensor: Cropped image.
"""
img = crop(img, top, left, height, width)
img = resize(img, size, interpolation)
return img
def hflip(img: Tensor) -> Tensor:
"""Horizontally flip the given image.
Args:
img (PIL Image or Tensor): Image to be flipped. If img
is a Tensor, it is expected to be in [..., H, W] format,
where ... means it can have an arbitrary number of leading
dimensions.
Returns:
PIL Image or Tensor: Horizontally flipped image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.hflip(img)
return F_t.hflip(img)
def _get_perspective_coeffs(startpoints: List[List[int]],
endpoints: List[List[int]]) -> List[float]:
"""Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.
In Perspective Transform each pixel (x, y) in the original image gets transformed as,
(x, y) -> ( (ax + by + c) / (gx + hy + 1), (dx + ey + f) / (gx + hy + 1) )
Args:
startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
Returns:
octuple (a, b, c, d, e, f, g, h) for transforming each pixel.
"""
a_matrix = torch.zeros(2 * len(startpoints), 8, dtype=torch.float)
for i, (p1, p2) in enumerate(zip(endpoints, startpoints)):
a_matrix[2 * i, :] = torch.tensor(
[p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]])
a_matrix[2 * i + 1, :] = torch.tensor(
[0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])
b_matrix = torch.tensor(startpoints, dtype=torch.float).view(8)
res = torch.linalg.lstsq(a_matrix, b_matrix, driver='gels').solution
output: List[float] = res.tolist()
return output
def perspective(img: Tensor,
startpoints: List[List[int]],
endpoints: List[List[int]],
interpolation: InterpolationMode=InterpolationMode.BILINEAR,
fill: Optional[List[float]]=None) -> Tensor:
"""Perform perspective transform of the given image.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
img (PIL Image or Tensor): Image to be transformed.
startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
.. note::
In torchscript mode single int/float value is not supported, please use a sequence
of length 1: ``[value, ]``.
Returns:
PIL Image or Tensor: transformed Image.
"""
coeffs = _get_perspective_coeffs(startpoints, endpoints)
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationMode instead of int. "
"Please, use InterpolationMode enum.")
interpolation = _interpolation_modes_from_int(interpolation)
if not isinstance(interpolation, InterpolationMode):
raise TypeError("Argument interpolation should be a InterpolationMode")
if not isinstance(img, torch.Tensor):
pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.perspective(
img, coeffs, interpolation=pil_interpolation, fill=fill)
return F_t.perspective(
img, coeffs, interpolation=interpolation.value, fill=fill)
def vflip(img: Tensor) -> Tensor:
"""Vertically flip the given image.
Args:
img (PIL Image or Tensor): Image to be flipped. If img
is a Tensor, it is expected to be in [..., H, W] format,
where ... means it can have an arbitrary number of leading
dimensions.
Returns:
PIL Image or Tensor: Vertically flipped image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.vflip(img)
return F_t.vflip(img)
def five_crop(
img: Tensor,
size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Crop the given image into four corners and the central crop.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
.. Note::
This transform returns a tuple of images and there may be a
mismatch in the number of inputs and targets your ``Dataset`` returns.
Args:
img (PIL Image or Tensor): Image to be cropped.
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
Returns:
tuple: tuple (tl, tr, bl, br, center)
Corresponding top left, top right, bottom left, bottom right and center crop.
"""
if isinstance(size, numbers.Number):
size = (int(size), int(size))
elif isinstance(size, (tuple, list)) and len(size) == 1:
size = (size[0], size[0])
if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.")
image_width, image_height = _get_image_size(img)
crop_height, crop_width = size
if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}"
raise ValueError(msg.format(size, (image_height, image_width)))
tl = crop(img, 0, 0, crop_height, crop_width)
tr = crop(img, 0, image_width - crop_width, crop_height, crop_width)
bl = crop(img, image_height - crop_height, 0, crop_height, crop_width)
br = crop(img, image_height - crop_height, image_width - crop_width,
crop_height, crop_width)
center = center_crop(img, [crop_height, crop_width])
return tl, tr, bl, br, center
def ten_crop(img: Tensor, size: List[int],
vertical_flip: bool=False) -> List[Tensor]:
"""Generate ten cropped images from the given image.
Crop the given image into four corners and the central crop plus the
flipped version of these (horizontal flipping is used by default).
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
.. Note::
This transform returns a tuple of images and there may be a
mismatch in the number of inputs and targets your ``Dataset`` returns.
Args:
img (PIL Image or Tensor): Image to be cropped.
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
vertical_flip (bool): Use vertical flipping instead of horizontal
Returns:
tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
Corresponding top left, top right, bottom left, bottom right and
center crop and same for the flipped image.
"""
if isinstance(size, numbers.Number):
size = (int(size), int(size))
elif isinstance(size, (tuple, list)) and len(size) == 1:
size = (size[0], size[0])
if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.")
first_five = five_crop(img, size)
if vertical_flip:
img = vflip(img)
else:
img = hflip(img)
second_five = five_crop(img, size)
return first_five + second_five
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
"""Adjust brightness of an image.
Args:
img (PIL Image or Tensor): Image to be adjusted.
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
brightness_factor (float): How much to adjust the brightness. Can be
any non negative number. 0 gives a black image, 1 gives the
original image while 2 increases the brightness by a factor of 2.
Returns:
PIL Image or Tensor: Brightness adjusted image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.adjust_brightness(img, brightness_factor)
return F_t.adjust_brightness(img, brightness_factor)
def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
"""Adjust contrast of an image.
Args:
img (PIL Image or Tensor): Image to be adjusted.
If img is torch Tensor, it is expected to be in [..., 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
contrast_factor (float): How much to adjust the contrast. Can be any
non negative number. 0 gives a solid gray image, 1 gives the
original image while 2 increases the contrast by a factor of 2.
Returns:
PIL Image or Tensor: Contrast adjusted image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.adjust_contrast(img, contrast_factor)
return F_t.adjust_contrast(img, contrast_factor)
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
"""Adjust color saturation of an image.
Args:
img (PIL Image or Tensor): Image to be adjusted.
If img is torch Tensor, it is expected to be in [..., 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
saturation_factor (float): How much to adjust the saturation. 0 will
give a black and white image, 1 will give the original image while
2 will enhance the saturation by a factor of 2.
Returns:
PIL Image or Tensor: Saturation adjusted image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.adjust_saturation(img, saturation_factor)
return F_t.adjust_saturation(img, saturation_factor)
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
"""Adjust hue of an image.
The image hue is adjusted by converting the image to HSV and
cyclically shifting the intensities in the hue channel (H).
The image is then converted back to original image mode.
`hue_factor` is the amount of shift in H channel and must be in the
interval `[-0.5, 0.5]`.
See `Hue`_ for more details.
.. _Hue: https://en.wikipedia.org/wiki/Hue
Args:
img (PIL Image or Tensor): Image to be adjusted.
If img is torch Tensor, it is expected to be in [..., 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image mode "1", "L", "I", "F" and modes with transparency (alpha channel) are not supported.
hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively.
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
with complementary colors while 0 gives the original image.
Returns:
PIL Image or Tensor: Hue adjusted image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.adjust_hue(img, hue_factor)
return F_t.adjust_hue(img, hue_factor)
def adjust_gamma(img: Tensor, gamma: float, gain: float=1) -> Tensor:
r"""Perform gamma correction on an image.
Also known as Power Law Transform. Intensities in RGB mode are adjusted
based on the following equation:
.. math::
I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma}
See `Gamma Correction`_ for more details.
.. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction
Args:
img (PIL Image or Tensor): PIL Image to be adjusted.
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, modes with transparency (alpha channel) are not supported.
gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
gamma larger than 1 make the shadows darker,
while gamma smaller than 1 make dark regions lighter.
gain (float): The constant multiplier.
Returns:
PIL Image or Tensor: Gamma correction adjusted image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.adjust_gamma(img, gamma, gain)
return F_t.adjust_gamma(img, gamma, gain)
def _get_inverse_affine_matrix(center: List[float],
angle: float,
translate: List[float],
scale: float,
shear: List[float]) -> List[float]:
# Helper method to compute inverse matrix for affine transformation
# As it is explained in PIL.Image.rotate
# We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1
# where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
# C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
# RSS is rotation with scale and shear matrix
# RSS(a, s, (sx, sy)) =
# = R(a) * S(s) * SHy(sy) * SHx(sx)
# = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(x)/cos(y) - sin(a)), 0 ]
# [ s*sin(a + sy)/cos(sy), s*(-sin(a - sy)*tan(x)/cos(y) + cos(a)), 0 ]
# [ 0 , 0 , 1 ]
#
# where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
# SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0]
# [0, 1 ] [-tan(s), 1]
#
# Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1
rot = math.radians(angle)
sx, sy = [math.radians(s) for s in shear]
cx, cy = center
tx, ty = translate
# RSS without scaling
a = math.cos(rot - sy) / math.cos(sy)
b = -math.cos(rot - sy) * math.tan(sx) / math.cos(sy) - math.sin(rot)
c = math.sin(rot - sy) / math.cos(sy)
d = -math.sin(rot - sy) * math.tan(sx) / math.cos(sy) + math.cos(rot)
# Inverted rotation matrix with scale and shear
# det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
matrix = [d, -b, 0.0, -c, a, 0.0]
matrix = [x / scale for x in matrix]
# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty)
matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty)
# Apply center translation: C * RSS^-1 * C^-1 * T^-1
matrix[2] += cx
matrix[5] += cy
return matrix
def rotate(img: Tensor,
angle: float,
interpolation: InterpolationMode=InterpolationMode.NEAREST,
expand: bool=False,
center: Optional[List[int]]=None,
fill: Optional[List[float]]=None,
resample: Optional[int]=None) -> Tensor:
"""Rotate the image by angle.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
img (PIL Image or Tensor): image to be rotated.
angle (number): rotation angle value in degrees, counter-clockwise.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
expand (bool, optional): Optional expansion flag.
If true, expands the output image to make it large enough to hold the entire rotated image.
If false or omitted, make the output image the same size as the input image.
Note that the expand flag assumes rotation around the center and no translation.
center (sequence, optional): Optional center of rotation. Origin is the upper left corner.
Default is the center of the image.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
.. note::
In torchscript mode single int/float value is not supported, please use a sequence
of length 1: ``[value, ]``.
Returns:
PIL Image or Tensor: Rotated image.
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
"""
if resample is not None:
warnings.warn(
"Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead"
)
interpolation = _interpolation_modes_from_int(resample)
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationMode instead of int. "
"Please, use InterpolationMode enum.")
interpolation = _interpolation_modes_from_int(interpolation)
if not isinstance(angle, (int, float)):
raise TypeError("Argument angle should be int or float")
if center is not None and not isinstance(center, (list, tuple)):
raise TypeError("Argument center should be a sequence")
if not isinstance(interpolation, InterpolationMode):
raise TypeError("Argument interpolation should be a InterpolationMode")
if not isinstance(img, torch.Tensor):
pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.rotate(
img,
angle=angle,
interpolation=pil_interpolation,
expand=expand,
center=center,
fill=fill)
center_f = [0.0, 0.0]
if center is not None:
img_size = _get_image_size(img)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)]
# due to current incoherence of rotation angle direction between affine and rotate implementations
# we need to set -angle.
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0,
[0.0, 0.0])
return F_t.rotate(
img,
matrix=matrix,
interpolation=interpolation.value,
expand=expand,
fill=fill)
def affine(img: Tensor,
angle: float,
translate: List[int],
scale: float,
shear: List[float],
interpolation: InterpolationMode=InterpolationMode.NEAREST,
fill: Optional[List[float]]=None,
resample: Optional[int]=None,
fillcolor: Optional[List[float]]=None) -> Tensor:
"""Apply affine transformation on the image keeping image center invariant.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
img (PIL Image or Tensor): image to transform.
angle (number): rotation angle in degrees between -180 and 180, clockwise direction.
translate (sequence of integers): horizontal and vertical translations (post-rotation translation)
scale (float): overall scale
shear (float or sequence): shear angle value in degrees between -180 to 180, clockwise direction.
If a sequence is specified, the first value corresponds to a shear parallel to the x axis, while
the second value corresponds to a shear parallel to the y axis.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
.. note::
In torchscript mode single int/float value is not supported, please use a sequence
of length 1: ``[value, ]``.
fillcolor (sequence, int, float): deprecated argument and will be removed since v0.10.0.
Please use the ``fill`` parameter instead.
resample (int, optional): deprecated argument and will be removed since v0.10.0.
Please use the ``interpolation`` parameter instead.
Returns:
PIL Image or Tensor: Transformed image.
"""
if resample is not None:
warnings.warn(
"Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead"
)
interpolation = _interpolation_modes_from_int(resample)
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationMode instead of int. "
"Please, use InterpolationMode enum.")
interpolation = _interpolation_modes_from_int(interpolation)
if fillcolor is not None:
warnings.warn(
"Argument fillcolor is deprecated and will be removed since v0.10.0. Please, use fill instead"
)
fill = fillcolor
if not isinstance(angle, (int, float)):
raise TypeError("Argument angle should be int or float")
if not isinstance(translate, (list, tuple)):
raise TypeError("Argument translate should be a sequence")
if len(translate) != 2:
raise ValueError("Argument translate should be a sequence of length 2")
if scale <= 0.0:
raise ValueError("Argument scale should be positive")
if not isinstance(shear, (numbers.Number, (list, tuple))):
raise TypeError(
"Shear should be either a single value or a sequence of two values")
if not isinstance(interpolation, InterpolationMode):
raise TypeError("Argument interpolation should be a InterpolationMode")
if isinstance(angle, int):
angle = float(angle)
if isinstance(translate, tuple):
translate = list(translate)
if isinstance(shear, numbers.Number):
shear = [shear, 0.0]
if isinstance(shear, tuple):
shear = list(shear)
if len(shear) == 1:
shear = [shear[0], shear[0]]
if len(shear) != 2:
raise ValueError(
"Shear should be a sequence containing two values. Got {}".format(
shear))
img_size = _get_image_size(img)
if not isinstance(img, torch.Tensor):
# center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
# it is visually better to estimate the center without 0.5 offset
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
center = [img_size[0] * 0.5, img_size[1] * 0.5]
matrix = _get_inverse_affine_matrix(center, angle, translate, scale,
shear)
pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.affine(
img, matrix=matrix, interpolation=pil_interpolation, fill=fill)
translate_f = [1.0 * t for t in translate]
matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, translate_f, scale,
shear)
return F_t.affine(
img, matrix=matrix, interpolation=interpolation.value, fill=fill)
@torch.jit.unused
def to_grayscale(img, num_output_channels=1):
"""Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image.
This transform does not support torch Tensor.
Args:
img (PIL Image): PIL Image to be converted to grayscale.
num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default is 1.
Returns:
PIL Image: Grayscale version of the image.
- if num_output_channels = 1 : returned image is single channel
- if num_output_channels = 3 : returned image is 3 channel with r = g = b
"""
if isinstance(img, Image.Image):
return F_pil.to_grayscale(img, num_output_channels)
raise TypeError("Input should be PIL Image")
def rgb_to_grayscale(img: Tensor, num_output_channels: int=1) -> Tensor:
"""Convert RGB image to grayscale version of image.
If the image is torch Tensor, it is expected
to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
Note:
Please, note that this method supports only RGB images as input. For inputs in other color spaces,
please, consider using meth:`~torchvision.transforms.functional.to_grayscale` with PIL Image.
Args:
img (PIL Image or Tensor): RGB Image to be converted to grayscale.
num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.
Returns:
PIL Image or Tensor: Grayscale version of the image.
- if num_output_channels = 1 : returned image is single channel
- if num_output_channels = 3 : returned image is 3 channel with r = g = b
"""
if not isinstance(img, torch.Tensor):
return F_pil.to_grayscale(img, num_output_channels)
return F_t.rgb_to_grayscale(img, num_output_channels)
def erase(img: Tensor,
i: int,
j: int,
h: int,
w: int,
v: Tensor,
inplace: bool=False) -> Tensor:
""" Erase the input Tensor Image with given value.
This transform does not support PIL Image.
Args:
img (Tensor Image): Tensor image of size (C, H, W) to be erased
i (int): i in (i,j) i.e coordinates of the upper left corner.
j (int): j in (i,j) i.e coordinates of the upper left corner.
h (int): Height of the erased region.
w (int): Width of the erased region.
v: Erasing value.
inplace(bool, optional): For in-place operations. By default is set False.
Returns:
Tensor Image: Erased image.
"""
if not isinstance(img, torch.Tensor):
raise TypeError('img should be Tensor Image. Got {}'.format(type(img)))
if not inplace:
img = img.clone()
img[..., i:i + h, j:j + w] = v
return img
def gaussian_blur(img: Tensor,
kernel_size: List[int],
sigma: Optional[List[float]]=None) -> Tensor:
"""Performs Gaussian blurring on the image by given kernel.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
img (PIL Image or Tensor): Image to be blurred
kernel_size (sequence of ints or int): Gaussian kernel size. Can be a sequence of integers
like ``(kx, ky)`` or a single integer for square kernels.
.. note::
In torchscript mode kernel_size as single int is not supported, use a sequence of
length 1: ``[ksize, ]``.
sigma (sequence of floats or float, optional): Gaussian kernel standard deviation. Can be a
sequence of floats like ``(sigma_x, sigma_y)`` or a single float to define the
same sigma in both X/Y directions. If None, then it is computed using
``kernel_size`` as ``sigma = 0.3 * ((kernel_size - 1) * 0.5 - 1) + 0.8``.
Default, None.
.. note::
In torchscript mode sigma as single float is
not supported, use a sequence of length 1: ``[sigma, ]``.
Returns:
PIL Image or Tensor: Gaussian Blurred version of the image.
"""
if not isinstance(kernel_size, (int, list, tuple)):
raise TypeError(
'kernel_size should be int or a sequence of integers. Got {}'.
format(type(kernel_size)))
if isinstance(kernel_size, int):
kernel_size = [kernel_size, kernel_size]
if len(kernel_size) != 2:
raise ValueError(
'If kernel_size is a sequence its length should be 2. Got {}'.
format(len(kernel_size)))
for ksize in kernel_size:
if ksize % 2 == 0 or ksize < 0:
raise ValueError(
'kernel_size should have odd and positive integers. Got {}'.
format(kernel_size))
if sigma is None:
sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]
if sigma is not None and not isinstance(sigma, (int, float, list, tuple)):
raise TypeError(
'sigma should be either float or sequence of floats. Got {}'.
format(type(sigma)))
if isinstance(sigma, (int, float)):
sigma = [float(sigma), float(sigma)]
if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
sigma = [sigma[0], sigma[0]]
if len(sigma) != 2:
raise ValueError(
'If sigma is a sequence, its length should be 2. Got {}'.format(
len(sigma)))
for s in sigma:
if s <= 0.:
raise ValueError(
'sigma should have positive values. Got {}'.format(sigma))
t_img = img
if not isinstance(img, torch.Tensor):
if not F_pil._is_pil_image(img):
raise TypeError('img should be PIL Image or Tensor. Got {}'.format(
type(img)))
t_img = to_tensor(img)
output = F_t.gaussian_blur(t_img, kernel_size, sigma)
if not isinstance(img, torch.Tensor):
output = to_pil_image(output)
return output
def invert(img: Tensor) -> Tensor:
"""Invert the colors of an RGB/grayscale image.
Args:
img (PIL Image or Tensor): Image to have its colors inverted.
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Returns:
PIL Image or Tensor: Color inverted image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.invert(img)
return F_t.invert(img)
def posterize(img: Tensor, bits: int) -> Tensor:
"""Posterize an image by reducing the number of bits for each color channel.
Args:
img (PIL Image or Tensor): Image to have its colors posterized.
If img is torch Tensor, it should be of type torch.uint8 and
it is expected to be in [..., 1 or 3, H, W] format, where ... means
it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
bits (int): The number of bits to keep for each channel (0-8).
Returns:
PIL Image or Tensor: Posterized image.
"""
if not (0 <= bits <= 8):
raise ValueError(
'The number if bits should be between 0 and 8. Got {}'.format(
bits))
if not isinstance(img, torch.Tensor):
return F_pil.posterize(img, bits)
return F_t.posterize(img, bits)
def solarize(img: Tensor, threshold: float) -> Tensor:
"""Solarize an RGB/grayscale image by inverting all pixel values above a threshold.
Args:
img (PIL Image or Tensor): Image to have its colors inverted.
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
threshold (float): All pixels equal or above this value are inverted.
Returns:
PIL Image or Tensor: Solarized image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.solarize(img, threshold)
return F_t.solarize(img, threshold)
def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
"""Adjust the sharpness of an image.
Args:
img (PIL Image or Tensor): Image to be adjusted.
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
sharpness_factor (float): How much to adjust the sharpness. Can be
any non negative number. 0 gives a blurred image, 1 gives the
original image while 2 increases the sharpness by a factor of 2.
Returns:
PIL Image or Tensor: Sharpness adjusted image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.adjust_sharpness(img, sharpness_factor)
return F_t.adjust_sharpness(img, sharpness_factor)
def autocontrast(img: Tensor) -> Tensor:
"""Maximize contrast of an image by remapping its
pixels per channel so that the lowest becomes black and the lightest
becomes white.
Args:
img (PIL Image or Tensor): Image on which autocontrast is applied.
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Returns:
PIL Image or Tensor: An image that was autocontrasted.
"""
if not isinstance(img, torch.Tensor):
return F_pil.autocontrast(img)
return F_t.autocontrast(img)
def equalize(img: Tensor) -> Tensor:
"""Equalize the histogram of an image by applying
a non-linear mapping to the input in order to create a uniform
distribution of grayscale values in the output.
Args:
img (PIL Image or Tensor): Image on which equalize is applied.
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
The tensor dtype must be ``torch.uint8`` and values are expected to be in ``[0, 255]``.
If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
Returns:
PIL Image or Tensor: An image that was equalized.
"""
if not isinstance(img, torch.Tensor):
return F_pil.equalize(img)
return F_t.equalize(img)
import numbers
from typing import Any, List, Sequence
import numpy as np
import torch
from PIL import Image, ImageOps, ImageEnhance
try:
import accimage
except ImportError:
accimage = None
@torch.jit.unused
def _is_pil_image(img: Any) -> bool:
if accimage is not None:
return isinstance(img, (Image.Image, accimage.Image))
else:
return isinstance(img, Image.Image)
@torch.jit.unused
def _get_image_size(img: Any) -> List[int]:
if _is_pil_image(img):
return img.size
raise TypeError("Unexpected type {}".format(type(img)))
@torch.jit.unused
def _get_image_num_channels(img: Any) -> int:
if _is_pil_image(img):
return 1 if img.mode == 'L' else 3
raise TypeError("Unexpected type {}".format(type(img)))
@torch.jit.unused
def hflip(img):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return img.transpose(Image.FLIP_LEFT_RIGHT)
@torch.jit.unused
def vflip(img):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return img.transpose(Image.FLIP_TOP_BOTTOM)
@torch.jit.unused
def adjust_brightness(img, brightness_factor):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
enhancer = ImageEnhance.Brightness(img)
img = enhancer.enhance(brightness_factor)
return img
@torch.jit.unused
def adjust_contrast(img, contrast_factor):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
enhancer = ImageEnhance.Contrast(img)
img = enhancer.enhance(contrast_factor)
return img
@torch.jit.unused
def adjust_saturation(img, saturation_factor):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
enhancer = ImageEnhance.Color(img)
img = enhancer.enhance(saturation_factor)
return img
@torch.jit.unused
def adjust_hue(img, hue_factor):
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(
hue_factor))
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
input_mode = img.mode
if input_mode in {'L', '1', 'I', 'F'}:
return img
h, s, v = img.convert('HSV').split()
np_h = np.array(h, dtype=np.uint8)
# uint8 addition take cares of rotation across boundaries
with np.errstate(over='ignore'):
np_h += np.uint8(hue_factor * 255)
h = Image.fromarray(np_h, 'L')
img = Image.merge('HSV', (h, s, v)).convert(input_mode)
return img
@torch.jit.unused
def adjust_gamma(img, gamma, gain=1):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if gamma < 0:
raise ValueError('Gamma should be a non-negative real number')
input_mode = img.mode
img = img.convert('RGB')
gamma_map = [(255 + 1 - 1e-3) * gain * pow(ele / 255., gamma)
for ele in range(256)] * 3
img = img.point(
gamma_map) # use PIL's point-function to accelerate this part
img = img.convert(input_mode)
return img
@torch.jit.unused
def pad(img, padding, fill=0, padding_mode="constant"):
if not _is_pil_image(img):
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if not isinstance(fill, (numbers.Number, str, tuple)):
raise TypeError("Got inappropriate fill arg")
if not isinstance(padding_mode, str):
raise TypeError("Got inappropriate padding_mode arg")
if isinstance(padding, list):
padding = tuple(padding)
if isinstance(padding, tuple) and len(padding) not in [1, 2, 4]:
raise ValueError(
"Padding must be an int or a 1, 2, or 4 element tuple, not a " +
"{} element tuple".format(len(padding)))
if isinstance(padding, tuple) and len(padding) == 1:
# Compatibility with `functional_tensor.pad`
padding = padding[0]
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError(
"Padding mode should be either constant, edge, reflect or symmetric"
)
if padding_mode == "constant":
opts = _parse_fill(fill, img, name="fill")
if img.mode == "P":
palette = img.getpalette()
image = ImageOps.expand(img, border=padding, **opts)
image.putpalette(palette)
return image
return ImageOps.expand(img, border=padding, **opts)
else:
if isinstance(padding, int):
pad_left = pad_right = pad_top = pad_bottom = padding
if isinstance(padding, tuple) and len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
if isinstance(padding, tuple) and len(padding) == 4:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]
p = [pad_left, pad_top, pad_right, pad_bottom]
cropping = -np.minimum(p, 0)
if cropping.any():
crop_left, crop_top, crop_right, crop_bottom = cropping
img = img.crop((crop_left, crop_top, img.width - crop_right,
img.height - crop_bottom))
pad_left, pad_top, pad_right, pad_bottom = np.maximum(p, 0)
if img.mode == 'P':
palette = img.getpalette()
img = np.asarray(img)
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)),
padding_mode)
img = Image.fromarray(img)
img.putpalette(palette)
return img
img = np.asarray(img)
# RGB image
if len(img.shape) == 3:
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right),
(0, 0)), padding_mode)
# Grayscale image
if len(img.shape) == 2:
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)),
padding_mode)
return Image.fromarray(img)
@torch.jit.unused
def crop(img: Image.Image, top: int, left: int, height: int,
width: int) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return img.crop((left, top, left + width, top + height))
@torch.jit.unused
def resize(img, size, interpolation=Image.BILINEAR, max_size=None):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if not (isinstance(size, int) or
(isinstance(size, Sequence) and len(size) in (1, 2))):
raise TypeError('Got inappropriate size arg: {}'.format(size))
if isinstance(size, Sequence) and len(size) == 1:
size = size[0]
if isinstance(size, int):
w, h = img.size
short, long = (w, h) if w <= h else (h, w)
if short == size:
return img
new_short, new_long = size, int(size * long / short)
if max_size is not None:
if max_size <= size:
raise ValueError(
f"max_size = {max_size} must be strictly greater than the requested "
f"size for the smaller edge size = {size}")
if new_long > max_size:
new_short, new_long = int(max_size * new_short /
new_long), max_size
new_w, new_h = (new_short, new_long) if w <= h else (new_long,
new_short)
return img.resize((new_w, new_h), interpolation)
else:
if max_size is not None:
raise ValueError(
"max_size should only be passed if size specifies the length of the smaller edge, "
"i.e. size should be an int or a sequence of length 1 in torchscript mode."
)
return img.resize(size[::-1], interpolation)
@torch.jit.unused
def _parse_fill(fill, img, name="fillcolor"):
# Process fill color for affine transforms
num_bands = len(img.getbands())
if fill is None:
fill = 0
if isinstance(fill, (int, float)) and num_bands > 1:
fill = tuple([fill] * num_bands)
if isinstance(fill, (list, tuple)):
if len(fill) != num_bands:
msg = (
"The number of elements in 'fill' does not match the number of "
"bands of the image ({} != {})")
raise ValueError(msg.format(len(fill), num_bands))
fill = tuple(fill)
return {name: fill}
@torch.jit.unused
def affine(img, matrix, interpolation=0, fill=None):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
output_size = img.size
opts = _parse_fill(fill, img)
return img.transform(output_size, Image.AFFINE, matrix, interpolation,
**opts)
@torch.jit.unused
def rotate(img, angle, interpolation=0, expand=False, center=None, fill=None):
if not _is_pil_image(img):
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
opts = _parse_fill(fill, img)
return img.rotate(angle, interpolation, expand, center, **opts)
@torch.jit.unused
def perspective(img,
perspective_coeffs,
interpolation=Image.BICUBIC,
fill=None):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
opts = _parse_fill(fill, img)
return img.transform(img.size, Image.PERSPECTIVE, perspective_coeffs,
interpolation, **opts)
@torch.jit.unused
def to_grayscale(img, num_output_channels):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if num_output_channels == 1:
img = img.convert('L')
elif num_output_channels == 3:
img = img.convert('L')
np_img = np.array(img, dtype=np.uint8)
np_img = np.dstack([np_img, np_img, np_img])
img = Image.fromarray(np_img, 'RGB')
else:
raise ValueError('num_output_channels should be either 1 or 3')
return img
@torch.jit.unused
def invert(img):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.invert(img)
@torch.jit.unused
def posterize(img, bits):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.posterize(img, bits)
@torch.jit.unused
def solarize(img, threshold):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.solarize(img, threshold)
@torch.jit.unused
def adjust_sharpness(img, sharpness_factor):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
enhancer = ImageEnhance.Sharpness(img)
img = enhancer.enhance(sharpness_factor)
return img
@torch.jit.unused
def autocontrast(img):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.autocontrast(img)
@torch.jit.unused
def equalize(img):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.equalize(img)
import warnings
import torch
from torch import Tensor
from torch.nn.functional import grid_sample, conv2d, interpolate, pad as torch_pad
from torch.jit.annotations import BroadcastingList2
from typing import Optional, Tuple, List
def _is_tensor_a_torch_image(x: Tensor) -> bool:
return x.ndim >= 2
def _assert_image_tensor(img):
if not _is_tensor_a_torch_image(img):
raise TypeError("Tensor is not a torch image.")
def _get_image_size(img: Tensor) -> List[int]:
# Returns (w, h) of tensor image
_assert_image_tensor(img)
return [img.shape[-1], img.shape[-2]]
def _get_image_num_channels(img: Tensor) -> int:
if img.ndim == 2:
return 1
elif img.ndim > 2:
return img.shape[-3]
raise TypeError("Input ndim should be 2 or more. Got {}".format(img.ndim))
def _max_value(dtype: torch.dtype) -> float:
# TODO: replace this method with torch.iinfo when it gets torchscript support.
# https://github.com/pytorch/pytorch/issues/41492
a = torch.tensor(2, dtype=dtype)
signed = 1 if torch.tensor(0, dtype=dtype).is_signed() else 0
bits = 1
max_value = torch.tensor(-signed, dtype=torch.long)
while True:
next_value = a.pow(bits - signed).sub(1)
if next_value > max_value:
max_value = next_value
bits *= 2
else:
break
return max_value.item()
def _assert_channels(img: Tensor, permitted: List[int]) -> None:
c = _get_image_num_channels(img)
if c not in permitted:
raise TypeError(
"Input image tensor permitted channel values are {}, but found {}".
format(permitted, c))
def convert_image_dtype(image: torch.Tensor,
dtype: torch.dtype=torch.float) -> torch.Tensor:
if image.dtype == dtype:
return image
if image.is_floating_point():
# TODO: replace with dtype.is_floating_point when torchscript supports it
if torch.tensor(0, dtype=dtype).is_floating_point():
return image.to(dtype)
# float to int
if (image.dtype == torch.float32 and dtype in
(torch.int32, torch.int64)) or (image.dtype == torch.float64 and
dtype == torch.int64):
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
raise RuntimeError(msg)
# https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
# For data in the range 0-1, (float * 255).to(uint) is only 255
# when float is exactly 1.0.
# `max + 1 - epsilon` provides more evenly distributed mapping of
# ranges of floats to ints.
eps = 1e-3
max_val = _max_value(dtype)
result = image.mul(max_val + 1.0 - eps)
return result.to(dtype)
else:
input_max = _max_value(image.dtype)
# int to float
# TODO: replace with dtype.is_floating_point when torchscript supports it
if torch.tensor(0, dtype=dtype).is_floating_point():
image = image.to(dtype)
return image / input_max
output_max = _max_value(dtype)
# int to int
if input_max > output_max:
# factor should be forced to int for torch jit script
# otherwise factor is a float and image // factor can produce different results
factor = int((input_max + 1) // (output_max + 1))
image = torch.div(image, factor, rounding_mode='floor')
return image.to(dtype)
else:
# factor should be forced to int for torch jit script
# otherwise factor is a float and image * factor can produce different results
factor = int((output_max + 1) // (input_max + 1))
image = image.to(dtype)
return image * factor
def vflip(img: Tensor) -> Tensor:
_assert_image_tensor(img)
return img.flip(-2)
def hflip(img: Tensor) -> Tensor:
_assert_image_tensor(img)
return img.flip(-1)
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
_assert_image_tensor(img)
w, h = _get_image_size(img)
right = left + width
bottom = top + height
if left < 0 or top < 0 or right > w or bottom > h:
padding_ltrb = [
max(-left, 0), max(-top, 0), max(right - w, 0), max(bottom - h, 0)
]
return pad(img[..., max(top, 0):bottom, max(left, 0):right],
padding_ltrb,
fill=0)
return img[..., top:bottom, left:right]
def rgb_to_grayscale(img: Tensor, num_output_channels: int=1) -> Tensor:
if img.ndim < 3:
raise TypeError(
"Input image tensor should have at least 3 dimensions, but found {}".
format(img.ndim))
_assert_channels(img, [3])
if num_output_channels not in (1, 3):
raise ValueError('num_output_channels should be either 1 or 3')
r, g, b = img.unbind(dim=-3)
# This implementation closely follows the TF one:
# https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138
l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(img.dtype)
l_img = l_img.unsqueeze(dim=-3)
if num_output_channels == 3:
return l_img.expand(img.shape)
return l_img
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
if brightness_factor < 0:
raise ValueError('brightness_factor ({}) is not non-negative.'.format(
brightness_factor))
_assert_image_tensor(img)
_assert_channels(img, [1, 3])
return _blend(img, torch.zeros_like(img), brightness_factor)
def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
if contrast_factor < 0:
raise ValueError('contrast_factor ({}) is not non-negative.'.format(
contrast_factor))
_assert_image_tensor(img)
_assert_channels(img, [3])
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
mean = torch.mean(
rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
return _blend(img, mean, contrast_factor)
def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(
hue_factor))
if not (isinstance(img, torch.Tensor)):
raise TypeError('Input img should be Tensor image')
_assert_image_tensor(img)
_assert_channels(img, [1, 3])
if _get_image_num_channels(img) == 1: # Match PIL behaviour
return img
orig_dtype = img.dtype
if img.dtype == torch.uint8:
img = img.to(dtype=torch.float32) / 255.0
img = _rgb2hsv(img)
h, s, v = img.unbind(dim=-3)
h = (h + hue_factor) % 1.0
img = torch.stack((h, s, v), dim=-3)
img_hue_adj = _hsv2rgb(img)
if orig_dtype == torch.uint8:
img_hue_adj = (img_hue_adj * 255.0).to(dtype=orig_dtype)
return img_hue_adj
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
if saturation_factor < 0:
raise ValueError('saturation_factor ({}) is not non-negative.'.format(
saturation_factor))
_assert_image_tensor(img)
_assert_channels(img, [3])
return _blend(img, rgb_to_grayscale(img), saturation_factor)
def adjust_gamma(img: Tensor, gamma: float, gain: float=1) -> Tensor:
if not isinstance(img, torch.Tensor):
raise TypeError('Input img should be a Tensor.')
_assert_channels(img, [1, 3])
if gamma < 0:
raise ValueError('Gamma should be a non-negative real number')
result = img
dtype = img.dtype
if not torch.is_floating_point(img):
result = convert_image_dtype(result, torch.float32)
result = (gain * result**gamma).clamp(0, 1)
result = convert_image_dtype(result, dtype)
return result
def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
"""DEPRECATED
"""
warnings.warn(
"This method is deprecated and will be removed in future releases. "
"Please, use ``F.center_crop`` instead.")
_assert_image_tensor(img)
_, image_width, image_height = img.size()
crop_height, crop_width = output_size
# crop_top = int(round((image_height - crop_height) / 2.))
# Result can be different between python func and scripted func
# Temporary workaround:
crop_top = int((image_height - crop_height + 1) * 0.5)
# crop_left = int(round((image_width - crop_width) / 2.))
# Result can be different between python func and scripted func
# Temporary workaround:
crop_left = int((image_width - crop_width + 1) * 0.5)
return crop(img, crop_top, crop_left, crop_height, crop_width)
def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]:
"""DEPRECATED
"""
warnings.warn(
"This method is deprecated and will be removed in future releases. "
"Please, use ``F.five_crop`` instead.")
_assert_image_tensor(img)
assert len(
size) == 2, "Please provide only two dimensions (h, w) for size."
_, image_width, image_height = img.size()
crop_height, crop_width = size
if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}"
raise ValueError(msg.format(size, (image_height, image_width)))
tl = crop(img, 0, 0, crop_width, crop_height)
tr = crop(img, image_width - crop_width, 0, image_width, crop_height)
bl = crop(img, 0, image_height - crop_height, crop_width, image_height)
br = crop(img, image_width - crop_width, image_height - crop_height,
image_width, image_height)
center = center_crop(img, (crop_height, crop_width))
return [tl, tr, bl, br, center]
def ten_crop(img: Tensor,
size: BroadcastingList2[int],
vertical_flip: bool=False) -> List[Tensor]:
"""DEPRECATED
"""
warnings.warn(
"This method is deprecated and will be removed in future releases. "
"Please, use ``F.ten_crop`` instead.")
_assert_image_tensor(img)
assert len(
size) == 2, "Please provide only two dimensions (h, w) for size."
first_five = five_crop(img, size)
if vertical_flip:
img = vflip(img)
else:
img = hflip(img)
second_five = five_crop(img, size)
return first_five + second_five
def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
ratio = float(ratio)
bound = 1.0 if img1.is_floating_point() else 255.0
return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype)
def _rgb2hsv(img):
r, g, b = img.unbind(dim=-3)
# Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/
# src/libImaging/Convert.c#L330
maxc = torch.max(img, dim=-3).values
minc = torch.min(img, dim=-3).values
# The algorithm erases S and H channel where `maxc = minc`. This avoids NaN
# from happening in the results, because
# + S channel has division by `maxc`, which is zero only if `maxc = minc`
# + H channel has division by `(maxc - minc)`.
#
# Instead of overwriting NaN afterwards, we just prevent it from occuring so
# we don't need to deal with it in case we save the NaN in a buffer in
# backprop, if it is ever supported, but it doesn't hurt to do so.
eqc = maxc == minc
cr = maxc - minc
# Since `eqc => cr = 0`, replacing denominator with 1 when `eqc` is fine.
ones = torch.ones_like(maxc)
s = cr / torch.where(eqc, ones, maxc)
# Note that `eqc => maxc = minc = r = g = b`. So the following calculation
# of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it
# would not matter what values `rc`, `gc`, and `bc` have here, and thus
# replacing denominator with 1 when `eqc` is fine.
cr_divisor = torch.where(eqc, ones, cr)
rc = (maxc - r) / cr_divisor
gc = (maxc - g) / cr_divisor
bc = (maxc - b) / cr_divisor
hr = (maxc == r) * (bc - gc)
hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc)
hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
h = (hr + hg + hb)
h = torch.fmod((h / 6.0 + 1.0), 1.0)
return torch.stack((h, s, maxc), dim=-3)
def _hsv2rgb(img):
h, s, v = img.unbind(dim=-3)
i = torch.floor(h * 6.0)
f = (h * 6.0) - i
i = i.to(dtype=torch.int32)
p = torch.clamp((v * (1.0 - s)), 0.0, 1.0)
q = torch.clamp((v * (1.0 - s * f)), 0.0, 1.0)
t = torch.clamp((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
i = i % 6
mask = i.unsqueeze(dim=-3) == torch.arange(
6, device=i.device).view(-1, 1, 1)
a1 = torch.stack((v, q, p, p, t, v), dim=-3)
a2 = torch.stack((t, v, v, q, p, p), dim=-3)
a3 = torch.stack((p, p, t, v, v, q), dim=-3)
a4 = torch.stack((a1, a2, a3), dim=-4)
return torch.einsum(
"...ijk, ...xijk -> ...xjk", mask.to(dtype=img.dtype), a4)
def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
# padding is left, right, top, bottom
# crop if needed
if padding[0] < 0 or padding[1] < 0 or padding[2] < 0 or padding[3] < 0:
crop_left, crop_right, crop_top, crop_bottom = [
-min(x, 0) for x in padding
]
img = img[..., crop_top:img.shape[-2] - crop_bottom, crop_left:
img.shape[-1] - crop_right]
padding = [max(x, 0) for x in padding]
in_sizes = img.size()
x_indices = [i for i in range(in_sizes[-1])] # [0, 1, 2, 3, ...]
left_indices = [i for i in range(padding[0] - 1, -1, -1)
] # e.g. [3, 2, 1, 0]
right_indices = [-(i + 1) for i in range(padding[1])] # e.g. [-1, -2, -3]
x_indices = torch.tensor(
left_indices + x_indices + right_indices, device=img.device)
y_indices = [i for i in range(in_sizes[-2])]
top_indices = [i for i in range(padding[2] - 1, -1, -1)]
bottom_indices = [-(i + 1) for i in range(padding[3])]
y_indices = torch.tensor(
top_indices + y_indices + bottom_indices, device=img.device)
ndim = img.ndim
if ndim == 3:
return img[:, y_indices[:, None], x_indices[None, :]]
elif ndim == 4:
return img[:, :, y_indices[:, None], x_indices[None, :]]
else:
raise RuntimeError(
"Symmetric padding of N-D tensors are not supported yet")
def pad(img: Tensor,
padding: List[int],
fill: int=0,
padding_mode: str="constant") -> Tensor:
_assert_image_tensor(img)
if not isinstance(padding, (int, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if not isinstance(fill, (int, float)):
raise TypeError("Got inappropriate fill arg")
if not isinstance(padding_mode, str):
raise TypeError("Got inappropriate padding_mode arg")
if isinstance(padding, tuple):
padding = list(padding)
if isinstance(padding, list) and len(padding) not in [1, 2, 4]:
raise ValueError(
"Padding must be an int or a 1, 2, or 4 element tuple, not a " +
"{} element tuple".format(len(padding)))
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError(
"Padding mode should be either constant, edge, reflect or symmetric"
)
if isinstance(padding, int):
if torch.jit.is_scripting():
# This maybe unreachable
raise ValueError(
"padding can't be an int while torchscripting, set it as a list [value, ]"
)
pad_left = pad_right = pad_top = pad_bottom = padding
elif len(padding) == 1:
pad_left = pad_right = pad_top = pad_bottom = padding[0]
elif len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
else:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]
p = [pad_left, pad_right, pad_top, pad_bottom]
if padding_mode == "edge":
# remap padding_mode str
padding_mode = "replicate"
elif padding_mode == "symmetric":
# route to another implementation
return _pad_symmetric(img, p)
need_squeeze = False
if img.ndim < 4:
img = img.unsqueeze(dim=0)
need_squeeze = True
out_dtype = img.dtype
need_cast = False
if (padding_mode != "constant") and img.dtype not in (torch.float32,
torch.float64):
# Here we temporary cast input tensor to float
# until pytorch issue is resolved :
# https://github.com/pytorch/pytorch/issues/40763
need_cast = True
img = img.to(torch.float32)
img = torch_pad(img, p, mode=padding_mode, value=float(fill))
if need_squeeze:
img = img.squeeze(dim=0)
if need_cast:
img = img.to(out_dtype)
return img
def resize(img: Tensor,
size: List[int],
interpolation: str="bilinear",
max_size: Optional[int]=None,
antialias: Optional[bool]=None) -> Tensor:
_assert_image_tensor(img)
if not isinstance(size, (int, tuple, list)):
raise TypeError("Got inappropriate size arg")
if not isinstance(interpolation, str):
raise TypeError("Got inappropriate interpolation arg")
if interpolation not in ["nearest", "bilinear", "bicubic"]:
raise ValueError(
"This interpolation mode is unsupported with Tensor input")
if isinstance(size, tuple):
size = list(size)
if isinstance(size, list):
if len(size) not in [1, 2]:
raise ValueError(
"Size must be an int or a 1 or 2 element tuple/list, not a "
"{} element tuple/list".format(len(size)))
if max_size is not None and len(size) != 1:
raise ValueError(
"max_size should only be passed if size specifies the length of the smaller edge, "
"i.e. size should be an int or a sequence of length 1 in torchscript mode."
)
if antialias is None:
antialias = False
if antialias and interpolation not in ["bilinear", "bicubic"]:
raise ValueError(
"Antialias option is supported for bilinear and bicubic interpolation modes only"
)
w, h = _get_image_size(img)
if isinstance(size, int) or len(
size) == 1: # specified size only for the smallest edge
short, long = (w, h) if w <= h else (h, w)
requested_new_short = size if isinstance(size, int) else size[0]
if short == requested_new_short:
return img
new_short, new_long = requested_new_short, int(requested_new_short *
long / short)
if max_size is not None:
if max_size <= requested_new_short:
raise ValueError(
f"max_size = {max_size} must be strictly greater than the requested "
f"size for the smaller edge size = {size}")
if new_long > max_size:
new_short, new_long = int(max_size * new_short /
new_long), max_size
new_w, new_h = (new_short, new_long) if w <= h else (new_long,
new_short)
else: # specified both h and w
new_w, new_h = size[1], size[0]
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
img, [torch.float32, torch.float64])
# Define align_corners to avoid warnings
align_corners = False if interpolation in ["bilinear", "bicubic"] else None
if antialias:
if interpolation == "bilinear":
img = torch.ops.torchvision._interpolate_bilinear2d_aa(
img, [new_h, new_w], align_corners=False)
elif interpolation == "bicubic":
img = torch.ops.torchvision._interpolate_bicubic2d_aa(
img, [new_h, new_w], align_corners=False)
else:
img = interpolate(
img,
size=[new_h, new_w],
mode=interpolation,
align_corners=align_corners)
if interpolation == "bicubic" and out_dtype == torch.uint8:
img = img.clamp(min=0, max=255)
img = _cast_squeeze_out(
img,
need_cast=need_cast,
need_squeeze=need_squeeze,
out_dtype=out_dtype)
return img
def _assert_grid_transform_inputs(
img: Tensor,
matrix: Optional[List[float]],
interpolation: str,
fill: Optional[List[float]],
supported_interpolation_modes: List[str],
coeffs: Optional[List[float]]=None, ):
if not (isinstance(img, torch.Tensor)):
raise TypeError("Input img should be Tensor")
_assert_image_tensor(img)
if matrix is not None and not isinstance(matrix, list):
raise TypeError("Argument matrix should be a list")
if matrix is not None and len(matrix) != 6:
raise ValueError("Argument matrix should have 6 float values")
if coeffs is not None and len(coeffs) != 8:
raise ValueError("Argument coeffs should have 8 float values")
if fill is not None and not isinstance(fill, (int, float, tuple, list)):
warnings.warn(
"Argument fill should be either int, float, tuple or list")
# Check fill
num_channels = _get_image_num_channels(img)
if isinstance(fill, (tuple, list)) and (len(fill) > 1 and
len(fill) != num_channels):
msg = (
"The number of elements in 'fill' cannot broadcast to match the number of "
"channels of the image ({} != {})")
raise ValueError(msg.format(len(fill), num_channels))
if interpolation not in supported_interpolation_modes:
raise ValueError(
"Interpolation mode '{}' is unsupported with Tensor input".format(
interpolation))
def _cast_squeeze_in(img: Tensor, req_dtypes: List[torch.dtype]) -> Tuple[
Tensor, bool, bool, torch.dtype]:
need_squeeze = False
# make image NCHW
if img.ndim < 4:
img = img.unsqueeze(dim=0)
need_squeeze = True
out_dtype = img.dtype
need_cast = False
if out_dtype not in req_dtypes:
need_cast = True
req_dtype = req_dtypes[0]
img = img.to(req_dtype)
return img, need_cast, need_squeeze, out_dtype
def _cast_squeeze_out(img: Tensor,
need_cast: bool,
need_squeeze: bool,
out_dtype: torch.dtype):
if need_squeeze:
img = img.squeeze(dim=0)
if need_cast:
if out_dtype in (torch.uint8, torch.int8, torch.int16, torch.int32,
torch.int64):
# it is better to round before cast
img = torch.round(img)
img = img.to(out_dtype)
return img
def _apply_grid_transform(img: Tensor,
grid: Tensor,
mode: str,
fill: Optional[List[float]]) -> Tensor:
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img,
[grid.dtype, ])
if img.shape[0] > 1:
# Apply same grid to a batch of images
grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2],
grid.shape[3])
# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
if fill is not None:
dummy = torch.ones(
(img.shape[0], 1, img.shape[2], img.shape[3]),
dtype=img.dtype,
device=img.device)
img = torch.cat((img, dummy), dim=1)
img = grid_sample(
img, grid, mode=mode, padding_mode="zeros", align_corners=False)
# Fill with required color
if fill is not None:
mask = img[:, -1:, :, :] # N * 1 * H * W
img = img[:, :-1, :, :] # N * C * H * W
mask = mask.expand_as(img)
len_fill = len(fill) if isinstance(fill, (tuple, list)) else 1
fill_img = torch.tensor(
fill, dtype=img.dtype, device=img.device).view(1, len_fill, 1,
1).expand_as(img)
if mode == 'nearest':
mask = mask < 0.5
img[mask] = fill_img[mask]
else: # 'bilinear'
img = img * mask + (1.0 - mask) * fill_img
img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
return img
def _gen_affine_grid(
theta: Tensor,
w: int,
h: int,
ow: int,
oh: int, ) -> Tensor:
# https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
# AffineGridGenerator.cpp#L18
# Difference with AffineGridGenerator is that:
# 1) we normalize grid values after applying theta
# 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate
d = 0.5
base_grid = torch.empty(
1, oh, ow, 3, dtype=theta.dtype, device=theta.device)
x_grid = torch.linspace(
-ow * 0.5 + d, ow * 0.5 + d - 1, steps=ow, device=theta.device)
base_grid[..., 0].copy_(x_grid)
y_grid = torch.linspace(
-oh * 0.5 + d, oh * 0.5 + d - 1, steps=oh,
device=theta.device).unsqueeze_(-1)
base_grid[..., 1].copy_(y_grid)
base_grid[..., 2].fill_(1)
rescaled_theta = theta.transpose(1, 2) / torch.tensor(
[0.5 * w, 0.5 * h], dtype=theta.dtype, device=theta.device)
output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
return output_grid.view(1, oh, ow, 2)
def affine(img: Tensor,
matrix: List[float],
interpolation: str="nearest",
fill: Optional[List[float]]=None) -> Tensor:
_assert_grid_transform_inputs(img, matrix, interpolation, fill,
["nearest", "bilinear"])
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
theta = torch.tensor(
matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
shape = img.shape
# grid will be generated on the same device as theta and img
grid = _gen_affine_grid(
theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2])
return _apply_grid_transform(img, grid, interpolation, fill=fill)
def _compute_output_size(matrix: List[float], w: int,
h: int) -> Tuple[int, int]:
# Inspired of PIL implementation:
# https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054
# pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
pts = torch.tensor([
[-0.5 * w, -0.5 * h, 1.0],
[-0.5 * w, 0.5 * h, 1.0],
[0.5 * w, 0.5 * h, 1.0],
[0.5 * w, -0.5 * h, 1.0],
])
theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3)
new_pts = pts.view(1, 4, 3).bmm(theta.transpose(1, 2)).view(4, 2)
min_vals, _ = new_pts.min(dim=0)
max_vals, _ = new_pts.max(dim=0)
# Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
tol = 1e-4
cmax = torch.ceil((max_vals / tol).trunc_() * tol)
cmin = torch.floor((min_vals / tol).trunc_() * tol)
size = cmax - cmin
return int(size[0]), int(size[1])
def rotate(img: Tensor,
matrix: List[float],
interpolation: str="nearest",
expand: bool=False,
fill: Optional[List[float]]=None) -> Tensor:
_assert_grid_transform_inputs(img, matrix, interpolation, fill,
["nearest", "bilinear"])
w, h = img.shape[-1], img.shape[-2]
ow, oh = _compute_output_size(matrix, w, h) if expand else (w, h)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
theta = torch.tensor(
matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
# grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh)
return _apply_grid_transform(img, grid, interpolation, fill=fill)
def _perspective_grid(coeffs: List[float],
ow: int,
oh: int,
dtype: torch.dtype,
device: torch.device):
# https://github.com/python-pillow/Pillow/blob/4634eafe3c695a014267eefdce830b4a825beed7/
# src/libImaging/Geometry.c#L394
#
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
#
theta1 = torch.tensor(
[[[coeffs[0], coeffs[1], coeffs[2]],
[coeffs[3], coeffs[4], coeffs[5]]]],
dtype=dtype,
device=device)
theta2 = torch.tensor(
[[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]],
dtype=dtype,
device=device)
d = 0.5
base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
x_grid = torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow, device=device)
base_grid[..., 0].copy_(x_grid)
y_grid = torch.linspace(
d, oh * 1.0 + d - 1.0, steps=oh, device=device).unsqueeze_(-1)
base_grid[..., 1].copy_(y_grid)
base_grid[..., 2].fill_(1)
rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor(
[0.5 * ow, 0.5 * oh], dtype=dtype, device=device)
output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1)
output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2))
output_grid = output_grid1 / output_grid2 - 1.0
return output_grid.view(1, oh, ow, 2)
def perspective(img: Tensor,
perspective_coeffs: List[float],
interpolation: str="bilinear",
fill: Optional[List[float]]=None) -> Tensor:
if not (isinstance(img, torch.Tensor)):
raise TypeError('Input img should be Tensor.')
_assert_image_tensor(img)
_assert_grid_transform_inputs(
img,
matrix=None,
interpolation=interpolation,
fill=fill,
supported_interpolation_modes=["nearest", "bilinear"],
coeffs=perspective_coeffs)
ow, oh = img.shape[-1], img.shape[-2]
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
grid = _perspective_grid(
perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device)
return _apply_grid_transform(img, grid, interpolation, fill=fill)
def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:
ksize_half = (kernel_size - 1) * 0.5
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
kernel1d = pdf / pdf.sum()
return kernel1d
def _get_gaussian_kernel2d(kernel_size: List[int],
sigma: List[float],
dtype: torch.dtype,
device: torch.device) -> Tensor:
kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(
device, dtype=dtype)
kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(
device, dtype=dtype)
kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
return kernel2d
def gaussian_blur(img: Tensor, kernel_size: List[int],
sigma: List[float]) -> Tensor:
if not (isinstance(img, torch.Tensor)):
raise TypeError('img should be Tensor. Got {}'.format(type(img)))
_assert_image_tensor(img)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
kernel = _get_gaussian_kernel2d(
kernel_size, sigma, dtype=dtype, device=img.device)
kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
img, [kernel.dtype, ])
# padding = (left, right, top, bottom)
padding = [
kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2,
kernel_size[1] // 2
]
img = torch_pad(img, padding, mode="reflect")
img = conv2d(img, kernel, groups=img.shape[-3])
img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
return img
def invert(img: Tensor) -> Tensor:
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError(
"Input image tensor should have at least 3 dimensions, but found {}".
format(img.ndim))
_assert_channels(img, [1, 3])
bound = torch.tensor(
1 if img.is_floating_point() else 255,
dtype=img.dtype,
device=img.device)
return bound - img
def posterize(img: Tensor, bits: int) -> Tensor:
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError(
"Input image tensor should have at least 3 dimensions, but found {}".
format(img.ndim))
if img.dtype != torch.uint8:
raise TypeError(
"Only torch.uint8 image tensors are supported, but found {}".
format(img.dtype))
_assert_channels(img, [1, 3])
mask = -int(2**(8 - bits)) # JIT-friendly for: ~(2 ** (8 - bits) - 1)
return img & mask
def solarize(img: Tensor, threshold: float) -> Tensor:
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError(
"Input image tensor should have at least 3 dimensions, but found {}".
format(img.ndim))
_assert_channels(img, [1, 3])
inverted_img = invert(img)
return torch.where(img >= threshold, inverted_img, img)
def _blurred_degenerate_image(img: Tensor) -> Tensor:
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
kernel = torch.ones((3, 3), dtype=dtype, device=img.device)
kernel[1, 1] = 5.0
kernel /= kernel.sum()
kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
img, [kernel.dtype, ])
result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3])
result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze,
out_dtype)
result = img.clone()
result[..., 1:-1, 1:-1] = result_tmp
return result
def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
if sharpness_factor < 0:
raise ValueError('sharpness_factor ({}) is not non-negative.'.format(
sharpness_factor))
_assert_image_tensor(img)
_assert_channels(img, [1, 3])
if img.size(-1) <= 2 or img.size(-2) <= 2:
return img
return _blend(img, _blurred_degenerate_image(img), sharpness_factor)
def autocontrast(img: Tensor) -> Tensor:
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError(
"Input image tensor should have at least 3 dimensions, but found {}".
format(img.ndim))
_assert_channels(img, [1, 3])
bound = 1.0 if img.is_floating_point() else 255.0
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype)
maximum = img.amax(dim=(-2, -1), keepdim=True).to(dtype)
eq_idxs = torch.where(minimum == maximum)[0]
minimum[eq_idxs] = 0
maximum[eq_idxs] = bound
scale = bound / (maximum - minimum)
return ((img - minimum) * scale).clamp(0, bound).to(img.dtype)
def _scale_channel(img_chan):
# TODO: we should expect bincount to always be faster than histc, but this
# isn't always the case. Once
# https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
# block and only use bincount.
if img_chan.is_cuda:
hist = torch.histc(
img_chan.to(torch.float32), bins=256, min=0, max=255)
else:
hist = torch.bincount(img_chan.view(-1), minlength=256)
nonzero_hist = hist[hist != 0]
step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode='floor')
if step == 0:
return img_chan
lut = torch.div(
torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode='floor'),
step,
rounding_mode='floor')
lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255)
return lut[img_chan.to(torch.int64)].to(torch.uint8)
def _equalize_single_image(img: Tensor) -> Tensor:
return torch.stack([_scale_channel(img[c]) for c in range(img.size(0))])
def equalize(img: Tensor) -> Tensor:
_assert_image_tensor(img)
if not (3 <= img.ndim <= 4):
raise TypeError(
"Input image tensor should have 3 or 4 dimensions, but found {}".
format(img.ndim))
if img.dtype != torch.uint8:
raise TypeError(
"Only torch.uint8 image tensors are supported, but found {}".
format(img.dtype))
_assert_channels(img, [1, 3])
if img.ndim == 3:
return _equalize_single_image(img)
return torch.stack([_equalize_single_image(x) for x in img])
import math
import numbers
import random
import warnings
from collections.abc import Sequence
from typing import Tuple, List, Optional
import torch
from torch import Tensor
try:
import accimage
except ImportError:
accimage = None
from . import functional as F
from .functional import InterpolationMode, _interpolation_modes_from_int
__all__ = [
"Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage",
"Normalize", "Resize", "Scale", "CenterCrop", "Pad", "Lambda",
"RandomApply", "RandomChoice", "RandomOrder", "RandomCrop",
"RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop",
"RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
"ColorJitter", "RandomRotation", "RandomAffine", "Grayscale",
"RandomGrayscale", "RandomPerspective", "RandomErasing", "GaussianBlur",
"InterpolationMode", "RandomInvert", "RandomPosterize", "RandomSolarize",
"RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize"
]
class Compose:
"""Composes several transforms together. This transform does not support torchscript.
Please, see the note below.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.ToTensor(),
>>> ])
.. note::
In order to script the transformations, please use ``torch.nn.Sequential`` as below.
>>> transforms = torch.nn.Sequential(
>>> transforms.CenterCrop(10),
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
>>> )
>>> scripted_transforms = torch.jit.script(transforms)
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
`lambda` functions or ``PIL.Image``.
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img):
for t in self.transforms:
img = t(img)
return img
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
class ToTensor:
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript.
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
or if the numpy.ndarray has dtype = np.uint8
In the other cases, tensors are returned without scaling.
.. note::
Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
transforming target image masks. See the `references`_ for implementing the transforms for image masks.
.. _references: https://github.com/pytorch/vision/tree/master/references/segmentation
"""
def __call__(self, pic):
"""
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
return F.to_tensor(pic)
def __repr__(self):
return self.__class__.__name__ + '()'
class PILToTensor:
"""Convert a ``PIL Image`` to a tensor of the same type. This transform does not support torchscript.
Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W).
"""
def __call__(self, pic):
"""
Args:
pic (PIL Image): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
return F.pil_to_tensor(pic)
def __repr__(self):
return self.__class__.__name__ + '()'
class ConvertImageDtype(torch.nn.Module):
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly
This function does not support PIL Image.
Args:
dtype (torch.dtype): Desired data type of the output
.. note::
When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
If converted back and forth, this mismatch has no effect.
Raises:
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
of the integer ``dtype``.
"""
def __init__(self, dtype: torch.dtype) -> None:
super().__init__()
self.dtype = dtype
def forward(self, image):
return F.convert_image_dtype(image, self.dtype)
class ToPILImage:
"""Convert a tensor or an ndarray to PIL Image. This transform does not support torchscript.
Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
H x W x C to a PIL Image while preserving the value range.
Args:
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
If ``mode`` is ``None`` (default) there are some assumptions made about the input data:
- If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``.
- If the input has 3 channels, the ``mode`` is assumed to be ``RGB``.
- If the input has 2 channels, the ``mode`` is assumed to be ``LA``.
- If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``,
``short``).
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
"""
def __init__(self, mode=None):
self.mode = mode
def __call__(self, pic):
"""
Args:
pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
Returns:
PIL Image: Image converted to PIL Image.
"""
return F.to_pil_image(pic, self.mode)
def __repr__(self):
format_string = self.__class__.__name__ + '('
if self.mode is not None:
format_string += 'mode={0}'.format(self.mode)
format_string += ')'
return format_string
class Normalize(torch.nn.Module):
"""Normalize a tensor image with mean and standard deviation.
This transform does not support PIL Image.
Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
channels, this transform will normalize each channel of the input
``torch.*Tensor`` i.e.,
``output[channel] = (input[channel] - mean[channel]) / std[channel]``
.. note::
This transform acts out of place, i.e., it does not mutate the input tensor.
Args:
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
inplace(bool,optional): Bool to make this operation in-place.
"""
def __init__(self, mean, std, inplace=False):
super().__init__()
self.mean = mean
self.std = std
self.inplace = inplace
def forward(self, tensor: Tensor) -> Tensor:
"""
Args:
tensor (Tensor): Tensor image to be normalized.
Returns:
Tensor: Normalized Tensor image.
"""
return F.normalize(tensor, self.mean, self.std, self.inplace)
def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(
self.mean, self.std)
class Resize(torch.nn.Module):
"""Resize the input image to the given size.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
.. warning::
The output image might be different depending on its type: when downsampling, the interpolation of PIL images
and tensors is slightly different, because PIL applies antialiasing. This may lead to significant differences
in the performance of a network. Therefore, it is preferable to train and serve a model with the same input
types. See also below the ``antialias`` parameter, which can help making the output of PIL images and tensors
closer.
Args:
size (sequence or int): Desired output size. If size is a sequence like
(h, w), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size).
.. note::
In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and
``InterpolationMode.BICUBIC`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
max_size (int, optional): The maximum allowed for the longer edge of
the resized image: if the longer edge of the image is greater
than ``max_size`` after being resized according to ``size``, then
the image is resized again so that the longer edge is equal to
``max_size``. As a result, ``size`` might be overruled, i.e the
smaller edge may be shorter than ``size``. This is only supported
if ``size`` is an int (or a sequence of length 1 in torchscript
mode).
antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias
is always used. If ``img`` is Tensor, the flag is False by default and can be set to True for
``InterpolationMode.BILINEAR`` only mode. This can help making the output for PIL images and tensors
closer.
.. warning::
There is no autodiff support for ``antialias=True`` option with input ``img`` as Tensor.
"""
def __init__(self,
size,
interpolation=InterpolationMode.BILINEAR,
max_size=None,
antialias=None):
super().__init__()
if not isinstance(size, (int, Sequence)):
raise TypeError("Size should be int or sequence. Got {}".format(
type(size)))
if isinstance(size, Sequence) and len(size) not in (1, 2):
raise ValueError(
"If size is a sequence, it should have 1 or 2 values")
self.size = size
self.max_size = max_size
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationMode instead of int. "
"Please, use InterpolationMode enum.")
interpolation = _interpolation_modes_from_int(interpolation)
self.interpolation = interpolation
self.antialias = antialias
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be scaled.
Returns:
PIL Image or Tensor: Rescaled image.
"""
return F.resize(img, self.size, self.interpolation, self.max_size,
self.antialias)
def __repr__(self):
interpolate_str = self.interpolation.value
return self.__class__.__name__ + '(size={0}, interpolation={1}, max_size={2}, antialias={3})'.format(
self.size, interpolate_str, self.max_size, self.antialias)
class Scale(Resize):
"""
Note: This transform is deprecated in favor of Resize.
"""
def __init__(self, *args, **kwargs):
warnings.warn(
"The use of the transforms.Scale transform is deprecated, " +
"please use transforms.Resize instead.")
super(Scale, self).__init__(*args, **kwargs)
class CenterCrop(torch.nn.Module):
"""Crops the given image at the center.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
"""
def __init__(self, size):
super().__init__()
self.size = _setup_size(
size,
error_msg="Please provide only two dimensions (h, w) for size.")
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be cropped.
Returns:
PIL Image or Tensor: Cropped image.
"""
return F.center_crop(img, self.size)
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)
class Pad(torch.nn.Module):
"""Pad the given image on all sides with the given "pad" value.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric,
at most 3 leading dimensions for mode edge,
and an arbitrary number of leading dimensions for mode constant
Args:
padding (int or sequence): Padding on each border. If a single int is provided this
is used to pad all borders. If sequence of length 2 is provided this is the padding
on left/right and top/bottom respectively. If a sequence of length 4 is provided
this is the padding for the left, top, right and bottom borders respectively.
.. note::
In torchscript mode padding as single int is not supported, use a sequence of
length 1: ``[padding, ]``.
fill (number or str or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
length 3, it is used to fill R, G, B channels respectively.
This value is only used when the padding_mode is constant.
Only number is supported for torch Tensor.
Only int or str or tuple value is supported for PIL Image.
padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
Default is constant.
- constant: pads with a constant value, this value is specified with fill
- edge: pads with the last value at the edge of the image.
If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
- reflect: pads with reflection of image without repeating the last value on the edge.
For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
will result in [3, 2, 1, 2, 3, 4, 3, 2]
- symmetric: pads with reflection of image repeating the last value on the edge.
For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
will result in [2, 1, 1, 2, 3, 4, 4, 3]
"""
def __init__(self, padding, fill=0, padding_mode="constant"):
super().__init__()
if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if not isinstance(fill, (numbers.Number, str, tuple)):
raise TypeError("Got inappropriate fill arg")
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError(
"Padding mode should be either constant, edge, reflect or symmetric"
)
if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]:
raise ValueError(
"Padding must be an int or a 1, 2, or 4 element tuple, not a "
+ "{} element tuple".format(len(padding)))
self.padding = padding
self.fill = fill
self.padding_mode = padding_mode
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be padded.
Returns:
PIL Image or Tensor: Padded image.
"""
return F.pad(img, self.padding, self.fill, self.padding_mode)
def __repr__(self):
return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\
format(self.padding, self.fill, self.padding_mode)
class Lambda:
"""Apply a user-defined lambda as a transform. This transform does not support torchscript.
Args:
lambd (function): Lambda/function to be used for transform.
"""
def __init__(self, lambd):
if not callable(lambd):
raise TypeError("Argument lambd should be callable, got {}".format(
repr(type(lambd).__name__)))
self.lambd = lambd
def __call__(self, img):
return self.lambd(img)
def __repr__(self):
return self.__class__.__name__ + '()'
class RandomTransforms:
"""Base class for a list of transformations with randomness
Args:
transforms (sequence): list of transformations
"""
def __init__(self, transforms):
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence")
self.transforms = transforms
def __call__(self, *args, **kwargs):
raise NotImplementedError()
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
class RandomApply(torch.nn.Module):
"""Apply randomly a list of transformations with a given probability.
.. note::
In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
transforms as shown below:
>>> transforms = transforms.RandomApply(torch.nn.ModuleList([
>>> transforms.ColorJitter(),
>>> ]), p=0.3)
>>> scripted_transforms = torch.jit.script(transforms)
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
`lambda` functions or ``PIL.Image``.
Args:
transforms (sequence or torch.nn.Module): list of transformations
p (float): probability
"""
def __init__(self, transforms, p=0.5):
super().__init__()
self.transforms = transforms
self.p = p
def forward(self, img):
if self.p < torch.rand(1):
return img
for t in self.transforms:
img = t(img)
return img
def __repr__(self):
format_string = self.__class__.__name__ + '('
format_string += '\n p={}'.format(self.p)
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
class RandomOrder(RandomTransforms):
"""Apply a list of transformations in a random order. This transform does not support torchscript.
"""
def __call__(self, img):
order = list(range(len(self.transforms)))
random.shuffle(order)
for i in order:
img = self.transforms[i](img)
return img
class RandomChoice(RandomTransforms):
"""Apply single transformation randomly picked from a list. This transform does not support torchscript.
"""
def __call__(self, img):
t = random.choice(self.transforms)
return t(img)
class RandomCrop(torch.nn.Module):
"""Crop the given image at a random location.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions,
but if non-constant padding is used, the input is expected to have at most 2 leading dimensions
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
padding (int or sequence, optional): Optional padding on each border
of the image. Default is None. If a single int is provided this
is used to pad all borders. If sequence of length 2 is provided this is the padding
on left/right and top/bottom respectively. If a sequence of length 4 is provided
this is the padding for the left, top, right and bottom borders respectively.
.. note::
In torchscript mode padding as single int is not supported, use a sequence of
length 1: ``[padding, ]``.
pad_if_needed (boolean): It will pad the image if smaller than the
desired size to avoid raising an exception. Since cropping is done
after padding, the padding seems to be done at a random offset.
fill (number or str or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
length 3, it is used to fill R, G, B channels respectively.
This value is only used when the padding_mode is constant.
Only number is supported for torch Tensor.
Only int or str or tuple value is supported for PIL Image.
padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
Default is constant.
- constant: pads with a constant value, this value is specified with fill
- edge: pads with the last value at the edge of the image.
If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
- reflect: pads with reflection of image without repeating the last value on the edge.
For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
will result in [3, 2, 1, 2, 3, 4, 3, 2]
- symmetric: pads with reflection of image repeating the last value on the edge.
For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
will result in [2, 1, 1, 2, 3, 4, 4, 3]
"""
@staticmethod
def get_params(img: Tensor,
output_size: Tuple[int, int]) -> Tuple[int, int, int, int]:
"""Get parameters for ``crop`` for a random crop.
Args:
img (PIL Image or Tensor): Image to be cropped.
output_size (tuple): Expected output size of the crop.
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
"""
w, h = F._get_image_size(img)
th, tw = output_size
if h + 1 < th or w + 1 < tw:
raise ValueError(
"Required crop size {} is larger then input image size {}".
format((th, tw), (h, w)))
if w == tw and h == th:
return 0, 0, h, w
i = torch.randint(0, h - th + 1, size=(1, )).item()
j = torch.randint(0, w - tw + 1, size=(1, )).item()
return i, j, th, tw
def __init__(self,
size,
padding=None,
pad_if_needed=False,
fill=0,
padding_mode="constant"):
super().__init__()
self.size = tuple(
_setup_size(
size,
error_msg="Please provide only two dimensions (h, w) for size."
))
self.padding = padding
self.pad_if_needed = pad_if_needed
self.fill = fill
self.padding_mode = padding_mode
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be cropped.
Returns:
PIL Image or Tensor: Cropped image.
"""
if self.padding is not None:
img = F.pad(img, self.padding, self.fill, self.padding_mode)
width, height = F._get_image_size(img)
# pad the width if needed
if self.pad_if_needed and width < self.size[1]:
padding = [self.size[1] - width, 0]
img = F.pad(img, padding, self.fill, self.padding_mode)
# pad the height if needed
if self.pad_if_needed and height < self.size[0]:
padding = [0, self.size[0] - height]
img = F.pad(img, padding, self.fill, self.padding_mode)
i, j, h, w = self.get_params(img, self.size)
return F.crop(img, i, j, h, w)
def __repr__(self):
return self.__class__.__name__ + "(size={0}, padding={1})".format(
self.size, self.padding)
class RandomHorizontalFlip(torch.nn.Module):
"""Horizontally flip the given image randomly with a given probability.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions
Args:
p (float): probability of the image being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
super().__init__()
self.p = p
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be flipped.
Returns:
PIL Image or Tensor: Randomly flipped image.
"""
if torch.rand(1) < self.p:
return F.hflip(img)
return img
def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)
class RandomVerticalFlip(torch.nn.Module):
"""Vertically flip the given image randomly with a given probability.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions
Args:
p (float): probability of the image being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
super().__init__()
self.p = p
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be flipped.
Returns:
PIL Image or Tensor: Randomly flipped image.
"""
if torch.rand(1) < self.p:
return F.vflip(img)
return img
def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)
class RandomPerspective(torch.nn.Module):
"""Performs a random perspective transformation of the given image with a given probability.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
Default is 0.5.
p (float): probability of the image being transformed. Default is 0.5.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
fill (sequence or number): Pixel fill value for the area outside the transformed
image. Default is ``0``. If given a number, the value is used for all bands respectively.
"""
def __init__(self,
distortion_scale=0.5,
p=0.5,
interpolation=InterpolationMode.BILINEAR,
fill=0):
super().__init__()
self.p = p
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationMode instead of int. "
"Please, use InterpolationMode enum.")
interpolation = _interpolation_modes_from_int(interpolation)
self.interpolation = interpolation
self.distortion_scale = distortion_scale
if fill is None:
fill = 0
elif not isinstance(fill, (Sequence, numbers.Number)):
raise TypeError("Fill should be either a sequence or a number.")
self.fill = fill
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be Perspectively transformed.
Returns:
PIL Image or Tensor: Randomly transformed image.
"""
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img)
else:
fill = [float(f) for f in fill]
if torch.rand(1) < self.p:
width, height = F._get_image_size(img)
startpoints, endpoints = self.get_params(width, height,
self.distortion_scale)
return F.perspective(img, startpoints, endpoints,
self.interpolation, fill)
return img
@staticmethod
def get_params(width: int, height: int, distortion_scale: float) -> Tuple[
List[List[int]], List[List[int]]]:
"""Get parameters for ``perspective`` for a random perspective transform.
Args:
width (int): width of the image.
height (int): height of the image.
distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
Returns:
List containing [top-left, top-right, bottom-right, bottom-left] of the original image,
List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image.
"""
half_height = height // 2
half_width = width // 2
topleft = [
int(
torch.randint(
0, int(distortion_scale * half_width) + 1, size=(1, ))
.item()), int(
torch.randint(
0, int(distortion_scale * half_height) + 1, size=(1, ))
.item())
]
topright = [
int(
torch.randint(
width - int(distortion_scale * half_width) - 1,
width,
size=(1, )).item()),
int(
torch.randint(
0, int(distortion_scale * half_height) + 1, size=(1, ))
.item())
]
botright = [
int(
torch.randint(
width - int(distortion_scale * half_width) - 1,
width,
size=(1, )).item()), int(
torch.randint(
height - int(distortion_scale * half_height) - 1,
height,
size=(1, )).item())
]
botleft = [
int(
torch.randint(
0, int(distortion_scale * half_width) + 1, size=(1, ))
.item()), int(
torch.randint(
height - int(distortion_scale * half_height) - 1,
height,
size=(1, )).item())
]
startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1],
[0, height - 1]]
endpoints = [topleft, topright, botright, botleft]
return startpoints, endpoints
def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)
class RandomResizedCrop(torch.nn.Module):
"""Crop a random portion of image and resize it to a given size.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
A crop of the original image is made: the crop has a random area (H * W)
and a random aspect ratio. This crop is finally resized to the given
size. This is popularly used to train the Inception networks.
Args:
size (int or sequence): expected output size of the crop, for each edge. If size is an
int instead of sequence like (h, w), a square output size ``(size, size)`` is
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
.. note::
In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop,
before resizing. The scale is defined with respect to the area of the original image.
ratio (tuple of float): lower and upper bounds for the random aspect ratio of the crop, before
resizing.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and
``InterpolationMode.BICUBIC`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
"""
def __init__(self,
size,
scale=(0.08, 1.0),
ratio=(3. / 4., 4. / 3.),
interpolation=InterpolationMode.BILINEAR):
super().__init__()
self.size = _setup_size(
size,
error_msg="Please provide only two dimensions (h, w) for size.")
if not isinstance(scale, Sequence):
raise TypeError("Scale should be a sequence")
if not isinstance(ratio, Sequence):
raise TypeError("Ratio should be a sequence")
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("Scale and ratio should be of kind (min, max)")
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationMode instead of int. "
"Please, use InterpolationMode enum.")
interpolation = _interpolation_modes_from_int(interpolation)
self.interpolation = interpolation
self.scale = scale
self.ratio = ratio
@staticmethod
def get_params(img: Tensor, scale: List[float],
ratio: List[float]) -> Tuple[int, int, int, int]:
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (PIL Image or Tensor): Input image.
scale (list): range of scale of the origin size cropped
ratio (list): range of aspect ratio of the origin aspect ratio cropped
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop.
"""
width, height = F._get_image_size(img)
area = height * width
log_ratio = torch.log(torch.tensor(ratio))
for _ in range(10):
target_area = area * torch.empty(1).uniform_(scale[0],
scale[1]).item()
aspect_ratio = torch.exp(
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if 0 < w <= width and 0 < h <= height:
i = torch.randint(0, height - h + 1, size=(1, )).item()
j = torch.randint(0, width - w + 1, size=(1, )).item()
return i, j, h, w
# Fallback to central crop
in_ratio = float(width) / float(height)
if in_ratio < min(ratio):
w = width
h = int(round(w / min(ratio)))
elif in_ratio > max(ratio):
h = height
w = int(round(h * max(ratio)))
else: # whole image
w = width
h = height
i = (height - h) // 2
j = (width - w) // 2
return i, j, h, w
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be cropped and resized.
Returns:
PIL Image or Tensor: Randomly cropped and resized image.
"""
i, j, h, w = self.get_params(img, self.scale, self.ratio)
return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
def __repr__(self):
interpolate_str = self.interpolation.value
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
format_string += ', scale={0}'.format(
tuple(round(s, 4) for s in self.scale))
format_string += ', ratio={0}'.format(
tuple(round(r, 4) for r in self.ratio))
format_string += ', interpolation={0})'.format(interpolate_str)
return format_string
class RandomSizedCrop(RandomResizedCrop):
"""
Note: This transform is deprecated in favor of RandomResizedCrop.
"""
def __init__(self, *args, **kwargs):
warnings.warn(
"The use of the transforms.RandomSizedCrop transform is deprecated, "
+ "please use transforms.RandomResizedCrop instead.")
super(RandomSizedCrop, self).__init__(*args, **kwargs)
class FiveCrop(torch.nn.Module):
"""Crop the given image into four corners and the central crop.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions
.. Note::
This transform returns a tuple of images and there may be a mismatch in the number of
inputs and targets your Dataset returns. See below for an example of how to deal with
this.
Args:
size (sequence or int): Desired output size of the crop. If size is an ``int``
instead of sequence like (h, w), a square crop of size (size, size) is made.
If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
Example:
>>> transform = Compose([
>>> FiveCrop(size), # this is a list of PIL Images
>>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
>>> ])
>>> #In your test loop you can do the following:
>>> input, target = batch # input is a 5d tensor, target is 2d
>>> bs, ncrops, c, h, w = input.size()
>>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
>>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
"""
def __init__(self, size):
super().__init__()
self.size = _setup_size(
size,
error_msg="Please provide only two dimensions (h, w) for size.")
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be cropped.
Returns:
tuple of 5 images. Image can be PIL Image or Tensor
"""
return F.five_crop(img, self.size)
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)
class TenCrop(torch.nn.Module):
"""Crop the given image into four corners and the central crop plus the flipped version of
these (horizontal flipping is used by default).
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions
.. Note::
This transform returns a tuple of images and there may be a mismatch in the number of
inputs and targets your Dataset returns. See below for an example of how to deal with
this.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
vertical_flip (bool): Use vertical flipping instead of horizontal
Example:
>>> transform = Compose([
>>> TenCrop(size), # this is a list of PIL Images
>>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
>>> ])
>>> #In your test loop you can do the following:
>>> input, target = batch # input is a 5d tensor, target is 2d
>>> bs, ncrops, c, h, w = input.size()
>>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
>>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
"""
def __init__(self, size, vertical_flip=False):
super().__init__()
self.size = _setup_size(
size,
error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be cropped.
Returns:
tuple of 10 images. Image can be PIL Image or Tensor
"""
return F.ten_crop(img, self.size, self.vertical_flip)
def __repr__(self):
return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(
self.size, self.vertical_flip)
class LinearTransformation(torch.nn.Module):
"""Transform a tensor image with a square transformation matrix and a mean_vector computed
offline.
This transform does not support PIL Image.
Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and
subtract mean_vector from it which is then followed by computing the dot
product with the transformation matrix and then reshaping the tensor to its
original shape.
Applications:
whitening transformation: Suppose X is a column vector zero-centered data.
Then compute the data covariance matrix [D x D] with torch.mm(X.t(), X),
perform SVD on this matrix and pass it as transformation_matrix.
Args:
transformation_matrix (Tensor): tensor [D x D], D = C x H x W
mean_vector (Tensor): tensor [D], D = C x H x W
"""
def __init__(self, transformation_matrix, mean_vector):
super().__init__()
if transformation_matrix.size(0) != transformation_matrix.size(1):
raise ValueError("transformation_matrix should be square. Got " +
"[{} x {}] rectangular matrix.".format(
*transformation_matrix.size()))
if mean_vector.size(0) != transformation_matrix.size(0):
raise ValueError(
"mean_vector should have the same length {}".format(
mean_vector.size(0)) +
" as any one of the dimensions of the transformation_matrix [{}]"
.format(tuple(transformation_matrix.size())))
if transformation_matrix.device != mean_vector.device:
raise ValueError(
"Input tensors should be on the same device. Got {} and {}"
.format(transformation_matrix.device, mean_vector.device))
self.transformation_matrix = transformation_matrix
self.mean_vector = mean_vector
def forward(self, tensor: Tensor) -> Tensor:
"""
Args:
tensor (Tensor): Tensor image to be whitened.
Returns:
Tensor: Transformed image.
"""
shape = tensor.shape
n = shape[-3] * shape[-2] * shape[-1]
if n != self.transformation_matrix.shape[0]:
raise ValueError(
"Input tensor and transformation matrix have incompatible shape."
+ "[{} x {} x {}] != ".format(shape[-3], shape[-2], shape[
-1]) + "{}".format(self.transformation_matrix.shape[0]))
if tensor.device.type != self.mean_vector.device.type:
raise ValueError(
"Input tensor should be on the same device as transformation matrix and mean vector. "
"Got {} vs {}".format(tensor.device, self.mean_vector.device))
flat_tensor = tensor.view(-1, n) - self.mean_vector
transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
tensor = transformed_tensor.view(shape)
return tensor
def __repr__(self):
format_string = self.__class__.__name__ + '(transformation_matrix='
format_string += (str(self.transformation_matrix.tolist()) + ')')
format_string += (
", (mean_vector=" + str(self.mean_vector.tolist()) + ')')
return format_string
class ColorJitter(torch.nn.Module):
"""Randomly change the brightness, contrast, saturation and hue of an image.
If the image is torch Tensor, it is expected
to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, mode "1", "L", "I", "F" and modes with transparency (alpha channel) are not supported.
Args:
brightness (float or tuple of float (min, max)): How much to jitter brightness.
brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
or the given [min, max]. Should be non negative numbers.
contrast (float or tuple of float (min, max)): How much to jitter contrast.
contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
or the given [min, max]. Should be non negative numbers.
saturation (float or tuple of float (min, max)): How much to jitter saturation.
saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
or the given [min, max]. Should be non negative numbers.
hue (float or tuple of float (min, max)): How much to jitter hue.
hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
"""
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
super().__init__()
self.brightness = self._check_input(brightness, 'brightness')
self.contrast = self._check_input(contrast, 'contrast')
self.saturation = self._check_input(saturation, 'saturation')
self.hue = self._check_input(
hue, 'hue', center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
@torch.jit.unused
def _check_input(self,
value,
name,
center=1,
bound=(0, float('inf')),
clip_first_on_zero=True):
if isinstance(value, numbers.Number):
if value < 0:
raise ValueError(
"If {} is a single number, it must be non negative.".
format(name))
value = [center - float(value), center + float(value)]
if clip_first_on_zero:
value[0] = max(value[0], 0.0)
elif isinstance(value, (tuple, list)) and len(value) == 2:
if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError("{} values should be between {}".format(
name, bound))
else:
raise TypeError(
"{} should be a single number or a list/tuple with length 2.".
format(name))
# if value is 0 or (1., 1.) for brightness/contrast/saturation
# or (0., 0.) for hue, do nothing
if value[0] == value[1] == center:
value = None
return value
@staticmethod
def get_params(
brightness: Optional[List[float]],
contrast: Optional[List[float]],
saturation: Optional[List[float]],
hue: Optional[List[float]]) -> Tuple[Tensor, Optional[
float], Optional[float], Optional[float], Optional[float]]:
"""Get the parameters for the randomized transform to be applied on image.
Args:
brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen
uniformly. Pass None to turn off the transformation.
contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen
uniformly. Pass None to turn off the transformation.
saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen
uniformly. Pass None to turn off the transformation.
hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
Pass None to turn off the transformation.
Returns:
tuple: The parameters used to apply the randomized transform
along with their random order.
"""
fn_idx = torch.randperm(4)
b = None if brightness is None else float(
torch.empty(1).uniform_(brightness[0], brightness[1]))
c = None if contrast is None else float(
torch.empty(1).uniform_(contrast[0], contrast[1]))
s = None if saturation is None else float(
torch.empty(1).uniform_(saturation[0], saturation[1]))
h = None if hue is None else float(
torch.empty(1).uniform_(hue[0], hue[1]))
return fn_idx, b, c, s, h
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Input image.
Returns:
PIL Image or Tensor: Color jittered image.
"""
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \
self.get_params(self.brightness, self.contrast, self.saturation, self.hue)
for fn_id in fn_idx:
if fn_id == 0 and brightness_factor is not None:
img = F.adjust_brightness(img, brightness_factor)
elif fn_id == 1 and contrast_factor is not None:
img = F.adjust_contrast(img, contrast_factor)
elif fn_id == 2 and saturation_factor is not None:
img = F.adjust_saturation(img, saturation_factor)
elif fn_id == 3 and hue_factor is not None:
img = F.adjust_hue(img, hue_factor)
return img
def __repr__(self):
format_string = self.__class__.__name__ + '('
format_string += 'brightness={0}'.format(self.brightness)
format_string += ', contrast={0}'.format(self.contrast)
format_string += ', saturation={0}'.format(self.saturation)
format_string += ', hue={0})'.format(self.hue)
return format_string
class RandomRotation(torch.nn.Module):
"""Rotate the image by angle.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
degrees (sequence or number): Range of degrees to select from.
If degrees is a number instead of sequence like (min, max), the range of degrees
will be (-degrees, +degrees).
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
expand (bool, optional): Optional expansion flag.
If true, expands the output to make it large enough to hold the entire rotated image.
If false or omitted, make the output image the same size as the input image.
Note that the expand flag assumes rotation around the center and no translation.
center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
Default is the center of the image.
fill (sequence or number): Pixel fill value for the area outside the rotated
image. Default is ``0``. If given a number, the value is used for all bands respectively.
resample (int, optional): deprecated argument and will be removed since v0.10.0.
Please use the ``interpolation`` parameter instead.
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
"""
def __init__(self,
degrees,
interpolation=InterpolationMode.NEAREST,
expand=False,
center=None,
fill=0,
resample=None):
super().__init__()
if resample is not None:
warnings.warn(
"Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead"
)
interpolation = _interpolation_modes_from_int(resample)
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationMode instead of int. "
"Please, use InterpolationMode enum.")
interpolation = _interpolation_modes_from_int(interpolation)
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, ))
if center is not None:
_check_sequence_input(center, "center", req_sizes=(2, ))
self.center = center
self.resample = self.interpolation = interpolation
self.expand = expand
if fill is None:
fill = 0
elif not isinstance(fill, (Sequence, numbers.Number)):
raise TypeError("Fill should be either a sequence or a number.")
self.fill = fill
@staticmethod
def get_params(degrees: List[float]) -> float:
"""Get parameters for ``rotate`` for a random rotation.
Returns:
float: angle parameter to be passed to ``rotate`` for random rotation.
"""
angle = float(
torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item(
))
return angle
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be rotated.
Returns:
PIL Image or Tensor: Rotated image.
"""
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img)
else:
fill = [float(f) for f in fill]
angle = self.get_params(self.degrees)
return F.rotate(img, angle, self.resample, self.expand, self.center,
fill)
def __repr__(self):
interpolate_str = self.interpolation.value
format_string = self.__class__.__name__ + '(degrees={0}'.format(
self.degrees)
format_string += ', interpolation={0}'.format(interpolate_str)
format_string += ', expand={0}'.format(self.expand)
if self.center is not None:
format_string += ', center={0}'.format(self.center)
if self.fill is not None:
format_string += ', fill={0}'.format(self.fill)
format_string += ')'
return format_string
class RandomAffine(torch.nn.Module):
"""Random affine transformation of the image keeping center invariant.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
degrees (sequence or number): Range of degrees to select from.
If degrees is a number instead of sequence like (min, max), the range of degrees
will be (-degrees, +degrees). Set to 0 to deactivate rotations.
translate (tuple, optional): tuple of maximum absolute fraction for horizontal
and vertical translations. For example translate=(a, b), then horizontal shift
is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
randomly sampled from the range a <= scale <= b. Will keep original scale by default.
shear (sequence or number, optional): Range of degrees to select from.
If shear is a number, a shear parallel to the x axis in the range (-shear, +shear)
will be applied. Else if shear is a sequence of 2 values a shear parallel to the x axis in the
range (shear[0], shear[1]) will be applied. Else if shear is a sequence of 4 values,
a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
Will not apply shear by default.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
fill (sequence or number): Pixel fill value for the area outside the transformed
image. Default is ``0``. If given a number, the value is used for all bands respectively.
fillcolor (sequence or number, optional): deprecated argument and will be removed since v0.10.0.
Please use the ``fill`` parameter instead.
resample (int, optional): deprecated argument and will be removed since v0.10.0.
Please use the ``interpolation`` parameter instead.
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
"""
def __init__(self,
degrees,
translate=None,
scale=None,
shear=None,
interpolation=InterpolationMode.NEAREST,
fill=0,
fillcolor=None,
resample=None):
super().__init__()
if resample is not None:
warnings.warn(
"Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead"
)
interpolation = _interpolation_modes_from_int(resample)
# Backward compatibility with integer value
if isinstance(interpolation, int):
warnings.warn(
"Argument interpolation should be of type InterpolationMode instead of int. "
"Please, use InterpolationMode enum.")
interpolation = _interpolation_modes_from_int(interpolation)
if fillcolor is not None:
warnings.warn(
"Argument fillcolor is deprecated and will be removed since v0.10.0. Please, use fill instead"
)
fill = fillcolor
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, ))
if translate is not None:
_check_sequence_input(translate, "translate", req_sizes=(2, ))
for t in translate:
if not (0.0 <= t <= 1.0):
raise ValueError(
"translation values should be between 0 and 1")
self.translate = translate
if scale is not None:
_check_sequence_input(scale, "scale", req_sizes=(2, ))
for s in scale:
if s <= 0:
raise ValueError("scale values should be positive")
self.scale = scale
if shear is not None:
self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4))
else:
self.shear = shear
self.resample = self.interpolation = interpolation
if fill is None:
fill = 0
elif not isinstance(fill, (Sequence, numbers.Number)):
raise TypeError("Fill should be either a sequence or a number.")
self.fillcolor = self.fill = fill
@staticmethod
def get_params(degrees: List[float],
translate: Optional[List[float]],
scale_ranges: Optional[List[float]],
shears: Optional[List[float]],
img_size: List[int]) -> Tuple[float, Tuple[int, int], float,
Tuple[float, float]]:
"""Get parameters for affine transformation
Returns:
params to be passed to the affine transformation
"""
angle = float(
torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item(
))
if translate is not None:
max_dx = float(translate[0] * img_size[0])
max_dy = float(translate[1] * img_size[1])
tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
translations = (tx, ty)
else:
translations = (0, 0)
if scale_ranges is not None:
scale = float(
torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item(
))
else:
scale = 1.0
shear_x = shear_y = 0.0
if shears is not None:
shear_x = float(
torch.empty(1).uniform_(shears[0], shears[1]).item())
if len(shears) == 4:
shear_y = float(
torch.empty(1).uniform_(shears[2], shears[3]).item())
shear = (shear_x, shear_y)
return angle, translations, scale, shear
def forward(self, img):
"""
img (PIL Image or Tensor): Image to be transformed.
Returns:
PIL Image or Tensor: Affine transformed image.
"""
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img)
else:
fill = [float(f) for f in fill]
img_size = F._get_image_size(img)
ret = self.get_params(self.degrees, self.translate, self.scale,
self.shear, img_size)
return F.affine(img, *ret, interpolation=self.interpolation, fill=fill)
def __repr__(self):
s = '{name}(degrees={degrees}'
if self.translate is not None:
s += ', translate={translate}'
if self.scale is not None:
s += ', scale={scale}'
if self.shear is not None:
s += ', shear={shear}'
if self.interpolation != InterpolationMode.NEAREST:
s += ', interpolation={interpolation}'
if self.fill != 0:
s += ', fill={fill}'
s += ')'
d = dict(self.__dict__)
d['interpolation'] = self.interpolation.value
return s.format(name=self.__class__.__name__, **d)
class Grayscale(torch.nn.Module):
"""Convert image to grayscale.
If the image is torch Tensor, it is expected
to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
Args:
num_output_channels (int): (1 or 3) number of channels desired for output image
Returns:
PIL Image: Grayscale version of the input.
- If ``num_output_channels == 1`` : returned image is single channel
- If ``num_output_channels == 3`` : returned image is 3 channel with r == g == b
"""
def __init__(self, num_output_channels=1):
super().__init__()
self.num_output_channels = num_output_channels
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be converted to grayscale.
Returns:
PIL Image or Tensor: Grayscaled image.
"""
return F.rgb_to_grayscale(
img, num_output_channels=self.num_output_channels)
def __repr__(self):
return self.__class__.__name__ + '(num_output_channels={0})'.format(
self.num_output_channels)
class RandomGrayscale(torch.nn.Module):
"""Randomly convert image to grayscale with a probability of p (default 0.1).
If the image is torch Tensor, it is expected
to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
Args:
p (float): probability that image should be converted to grayscale.
Returns:
PIL Image or Tensor: Grayscale version of the input image with probability p and unchanged
with probability (1-p).
- If input image is 1 channel: grayscale version is 1 channel
- If input image is 3 channel: grayscale version is 3 channel with r == g == b
"""
def __init__(self, p=0.1):
super().__init__()
self.p = p
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be converted to grayscale.
Returns:
PIL Image or Tensor: Randomly grayscaled image.
"""
num_output_channels = F._get_image_num_channels(img)
if torch.rand(1) < self.p:
return F.rgb_to_grayscale(
img, num_output_channels=num_output_channels)
return img
def __repr__(self):
return self.__class__.__name__ + '(p={0})'.format(self.p)
class RandomErasing(torch.nn.Module):
""" Randomly selects a rectangle region in an torch Tensor image and erases its pixels.
This transform does not support PIL Image.
'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896
Args:
p: probability that the random erasing operation will be performed.
scale: range of proportion of erased area against input image.
ratio: range of aspect ratio of erased area.
value: erasing value. Default is 0. If a single int, it is used to
erase all pixels. If a tuple of length 3, it is used to erase
R, G, B channels respectively.
If a str of 'random', erasing each pixel with random values.
inplace: boolean to make this transform inplace. Default set to False.
Returns:
Erased Image.
Example:
>>> transform = transforms.Compose([
>>> transforms.RandomHorizontalFlip(),
>>> transforms.ToTensor(),
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
>>> transforms.RandomErasing(),
>>> ])
"""
def __init__(self,
p=0.5,
scale=(0.02, 0.33),
ratio=(0.3, 3.3),
value=0,
inplace=False):
super().__init__()
if not isinstance(value, (numbers.Number, str, tuple, list)):
raise TypeError(
"Argument value should be either a number or str or a sequence")
if isinstance(value, str) and value != "random":
raise ValueError("If value is str, it should be 'random'")
if not isinstance(scale, (tuple, list)):
raise TypeError("Scale should be a sequence")
if not isinstance(ratio, (tuple, list)):
raise TypeError("Ratio should be a sequence")
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("Scale and ratio should be of kind (min, max)")
if scale[0] < 0 or scale[1] > 1:
raise ValueError("Scale should be between 0 and 1")
if p < 0 or p > 1:
raise ValueError(
"Random erasing probability should be between 0 and 1")
self.p = p
self.scale = scale
self.ratio = ratio
self.value = value
self.inplace = inplace
@staticmethod
def get_params(img: Tensor,
scale: Tuple[float, float],
ratio: Tuple[float, float],
value: Optional[List[float]]=None) -> Tuple[int, int, int,
int, Tensor]:
"""Get parameters for ``erase`` for a random erasing.
Args:
img (Tensor): Tensor image to be erased.
scale (sequence): range of proportion of erased area against input image.
ratio (sequence): range of aspect ratio of erased area.
value (list, optional): erasing value. If None, it is interpreted as "random"
(erasing each pixel with random values). If ``len(value)`` is 1, it is interpreted as a number,
i.e. ``value[0]``.
Returns:
tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing.
"""
img_c, img_h, img_w = img.shape[-3], img.shape[-2], img.shape[-1]
area = img_h * img_w
log_ratio = torch.log(torch.tensor(ratio))
for _ in range(10):
erase_area = area * torch.empty(1).uniform_(scale[0],
scale[1]).item()
aspect_ratio = torch.exp(
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
h = int(round(math.sqrt(erase_area * aspect_ratio)))
w = int(round(math.sqrt(erase_area / aspect_ratio)))
if not (h < img_h and w < img_w):
continue
if value is None:
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
else:
v = torch.tensor(value)[:, None, None]
i = torch.randint(0, img_h - h + 1, size=(1, )).item()
j = torch.randint(0, img_w - w + 1, size=(1, )).item()
return i, j, h, w, v
# Return original image
return 0, 0, img_h, img_w, img
def forward(self, img):
"""
Args:
img (Tensor): Tensor image to be erased.
Returns:
img (Tensor): Erased Tensor image.
"""
if torch.rand(1) < self.p:
# cast self.value to script acceptable type
if isinstance(self.value, (int, float)):
value = [self.value, ]
elif isinstance(self.value, str):
value = None
elif isinstance(self.value, tuple):
value = list(self.value)
else:
value = self.value
if value is not None and not (len(value) in (1, img.shape[-3])):
raise ValueError(
"If value is a sequence, it should have either a single value or "
"{} (number of input channels)".format(img.shape[-3]))
x, y, h, w, v = self.get_params(
img, scale=self.scale, ratio=self.ratio, value=value)
return F.erase(img, x, y, h, w, v, self.inplace)
return img
def __repr__(self):
s = '(p={}, '.format(self.p)
s += 'scale={}, '.format(self.scale)
s += 'ratio={}, '.format(self.ratio)
s += 'value={}, '.format(self.value)
s += 'inplace={})'.format(self.inplace)
return self.__class__.__name__ + s
class GaussianBlur(torch.nn.Module):
"""Blurs image with randomly chosen Gaussian blur.
If the image is torch Tensor, it is expected
to have [..., C, H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
kernel_size (int or sequence): Size of the Gaussian kernel.
sigma (float or tuple of float (min, max)): Standard deviation to be used for
creating kernel to perform blurring. If float, sigma is fixed. If it is tuple
of float (min, max), sigma is chosen uniformly at random to lie in the
given range.
Returns:
PIL Image or Tensor: Gaussian blurred version of the input image.
"""
def __init__(self, kernel_size, sigma=(0.1, 2.0)):
super().__init__()
self.kernel_size = _setup_size(
kernel_size, "Kernel size should be a tuple/list of two integers")
for ks in self.kernel_size:
if ks <= 0 or ks % 2 == 0:
raise ValueError(
"Kernel size value should be an odd and positive number.")
if isinstance(sigma, numbers.Number):
if sigma <= 0:
raise ValueError(
"If sigma is a single number, it must be positive.")
sigma = (sigma, sigma)
elif isinstance(sigma, Sequence) and len(sigma) == 2:
if not 0. < sigma[0] <= sigma[1]:
raise ValueError(
"sigma values should be positive and of the form (min, max)."
)
else:
raise ValueError(
"sigma should be a single number or a list/tuple with length 2."
)
self.sigma = sigma
@staticmethod
def get_params(sigma_min: float, sigma_max: float) -> float:
"""Choose sigma for random gaussian blurring.
Args:
sigma_min (float): Minimum standard deviation that can be chosen for blurring kernel.
sigma_max (float): Maximum standard deviation that can be chosen for blurring kernel.
Returns:
float: Standard deviation to be passed to calculate kernel for gaussian blurring.
"""
return torch.empty(1).uniform_(sigma_min, sigma_max).item()
def forward(self, img: Tensor) -> Tensor:
"""
Args:
img (PIL Image or Tensor): image to be blurred.
Returns:
PIL Image or Tensor: Gaussian blurred image
"""
sigma = self.get_params(self.sigma[0], self.sigma[1])
return F.gaussian_blur(img, self.kernel_size, [sigma, sigma])
def __repr__(self):
s = '(kernel_size={}, '.format(self.kernel_size)
s += 'sigma={})'.format(self.sigma)
return self.__class__.__name__ + s
def _setup_size(size, error_msg):
if isinstance(size, numbers.Number):
return int(size), int(size)
if isinstance(size, Sequence) and len(size) == 1:
return size[0], size[0]
if len(size) != 2:
raise ValueError(error_msg)
return size
def _check_sequence_input(x, name, req_sizes):
msg = req_sizes[0] if len(req_sizes) < 2 else " or ".join(
[str(s) for s in req_sizes])
if not isinstance(x, Sequence):
raise TypeError("{} should be a sequence of length {}.".format(name,
msg))
if len(x) not in req_sizes:
raise ValueError("{} should be sequence of length {}.".format(name,
msg))
def _setup_angle(x, name, req_sizes=(2, )):
if isinstance(x, numbers.Number):
if x < 0:
raise ValueError(
"If {} is a single number, it must be positive.".format(name))
x = [-x, x]
else:
_check_sequence_input(x, name, req_sizes)
return [float(d) for d in x]
class RandomInvert(torch.nn.Module):
"""Inverts the colors of the given image randomly with a given probability.
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
p (float): probability of the image being color inverted. Default value is 0.5
"""
def __init__(self, p=0.5):
super().__init__()
self.p = p
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be inverted.
Returns:
PIL Image or Tensor: Randomly color inverted image.
"""
if torch.rand(1).item() < self.p:
return F.invert(img)
return img
def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)
class RandomPosterize(torch.nn.Module):
"""Posterize the image randomly with a given probability by reducing the
number of bits for each color channel. If the image is torch Tensor, it should be of type torch.uint8,
and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
bits (int): number of bits to keep for each channel (0-8)
p (float): probability of the image being color inverted. Default value is 0.5
"""
def __init__(self, bits, p=0.5):
super().__init__()
self.bits = bits
self.p = p
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be posterized.
Returns:
PIL Image or Tensor: Randomly posterized image.
"""
if torch.rand(1).item() < self.p:
return F.posterize(img, self.bits)
return img
def __repr__(self):
return self.__class__.__name__ + '(bits={},p={})'.format(self.bits,
self.p)
class RandomSolarize(torch.nn.Module):
"""Solarize the image randomly with a given probability by inverting all pixel
values above a threshold. If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
threshold (float): all pixels equal or above this value are inverted.
p (float): probability of the image being color inverted. Default value is 0.5
"""
def __init__(self, threshold, p=0.5):
super().__init__()
self.threshold = threshold
self.p = p
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be solarized.
Returns:
PIL Image or Tensor: Randomly solarized image.
"""
if torch.rand(1).item() < self.p:
return F.solarize(img, self.threshold)
return img
def __repr__(self):
return self.__class__.__name__ + '(threshold={},p={})'.format(
self.threshold, self.p)
class RandomAdjustSharpness(torch.nn.Module):
"""Adjust the sharpness of the image randomly with a given probability. If the image is torch Tensor,
it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
sharpness_factor (float): How much to adjust the sharpness. Can be
any non negative number. 0 gives a blurred image, 1 gives the
original image while 2 increases the sharpness by a factor of 2.
p (float): probability of the image being color inverted. Default value is 0.5
"""
def __init__(self, sharpness_factor, p=0.5):
super().__init__()
self.sharpness_factor = sharpness_factor
self.p = p
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be sharpened.
Returns:
PIL Image or Tensor: Randomly sharpened image.
"""
if torch.rand(1).item() < self.p:
return F.adjust_sharpness(img, self.sharpness_factor)
return img
def __repr__(self):
return self.__class__.__name__ + '(sharpness_factor={},p={})'.format(
self.sharpness_factor, self.p)
class RandomAutocontrast(torch.nn.Module):
"""Autocontrast the pixels of the given image randomly with a given probability.
If the image is torch Tensor, it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
p (float): probability of the image being autocontrasted. Default value is 0.5
"""
def __init__(self, p=0.5):
super().__init__()
self.p = p
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be autocontrasted.
Returns:
PIL Image or Tensor: Randomly autocontrasted image.
"""
if torch.rand(1).item() < self.p:
return F.autocontrast(img)
return img
def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)
class RandomEqualize(torch.nn.Module):
"""Equalize the histogram of the given image randomly with a given probability.
If the image is torch Tensor, it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
Args:
p (float): probability of the image being equalized. Default value is 0.5
"""
def __init__(self, p=0.5):
super().__init__()
self.p = p
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be equalized.
Returns:
PIL Image or Tensor: Randomly equalized image.
"""
if torch.rand(1).item() < self.p:
return F.equalize(img)
return img
def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)
import datetime
import os
import time
import torch
import torch.utils.data
from torch import nn
import torchvision
import presets
import utils
try:
from apex import amp
except ImportError:
amp = None
import sys
sys.path.insert(0, ".")
import numpy as np
from reprod_log import ReprodLogger
def train_one_epoch(model,
criterion,
optimizer,
data_loader,
device,
epoch,
print_freq,
apex=False):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter(
'lr', utils.SmoothedValue(
window_size=1, fmt='{value}'))
metric_logger.add_meter(
'img/s', utils.SmoothedValue(
window_size=10, fmt='{value}'))
header = 'Epoch: [{}]'.format(epoch)
for image, target in metric_logger.log_every(data_loader, print_freq,
header):
start_time = time.time()
image, target = image.to(device), target.to(device)
output = model(image)
loss = criterion(output, target)
optimizer.zero_grad()
if apex:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
batch_size = image.shape[0]
metric_logger.update(
loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
metric_logger.meters['img/s'].update(batch_size /
(time.time() - start_time))
def evaluate(model, criterion, data_loader, device, print_freq=100):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
with torch.no_grad():
for image, target in metric_logger.log_every(data_loader, print_freq,
header):
image = image.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
output = model(image)
loss = criterion(output, target)
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
# FIXME need to take into account that the datasets
# could have been padded in distributed setup
batch_size = image.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print(' * Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}'.format(
top1=metric_logger.acc1, top5=metric_logger.acc5))
return metric_logger.acc1.global_avg
def _get_cache_path(filepath):
import hashlib
h = hashlib.sha1(filepath.encode()).hexdigest()
cache_path = os.path.join("~", ".torch", "vision", "datasets",
"imagefolder", h[:10] + ".pt")
cache_path = os.path.expanduser(cache_path)
return cache_path
def load_data(traindir, valdir, args):
# Data loading code
print("Loading data")
resize_size, crop_size = (342, 299) if args.model == 'inception_v3' else (
256, 224)
print("Loading training data")
st = time.time()
cache_path = _get_cache_path(traindir)
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print("Loading dataset_train from {}".format(cache_path))
dataset, _ = torch.load(cache_path)
else:
auto_augment_policy = getattr(args, "auto_augment", None)
random_erase_prob = getattr(args, "random_erase", 0.0)
dataset = torchvision.datasets.ImageFolder(
traindir,
presets.ClassificationPresetTrain(
crop_size=crop_size,
auto_augment_policy=auto_augment_policy,
random_erase_prob=random_erase_prob))
if args.cache_dataset:
print("Saving dataset_train to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset, traindir), cache_path)
print("Took", time.time() - st)
print("Loading validation data")
cache_path = _get_cache_path(valdir)
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print("Loading dataset_test from {}".format(cache_path))
dataset_test, _ = torch.load(cache_path)
else:
dataset_test = torchvision.datasets.ImageFolder(
valdir,
presets.ClassificationPresetEval(
crop_size=crop_size, resize_size=resize_size))
if args.cache_dataset:
print("Saving dataset_test to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset_test, valdir), cache_path)
print("Creating data loaders")
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(
dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(
dataset_test)
else:
train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
return dataset, dataset_test, train_sampler, test_sampler
def main(args):
if args.apex and amp is None:
raise RuntimeError(
"Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
"to enable mixed-precision training.")
if args.output_dir:
utils.mkdir(args.output_dir)
utils.init_distributed_mode(args)
print(args)
device = torch.device(args.device)
torch.backends.cudnn.benchmark = True
train_dir = os.path.join(args.data_path, 'train')
val_dir = os.path.join(args.data_path, 'val')
dataset, dataset_test, train_sampler, test_sampler = load_data(
train_dir, val_dir, args)
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
sampler=train_sampler,
num_workers=args.workers,
pin_memory=True)
data_loader_test = torch.utils.data.DataLoader(
dataset_test,
batch_size=args.batch_size,
sampler=test_sampler,
num_workers=args.workers,
pin_memory=True)
print("Creating model")
model = torchvision.models.__dict__[args.model](pretrained=args.pretrained)
model.to(device)
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
criterion = nn.CrossEntropyLoss()
opt_name = args.opt.lower()
if opt_name == 'sgd':
optimizer = torch.optim.SGD(model.parameters(),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
elif opt_name == 'rmsprop':
optimizer = torch.optim.RMSprop(
model.parameters(),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
eps=0.0316,
alpha=0.9)
else:
raise RuntimeError(
"Invalid optimizer {}. Only SGD and RMSprop are supported.".format(
args.opt))
if args.apex:
model, optimizer = amp.initialize(
model, optimizer, opt_level=args.apex_opt_level)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[args.gpu])
model_without_ddp = model.module
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if args.test_only:
# return top1 for record
top1 = evaluate(model, criterion, data_loader_test, device=device)
return top1
print("Start training")
start_time = time.time()
best_top1 = 0.0
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device,
epoch, args.print_freq, args.apex)
lr_scheduler.step()
top1 = evaluate(model, criterion, data_loader_test, device=device)
best_top1 = max(best_top1, top1)
if args.output_dir:
checkpoint = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args
}
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
utils.save_on_master(
checkpoint, os.path.join(args.output_dir, 'checkpoint.pth'))
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
return best_top1
def get_args_parser(add_help=True):
import argparse
parser = argparse.ArgumentParser(
description='PyTorch Classification Training', add_help=add_help)
parser.add_argument('--data-path', default='../lite_data', help='dataset')
parser.add_argument('--model', default='mobilenet_v3_small', help='model')
parser.add_argument('--device', default='cuda', help='device')
parser.add_argument('-b', '--batch-size', default=32, type=int)
parser.add_argument(
'--epochs',
default=90,
type=int,
metavar='N',
help='number of total epochs to run')
parser.add_argument(
'-j',
'--workers',
default=16,
type=int,
metavar='N',
help='number of data loading workers (default: 16)')
parser.add_argument('--opt', default='sgd', type=str, help='optimizer')
parser.add_argument(
'--lr', default=0.00125, type=float, help='initial learning rate')
parser.add_argument(
'--momentum', default=0.9, type=float, metavar='M', help='momentum')
parser.add_argument(
'--wd',
'--weight-decay',
default=1e-4,
type=float,
metavar='W',
help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument(
'--lr-step-size',
default=30,
type=int,
help='decrease lr every step-size epochs')
parser.add_argument(
'--lr-gamma',
default=0.1,
type=float,
help='decrease lr by a factor of lr-gamma')
parser.add_argument(
'--print-freq', default=10, type=int, help='print frequency')
parser.add_argument('--output-dir', default='.', help='path where to save')
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument(
'--start-epoch', default=0, type=int, metavar='N', help='start epoch')
parser.add_argument(
"--cache-dataset",
dest="cache_dataset",
help="Cache the datasets for quicker initialization. It also serializes the transforms",
action="store_true", )
parser.add_argument(
"--sync-bn",
dest="sync_bn",
help="Use sync batch norm",
action="store_true", )
parser.add_argument(
"--test-only",
dest="test_only",
help="Only test the model",
action="store_true", )
parser.add_argument(
"--pretrained",
dest="pretrained",
help="Use pre-trained models from the modelzoo",
action="store_true", )
parser.add_argument(
'--auto-augment',
default=None,
help='auto augment policy (default: None)')
parser.add_argument(
'--random-erase',
default=0.0,
type=float,
help='random erasing probability (default: 0.0)')
# Mixed precision training parameters
parser.add_argument(
'--apex',
action='store_true',
help='Use apex for mixed precision training')
parser.add_argument(
'--apex-opt-level',
default='O1',
type=str,
help='For apex mixed precision training'
'O0 for FP32 training, O1 for mixed precision training.'
'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet'
)
# distributed training parameters
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument(
'--world-size',
default=1,
type=int,
help='number of distributed processes')
parser.add_argument(
'--dist-url',
default='env://',
help='url used to set up distributed training')
return parser
if __name__ == "__main__":
args = get_args_parser().parse_args()
top1 = main(args)
reprod_logger = ReprodLogger()
reprod_logger.add("top1", np.array([top1]))
reprod_logger.save("train_align_torch.npy")
from collections import defaultdict, deque, OrderedDict
import copy
import datetime
import hashlib
import time
import torch
import torch.distributed as dist
import errno
import os
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor(
[self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append("{}: {}".format(name, str(meter)))
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
if torch.cuda.is_available():
log_msg = self.delimiter.join([
header, '[{0' + space_fmt + '}/{1}]', 'eta: {eta}', '{meters}',
'time: {time}', 'data: {data}', 'max mem: {memory:.0f}'
])
else:
log_msg = self.delimiter.join([
header, '[{0' + space_fmt + '}/{1}]', 'eta: {eta}', '{meters}',
'time: {time}', 'data: {data}'
])
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
else:
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {}'.format(header, total_time_str))
def accuracy(output, target, topk=(1, )):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target[None])
res = []
for k in topk:
correct_k = correct[:k].flatten().sum(dtype=torch.float32)
res.append(correct_k * (100.0 / batch_size))
return res
def mkdir(path):
try:
os.makedirs(path)
except OSError as e:
if e.errno != errno.EEXIST:
raise
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
elif hasattr(args, "rank"):
pass
else:
print('Not using distributed mode')
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print(
'| distributed init (rank {}): {}'.format(args.rank, args.dist_url),
flush=True)
torch.distributed.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank)
setup_for_distributed(args.rank == 0)
def average_checkpoints(inputs):
"""Loads checkpoints from inputs and returns a model with averaged weights. Original implementation taken from:
https://github.com/pytorch/fairseq/blob/a48f235636557b8d3bc4922a6fa90f3a0fa57955/scripts/average_checkpoints.py#L16
Args:
inputs (List[str]): An iterable of string paths of checkpoints to load from.
Returns:
A dict of string keys mapping to various values. The 'model' key
from the returned dict should correspond to an OrderedDict mapping
string parameter names to torch Tensors.
"""
params_dict = OrderedDict()
params_keys = None
new_state = None
num_models = len(inputs)
for fpath in inputs:
with open(fpath, "rb") as f:
state = torch.load(
f,
map_location=(
lambda s, _: torch.serialization.default_restore_location(s, "cpu")
), )
# Copies over the settings from the first checkpoint
if new_state is None:
new_state = state
model_params = state["model"]
model_params_keys = list(model_params.keys())
if params_keys is None:
params_keys = model_params_keys
elif params_keys != model_params_keys:
raise KeyError("For checkpoint {}, expected list of params: {}, "
"but found: {}".format(f, params_keys,
model_params_keys))
for k in params_keys:
p = model_params[k]
if isinstance(p, torch.HalfTensor):
p = p.float()
if k not in params_dict:
params_dict[k] = p.clone()
# NOTE: clone() is needed in case of p is a shared parameter
else:
params_dict[k] += p
averaged_params = OrderedDict()
for k, v in params_dict.items():
averaged_params[k] = v
if averaged_params[k].is_floating_point():
averaged_params[k].div_(num_models)
else:
averaged_params[k] //= num_models
new_state["model"] = averaged_params
return new_state
def store_model_weights(model,
checkpoint_path,
checkpoint_key='model',
strict=True):
"""
This method can be used to prepare weights files for new models. It receives as
input a model architecture and a checkpoint from the training script and produces
a file with the weights ready for release.
Examples:
from torchvision import models as M
# Classification
model = M.mobilenet_v3_large(pretrained=False)
print(store_model_weights(model, './class.pth'))
# Quantized Classification
model = M.quantization.mobilenet_v3_large(pretrained=False, quantize=False)
model.fuse_model()
model.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')
_ = torch.quantization.prepare_qat(model, inplace=True)
print(store_model_weights(model, './qat.pth'))
# Object Detection
model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, pretrained_backbone=False)
print(store_model_weights(model, './obj.pth'))
# Segmentation
model = M.segmentation.deeplabv3_mobilenet_v3_large(pretrained=False, pretrained_backbone=False, aux_loss=True)
print(store_model_weights(model, './segm.pth', strict=False))
Args:
model (pytorch.nn.Module): The model on which the weights will be loaded for validation purposes.
checkpoint_path (str): The path of the checkpoint we will load.
checkpoint_key (str, optional): The key of the checkpoint where the model weights are stored.
Default: "model".
strict (bool): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
Returns:
output_path (str): The location where the weights are saved.
"""
# Store the new model next to the checkpoint_path
checkpoint_path = os.path.abspath(checkpoint_path)
output_dir = os.path.dirname(checkpoint_path)
# Deep copy to avoid side-effects on the model object.
model = copy.deepcopy(model)
checkpoint = torch.load(checkpoint_path, map_location='cpu')
# Load the weights to the model to validate that everything works
# and remove unnecessary weights (such as auxiliaries, etc)
model.load_state_dict(checkpoint[checkpoint_key], strict=strict)
tmp_path = os.path.join(output_dir, str(model.__hash__()))
torch.save(model.state_dict(), tmp_path)
sha256_hash = hashlib.sha256()
with open(tmp_path, "rb") as f:
# Read and update hash string value in blocks of 4K
for byte_block in iter(lambda: f.read(4096), b""):
sha256_hash.update(byte_block)
hh = sha256_hash.hexdigest()
output_path = os.path.join(output_dir, "weights-" + str(hh[:8]) + ".pth")
os.replace(tmp_path, output_path)
return output_path
reprod-log
\ No newline at end of file
[2021/12/22 20:08:46] root INFO: acc_top1:
[2021/12/22 20:08:46] root INFO: mean diff: check passed: True, value: 0.0
[2021/12/22 20:08:46] root INFO: acc_top5:
[2021/12/22 20:08:46] root INFO: mean diff: check passed: True, value: 0.0
[2021/12/22 20:08:46] root INFO: diff check passed
[2021/12/23 17:49:27] root INFO: loss_0:
[2021/12/23 17:49:27] root INFO: mean diff: check passed: False, value: 1.9073486328125e-06
[2021/12/23 17:49:27] root INFO: lr_0:
[2021/12/23 17:49:27] root INFO: mean diff: check passed: True, value: 0.0
[2021/12/23 17:49:27] root INFO: loss_1:
[2021/12/23 17:49:27] root INFO: mean diff: check passed: False, value: 2.384185791015625e-06
[2021/12/23 17:49:27] root INFO: lr_1:
[2021/12/23 17:49:27] root INFO: mean diff: check passed: True, value: 0.0
[2021/12/23 17:49:27] root INFO: loss_2:
[2021/12/23 17:49:27] root INFO: mean diff: check passed: False, value: 7.62939453125e-06
[2021/12/23 17:49:27] root INFO: lr_2:
[2021/12/23 17:49:27] root INFO: mean diff: check passed: True, value: 0.0
[2021/12/23 17:49:27] root INFO: loss_3:
[2021/12/23 17:49:27] root INFO: mean diff: check passed: False, value: 0.002070903778076172
[2021/12/23 17:49:27] root INFO: lr_3:
[2021/12/23 17:49:27] root INFO: mean diff: check passed: True, value: 0.0
[2021/12/23 17:49:27] root INFO: loss_4:
[2021/12/23 17:49:27] root INFO: mean diff: check passed: False, value: 0.002232074737548828
[2021/12/23 17:49:27] root INFO: lr_4:
[2021/12/23 17:49:27] root INFO: mean diff: check passed: True, value: 0.0
[2021/12/23 17:49:27] root INFO: loss_5:
[2021/12/23 17:49:27] root INFO: mean diff: check passed: False, value: 0.03954291343688965
[2021/12/23 17:49:27] root INFO: lr_5:
[2021/12/23 17:49:27] root INFO: mean diff: check passed: True, value: 0.0
[2021/12/23 17:49:27] root INFO: diff check failed
[2021/12/23 17:21:22] root INFO: length:
[2021/12/23 17:21:22] root INFO: mean diff: check passed: True, value: 0.0
[2021/12/23 17:21:22] root INFO: dataloader_0:
[2021/12/23 17:21:22] root INFO: mean diff: check passed: True, value: 0.0
[2021/12/23 17:21:22] root INFO: dataloader_1:
[2021/12/23 17:21:22] root INFO: mean diff: check passed: True, value: 0.0
[2021/12/23 17:21:22] root INFO: dataloader_2:
[2021/12/23 17:21:22] root INFO: mean diff: check passed: True, value: 0.0
[2021/12/23 17:21:22] root INFO: dataloader_3:
[2021/12/23 17:21:22] root INFO: mean diff: check passed: True, value: 0.0
[2021/12/23 17:21:22] root INFO: diff check passed
[2021/12/23 17:44:09] root INFO: logits:
[2021/12/23 17:44:09] root INFO: mean diff: check passed: False, value: 2.308018565599923e-06
[2021/12/23 17:44:09] root INFO: diff check failed
[2021/12/23 17:46:12] root INFO: loss:
[2021/12/23 17:46:12] root INFO: mean diff: check passed: True, value: 0.0
[2021/12/23 17:46:12] root INFO: diff check passed
[2021/12/23 17:45:32] root INFO: acc_top1:
[2021/12/23 17:45:32] root INFO: mean diff: check passed: True, value: 0.0
[2021/12/23 17:45:32] root INFO: acc_top5:
[2021/12/23 17:45:32] root INFO: mean diff: check passed: True, value: 0.0
[2021/12/23 17:45:32] root INFO: diff check passed
import numpy as np
def gen_fake_data():
fake_data = np.random.rand(1, 3, 224, 224).astype(np.float32) - 0.5
fake_label = np.arange(1).astype(np.int64)
np.save("fake_data.npy", fake_data)
np.save("fake_label.npy", fake_label)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册