未验证 提交 beec7ed2 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

add swin2sr_real_sr_x4 (#2085)

上级 36ce4789
# swin2sr_real_sr_x4
|模型名称|swin2sr_real_sr_x4|
| :--- | :---: |
|类别|图像-图像编辑|
|网络|Swin2SR|
|数据集|DIV2K / Flickr2K|
|是否支持Fine-tuning|否|
|模型大小|68.4MB|
|指标|-|
|最新更新日期|2022-10-25|
## 一、模型基本信息
- ### 应用效果展示
- 网络结构:
<p align="center">
<img src="https://ai-studio-static-online.cdn.bcebos.com/884d4d4472b44bf1879606374ed64a7e8d2fec0bcf034285a5cecfc582e8cd65" hspace='10'/> <br />
</p>
- 样例结果示例:
<p align="center">
<img src="https://ai-studio-static-online.cdn.bcebos.com/c5517af6c3f944c4b281aedc417a4f8c02c0a969d0dd494c9106c4ff2709fc2f" hspace='10'/>
<img src="https://ai-studio-static-online.cdn.bcebos.com/183c5821029f45bbb78d1700ab8297baabba15f82ab4467e88414bbed056ccf0" hspace='10'/>
</p>
- ### 模型介绍
- Swin2SR 是一个基于 Swin Transformer v2 的图像超分辨率模型。swin2sr_real_sr_x4 是基于 Swin2SR 的 4 倍现实图像超分辨率模型。
## 二、安装
- ### 1、环境依赖
- paddlepaddle >= 2.0.0
- paddlehub >= 2.0.0
- ### 2.安装
- ```shell
$ hub install swin2sr_real_sr_x4
```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
## 三、模型API预测
- ### 1、命令行预测
```shell
$ hub run swin2sr_real_sr_x4 \
--input_path "/PATH/TO/IMAGE" \
--output_dir "swin2sr_real_sr_x4_output"
```
- ### 2、预测代码示例
```python
import paddlehub as hub
import cv2
module = hub.Module(name="swin2sr_real_sr_x4")
result = module.real_sr(
image=cv2.imread('/PATH/TO/IMAGE'),
visualization=True,
output_dir='swin2sr_real_sr_x4_output'
)
```
- ### 3、API
```python
def real_sr(
image: Union[str, numpy.ndarray],
visualization: bool = True,
output_dir: str = "swin2sr_real_sr_x4_output"
) -> numpy.ndarray
```
- 超分辨率 API
- **参数**
* image (Union\[str, numpy.ndarray\]): 图片数据,ndarray.shape 为 \[H, W, C\],BGR格式;
* visualization (bool): 是否将识别结果保存为图片文件;
* output\_dir (str): 保存处理结果的文件目录。
- **返回**
* res (numpy.ndarray): 图像超分辨率结果 (BGR);
## 四、服务部署
- PaddleHub Serving 可以部署一个图像超分辨率的在线服务。
- ### 第一步:启动PaddleHub Serving
- 运行启动命令:
```shell
$ hub serving start -m swin2sr_real_sr_x4
```
- 这样就完成了一个图像超分辨率服务化API的部署,默认端口号为8866。
- ### 第二步:发送预测请求
- 配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
```python
import requests
import json
import base64
import cv2
import numpy as np
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tobytes()).decode('utf8')
def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8'))
data = np.frombuffer(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data
# 发送HTTP请求
org_im = cv2.imread('/PATH/TO/IMAGE')
data = {
'image': cv2_to_base64(org_im)
}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/swin2sr_real_sr_x4"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
# 结果转换
results = r.json()['results']
results = base64_to_cv2(results)
# 保存结果
cv2.imwrite('output.jpg', results)
```
## 五、参考资料
* 论文:[Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration](https://arxiv.org/abs/2209.11345)
* 官方实现:[mv-lab/swin2sr](https://github.com/mv-lab/swin2sr/)
## 六、更新历史
* 1.0.0
初始发布
```shell
$ hub install swin2sr_real_sr_x4==1.0.0
```
import argparse
import base64
import os
import time
from typing import Union
import cv2
import numpy as np
import paddle
import paddle.nn as nn
from .swin2sr import Swin2SR
from paddlehub.module.module import moduleinfo
from paddlehub.module.module import runnable
from paddlehub.module.module import serving
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tobytes()).decode('utf8')
def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8'))
data = np.frombuffer(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data
@moduleinfo(
name='swin2sr_real_sr_x4',
version='1.0.0',
type="CV/image_editing",
author="",
author_email="",
summary="SwinV2 Transformer for Compressed Image Super-Resolution and Restoration.",
)
class SwinIRMRealSR(nn.Layer):
def __init__(self):
super(SwinIRMRealSR, self).__init__()
self.default_pretrained_model_path = os.path.join(self.directory,
'Swin2SR_RealworldSR_X4_64_BSRGAN_PSNR.pdparams')
self.swin2sr = Swin2SR(upscale=4,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.,
depths=[6, 6, 6, 6, 6, 6],
embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2,
upsampler='nearest+conv',
resi_connection='1conv')
state_dict = paddle.load(self.default_pretrained_model_path)
self.swin2sr.set_state_dict(state_dict)
self.swin2sr.eval()
def preprocess(self, img: np.ndarray) -> np.ndarray:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img.transpose((2, 0, 1))
img = img / 255.0
return img.astype(np.float32)
def postprocess(self, img: np.ndarray) -> np.ndarray:
img = img.clip(0, 1)
img = img * 255.0
img = img.transpose((1, 2, 0))
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
return img.astype(np.uint8)
def real_sr(self,
image: Union[str, np.ndarray],
visualization: bool = True,
output_dir: str = "swin2sr_real_sr_x4_output") -> np.ndarray:
if isinstance(image, str):
_, file_name = os.path.split(image)
save_name, _ = os.path.splitext(file_name)
save_name = save_name + '_' + str(int(time.time())) + '.jpg'
image = cv2.imdecode(np.fromfile(image, dtype=np.uint8), cv2.IMREAD_COLOR)
elif isinstance(image, np.ndarray):
save_name = str(int(time.time())) + '.jpg'
image = image
else:
raise Exception("image should be a str / np.ndarray")
with paddle.no_grad():
img_input = self.preprocess(image)
img_input = paddle.to_tensor(img_input[None, ...], dtype=paddle.float32)
img_output = self.swin2sr(img_input)
img_output = img_output.numpy()[0]
img_output = self.postprocess(img_output)
if visualization:
if not os.path.isdir(output_dir):
os.makedirs(output_dir)
save_path = os.path.join(output_dir, save_name)
cv2.imwrite(save_path, img_output)
return img_output
@runnable
def run_cmd(self, argvs):
"""
Run as a command.
"""
self.parser = argparse.ArgumentParser(description="Run the {} module.".format(self.name),
prog='hub run {}'.format(self.name),
usage='%(prog)s',
add_help=True)
self.parser.add_argument('--input_path', type=str, help="Path to image.")
self.parser.add_argument('--output_dir',
type=str,
default='swin2sr_real_sr_x4_output',
help="The directory to save output images.")
args = self.parser.parse_args(argvs)
self.real_sr(image=args.input_path, visualization=True, output_dir=args.output_dir)
return 'Results are saved in %s' % args.output_dir
@serving
def serving_method(self, image, **kwargs):
"""
Run as a service.
"""
image = base64_to_cv2(image)
img_output = self.real_sr(image=image, **kwargs)
return cv2_to_base64(img_output)
import collections.abc
import math
from itertools import repeat
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)
class Mlp(nn.Layer):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.reshape((B, H // window_size, window_size, W // window_size, window_size, C))
windows = x.transpose((0, 1, 3, 2, 4, 5)).reshape((-1, window_size, window_size, C))
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.reshape((B, H // window_size, W // window_size, window_size, window_size, -1))
x = x.transpose((0, 1, 3, 2, 4, 5)).reshape((B, H, W, -1))
return x
class WindowAttention(nn.Layer):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
"""
def __init__(self,
dim,
window_size,
num_heads,
qkv_bias=True,
attn_drop=0.,
proj_drop=0.,
pretrained_window_size=[0, 0]):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.pretrained_window_size = pretrained_window_size
self.num_heads = num_heads
self.logit_scale = self.create_parameter(shape=(num_heads, 1, 1),
dtype=paddle.float32,
default_initializer=nn.initializer.Assign(
paddle.log(10 * paddle.ones((num_heads, 1, 1)))))
# mlp to generate continuous relative position bias
self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias_attr=True), nn.ReLU(),
nn.Linear(512, num_heads, bias_attr=False))
# get relative_coords_table
relative_coords_h = paddle.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=paddle.float32)
relative_coords_w = paddle.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=paddle.float32)
relative_coords_table = paddle.stack(paddle.meshgrid([relative_coords_h, relative_coords_w])).transpose(
(1, 2, 0)).unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
if pretrained_window_size[0] > 0:
relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
else:
relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
relative_coords_table *= 8 # normalize to -8, 8
relative_coords_table = paddle.sign(relative_coords_table) * paddle.log2(
paddle.abs(relative_coords_table) + 1.0) / np.log2(8)
self.register_buffer("relative_coords_table", relative_coords_table)
# get pair-wise relative position index for each token inside the window
coords_h = paddle.arange(self.window_size[0])
coords_w = paddle.arange(self.window_size[1])
coords = paddle.stack(paddle.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = paddle.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - \
coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.transpose((1, 2, 0)) # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - \
1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias_attr=False)
if qkv_bias:
self.q_bias = self.create_parameter(shape=(dim, ),
dtype=paddle.float32,
default_initializer=nn.initializer.Constant(0.0))
self.v_bias = self.create_parameter(shape=(dim, ),
dtype=paddle.float32,
default_initializer=nn.initializer.Constant(0.0))
else:
self.q_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(axis=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = paddle.concat((self.q_bias, paddle.zeros_like(self.v_bias), self.v_bias))
qkv = F.linear(x=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape((B_, N, 3, self.num_heads, -1)).transpose((2, 0, 3, 1, 4))
# make torchscript happy (cannot use tensor as tuple)
q, k, v = qkv[0], qkv[1], qkv[2]
# cosine attention
attn = (F.normalize(q, axis=-1) @ F.normalize(k, axis=-1).transpose((0, 1, 3, 2)))
logit_scale = paddle.clip(self.logit_scale, max=paddle.log(paddle.to_tensor(1. / 0.01))).exp()
attn = attn * logit_scale
relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).reshape((-1, self.num_heads))
relative_position_bias = relative_position_bias_table[self.relative_position_index.reshape((-1, ))].reshape(
(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1],
-1)) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.transpose((2, 0, 1)) # nH, Wh*Ww, Wh*Ww
relative_position_bias = 16 * \
nn.functional.sigmoid(relative_position_bias)
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.reshape((B_ // nW, nW, self.num_heads, N, N)) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.reshape((-1, self.num_heads, N, N))
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose((0, 2, 1, 3)).reshape((B_, N, C))
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, ' \
f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
class SwinTransformerBlock(nn.Layer):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Layer, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Layer, optional): Normalization layer. Default: nn.LayerNorm
pretrained_window_size (int): Window size in pre-training.
"""
def __init__(self,
dim,
input_resolution,
num_heads,
window_size=7,
shift_size=0,
mlp_ratio=4.,
qkv_bias=True,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
pretrained_window_size=0):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(dim,
window_size=to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
pretrained_window_size=to_2tuple(pretrained_window_size))
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.drop_path = nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if self.shift_size > 0:
attn_mask = self.calculate_mask(self.input_resolution)
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def calculate_mask(self, x_size):
# calculate attention mask for SW-MSA
H, W = x_size
img_mask = paddle.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size), slice(-self.window_size,
-self.shift_size if self.shift_size else None),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size), slice(-self.window_size,
-self.shift_size if self.shift_size else None),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
# nW, window_size, window_size, 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.reshape((-1, self.window_size * self.window_size))
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
_h = paddle.full_like(attn_mask, -100.0, dtype=paddle.float32)
_z = paddle.full_like(attn_mask, 0.0, dtype=paddle.float32)
attn_mask = paddle.where(attn_mask != 0, _h, _z)
# attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
def forward(self, x, x_size):
H, W = x_size
B, L, C = x.shape
#assert L == H * W, "input feature has wrong size"
shortcut = x
x = x.reshape((B, H, W, C))
# cyclic shift
if self.shift_size > 0:
shifted_x = paddle.roll(x, shifts=(-self.shift_size, -self.shift_size), axis=(1, 2))
else:
shifted_x = x
# partition windows
# nW*B, window_size, window_size, C
x_windows = window_partition(shifted_x, self.window_size)
# nW*B, window_size*window_size, C
x_windows = x_windows.reshape((-1, self.window_size * self.window_size, C))
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
if self.input_resolution == x_size:
# nW*B, window_size*window_size, C
attn_windows = self.attn(x_windows, mask=self.attn_mask)
else:
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size))
# merge windows
attn_windows = attn_windows.reshape((-1, self.window_size, self.window_size, C))
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = paddle.roll(shifted_x, shifts=(self.shift_size, self.shift_size), axis=(1, 2))
else:
x = shifted_x
x = x.reshape((B, H * W, C))
x = shortcut + self.drop_path(self.norm1(x))
# FFN
x = x + self.drop_path(self.norm2(self.mlp(x)))
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
class PatchMerging(nn.Layer):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Layer, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias_attr=False)
self.norm = norm_layer(2 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.reshape((B, H, W, C))
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = paddle.concat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.rehsape((B, -1, 4 * C)) # B H/2*W/2 4*C
x = self.reduction(x)
x = self.norm(x)
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
flops += H * W * self.dim // 2
return flops
class BasicLayer(nn.Layer):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Layer, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Layer | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
pretrained_window_size (int): Local window size in pre-training.
"""
def __init__(self,
dim,
input_resolution,
depth,
num_heads,
window_size,
mlp_ratio=4.,
qkv_bias=True,
drop=0.,
attn_drop=0.,
drop_path=0.,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False,
pretrained_window_size=0):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.LayerList([
SwinTransformerBlock(dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
pretrained_window_size=pretrained_window_size) for i in range(depth)
])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x, x_size):
for blk in self.blocks:
x = blk(x, x_size)
if self.downsample is not None:
x = self.downsample(x)
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops()
if self.downsample is not None:
flops += self.downsample.flops()
return flops
class PatchEmbed(nn.Layer):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Layer, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2D(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
# assert H == self.img_size[0] and W == self.img_size[1],
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose((0, 2, 1)) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
def flops(self):
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * \
(self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops
class RSTB(nn.Layer):
"""Residual Swin Transformer Block (RSTB).
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Layer, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Layer | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
img_size: Input image size.
patch_size: Patch size.
resi_connection: The convolutional block before residual connection.
"""
def __init__(self,
dim,
input_resolution,
depth,
num_heads,
window_size,
mlp_ratio=4.,
qkv_bias=True,
drop=0.,
attn_drop=0.,
drop_path=0.,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False,
img_size=224,
patch_size=4,
resi_connection='1conv'):
super(RSTB, self).__init__()
self.dim = dim
self.input_resolution = input_resolution
self.residual_group = BasicLayer(dim=dim,
input_resolution=input_resolution,
depth=depth,
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path,
norm_layer=norm_layer,
downsample=downsample,
use_checkpoint=use_checkpoint)
if resi_connection == '1conv':
self.conv = nn.Conv2D(dim, dim, 3, 1, 1)
elif resi_connection == '3conv':
# to save parameters and memory
self.conv = nn.Sequential(nn.Conv2D(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, ),
nn.Conv2D(dim // 4, dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, ),
nn.Conv2D(dim // 4, dim, 3, 1, 1))
self.patch_embed = PatchEmbed(img_size=img_size,
patch_size=patch_size,
in_chans=dim,
embed_dim=dim,
norm_layer=None)
self.patch_unembed = PatchUnEmbed(img_size=img_size,
patch_size=patch_size,
in_chans=dim,
embed_dim=dim,
norm_layer=None)
def forward(self, x, x_size):
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
def flops(self):
flops = 0
flops += self.residual_group.flops()
H, W = self.input_resolution
flops += H * W * self.dim * self.dim * 9
flops += self.patch_embed.flops()
flops += self.patch_unembed.flops()
return flops
class PatchUnEmbed(nn.Layer):
r""" Image to Patch Unembedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Layer, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
def forward(self, x, x_size):
B, HW, C = x.shape
x = x.transpose((0, 2, 1)).reshape((B, self.embed_dim, x_size[0], x_size[1])) # B Ph*Pw C
return x
def flops(self):
flops = 0
return flops
class Upsample(nn.Sequential):
"""Upsample module.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat):
m = []
if (scale & (scale - 1)) == 0: # scale = 2^n
for _ in range(int(math.log(scale, 2))):
m.append(nn.Conv2D(num_feat, 4 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(2))
elif scale == 3:
m.append(nn.Conv2D(num_feat, 9 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. '
'Supported scales: 2^n and 3.')
super(Upsample, self).__init__(*m)
class Upsample_hf(nn.Sequential):
"""Upsample module.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat):
m = []
if (scale & (scale - 1)) == 0: # scale = 2^n
for _ in range(int(math.log(scale, 2))):
m.append(nn.Conv2D(num_feat, 4 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(2))
elif scale == 3:
m.append(nn.Conv2D(num_feat, 9 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. '
'Supported scales: 2^n and 3.')
super(Upsample_hf, self).__init__(*m)
class UpsampleOneStep(nn.Sequential):
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
Used in lightweight SR to save parameters.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
self.num_feat = num_feat
self.input_resolution = input_resolution
m = []
m.append(nn.Conv2D(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
m.append(nn.PixelShuffle(scale))
super(UpsampleOneStep, self).__init__(*m)
def flops(self):
H, W = self.input_resolution
flops = H * W * self.num_feat * 3 * 9
return flops
class Swin2SR(nn.Layer):
r""" Swin2SR
A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
Args:
img_size (int | tuple(int)): Input image size. Default 64
patch_size (int | tuple(int)): Patch size. Default: 1
in_chans (int): Number of input image channels. Default: 3
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Layer): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
img_range: Image range. 1. or 255.
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
"""
def __init__(self,
img_size=64,
patch_size=1,
in_chans=3,
embed_dim=96,
depths=[6, 6, 6, 6],
num_heads=[6, 6, 6, 6],
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
ape=False,
patch_norm=True,
use_checkpoint=False,
upscale=2,
img_range=1.,
upsampler='',
resi_connection='1conv',
**kwargs):
super(Swin2SR, self).__init__()
num_in_ch = in_chans
num_out_ch = in_chans
num_feat = 64
self.img_range = img_range
if in_chans == 3:
rgb_mean = (0.4488, 0.4371, 0.4040)
self.mean = paddle.to_tensor(rgb_mean).reshape((1, 3, 1, 1))
else:
self.mean = paddle.zeros((1, 1, 1, 1))
self.upscale = upscale
self.upsampler = upsampler
self.window_size = window_size
#####################################################################################################
################################### 1, shallow feature extraction ###################################
self.conv_first = nn.Conv2D(num_in_ch, embed_dim, 3, 1, 1)
#####################################################################################################
################################### 2, deep feature extraction ######################################
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = embed_dim
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(img_size=img_size,
patch_size=patch_size,
in_chans=embed_dim,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# merge non-overlapping patches into image
self.patch_unembed = PatchUnEmbed(img_size=img_size,
patch_size=patch_size,
in_chans=embed_dim,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
# absolute position embedding
if self.ape:
self.absolute_pos_embed = self.create_parameter(shape=(1, num_patches, embed_dim),
dtype=paddle.float32,
default_initializer=nn.initializer.Constant(0.0))
# trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in paddle.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build Residual Swin Transformer blocks (RSTB)
self.layers = nn.LayerList()
for i_layer in range(self.num_layers):
layer = RSTB(
dim=embed_dim,
input_resolution=(patches_resolution[0], patches_resolution[1]),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer,
downsample=None,
use_checkpoint=use_checkpoint,
img_size=img_size,
patch_size=patch_size,
resi_connection=resi_connection)
self.layers.append(layer)
if self.upsampler == 'pixelshuffle_hf':
self.layers_hf = nn.LayerList()
for i_layer in range(self.num_layers):
layer = RSTB(
dim=embed_dim,
input_resolution=(patches_resolution[0], patches_resolution[1]),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer,
downsample=None,
use_checkpoint=use_checkpoint,
img_size=img_size,
patch_size=patch_size,
resi_connection=resi_connection)
self.layers_hf.append(layer)
self.norm = norm_layer(self.num_features)
# build the last conv layer in deep feature extraction
if resi_connection == '1conv':
self.conv_after_body = nn.Conv2D(embed_dim, embed_dim, 3, 1, 1)
elif resi_connection == '3conv':
# to save parameters and memory
self.conv_after_body = nn.Sequential(nn.Conv2D(embed_dim, embed_dim // 4, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.2, ),
nn.Conv2D(embed_dim // 4, embed_dim // 4, 1, 1, 0),
nn.LeakyReLU(negative_slope=0.2, ),
nn.Conv2D(embed_dim // 4, embed_dim, 3, 1, 1))
#####################################################################################################
################################ 3, high quality image reconstruction ################################
if self.upsampler == 'pixelshuffle':
# for classical SR
self.conv_before_upsample = nn.Sequential(nn.Conv2D(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU())
self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2D(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == 'pixelshuffle_aux':
self.conv_bicubic = nn.Conv2D(num_in_ch, num_feat, 3, 1, 1)
self.conv_before_upsample = nn.Sequential(nn.Conv2D(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU())
self.conv_aux = nn.Conv2D(num_feat, num_out_ch, 3, 1, 1)
self.conv_after_aux = nn.Sequential(nn.Conv2D(3, num_feat, 3, 1, 1), nn.LeakyReLU())
self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2D(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == 'pixelshuffle_hf':
self.conv_before_upsample = nn.Sequential(nn.Conv2D(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU())
self.upsample = Upsample(upscale, num_feat)
self.upsample_hf = Upsample_hf(upscale, num_feat)
self.conv_last = nn.Conv2D(num_feat, num_out_ch, 3, 1, 1)
self.conv_first_hf = nn.Sequential(nn.Conv2D(num_feat, embed_dim, 3, 1, 1), nn.LeakyReLU())
self.conv_after_body_hf = nn.Conv2D(embed_dim, embed_dim, 3, 1, 1)
self.conv_before_upsample_hf = nn.Sequential(nn.Conv2D(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU())
self.conv_last_hf = nn.Conv2D(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR (to save parameters)
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
(patches_resolution[0], patches_resolution[1]))
elif self.upsampler == 'nearest+conv':
# for real-world SR (less artifacts)
assert self.upscale == 4, 'only support x4 now.'
self.conv_before_upsample = nn.Sequential(nn.Conv2D(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU())
self.conv_up1 = nn.Conv2D(num_feat, num_feat, 3, 1, 1)
self.conv_up2 = nn.Conv2D(num_feat, num_feat, 3, 1, 1)
self.conv_hr = nn.Conv2D(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2D(num_feat, num_out_ch, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, )
else:
# for image denoising and JPEG compression artifact reduction
self.conv_last = nn.Conv2D(embed_dim, num_out_ch, 3, 1, 1)
def check_image_size(self, x):
_, _, h, w = x.shape
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
return x
def forward_features(self, x):
x_size = (x.shape[2], x.shape[3])
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x, x_size)
x = self.norm(x) # B L C
x = self.patch_unembed(x, x_size)
return x
def forward_features_hf(self, x):
x_size = (x.shape[2], x.shape[3])
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers_hf:
x = layer(x, x_size)
x = self.norm(x) # B L C
x = self.patch_unembed(x, x_size)
return x
def forward(self, x):
H, W = x.shape[2:]
x = self.check_image_size(x)
self.mean = self.mean.cast(x.dtype)
x = (x - self.mean) * self.img_range
if self.upsampler == 'pixelshuffle':
# for classical SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.conv_last(self.upsample(x))
elif self.upsampler == 'pixelshuffle_aux':
bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False)
bicubic = self.conv_bicubic(bicubic)
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
aux = self.conv_aux(x) # b, 3, LR_H, LR_W
x = self.conv_after_aux(aux)
x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + \
bicubic[:, :, :H * self.upscale, :W * self.upscale]
x = self.conv_last(x)
aux = aux / self.img_range + self.mean
elif self.upsampler == 'pixelshuffle_hf':
# for classical SR with HF
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x_before = self.conv_before_upsample(x)
x_out = self.conv_last(self.upsample(x_before))
x_hf = self.conv_first_hf(x_before)
x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
x_hf = self.conv_before_upsample_hf(x_hf)
x_hf = self.conv_last_hf(self.upsample_hf(x_hf))
x = x_out + x_hf
x_hf = x_hf / self.img_range + self.mean
elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.upsample(x)
elif self.upsampler == 'nearest+conv':
# for real-world SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.lrelu(self.conv_up1(nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
x = self.lrelu(self.conv_up2(nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
x = self.conv_last(self.lrelu(self.conv_hr(x)))
else:
# for image denoising and JPEG compression artifact reduction
x_first = self.conv_first(x)
res = self.conv_after_body(self.forward_features(x_first)) + x_first
x = x + self.conv_last(res)
x = x / self.img_range + self.mean
if self.upsampler == "pixelshuffle_aux":
return x[:, :, :H * self.upscale, :W * self.upscale], aux
elif self.upsampler == "pixelshuffle_hf":
x_out = x_out / self.img_range + self.mean
return x_out[:, :, :H * self.upscale, :W *
self.upscale], x[:, :, :H * self.upscale, :W *
self.upscale], x_hf[:, :, :H * self.upscale, :W * self.upscale]
else:
return x[:, :, :H * self.upscale, :W * self.upscale]
def flops(self):
flops = 0
H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops()
return flops
import os
import shutil
import unittest
import cv2
import numpy as np
import requests
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/mJaD10XeD7w/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8M3x8Y2F0fGVufDB8fHx8MTY2MzczNDc3Mw&force=true&w=640'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
img = cv2.imread('tests/test.jpg')
img = cv2.resize(img, (0, 0), fx=0.25, fy=0.25)
cv2.imwrite('tests/test.jpg', img)
cls.module = hub.Module(name="swin2sr_real_sr_x4")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('swin2sr_real_sr_x4_output')
def test_real_sr1(self):
results = self.module.real_sr(image='tests/test.jpg', visualization=False)
self.assertIsInstance(results, np.ndarray)
def test_real_sr2(self):
results = self.module.real_sr(image=cv2.imread('tests/test.jpg'), visualization=True)
self.assertIsInstance(results, np.ndarray)
def test_real_sr3(self):
results = self.module.real_sr(image=cv2.imread('tests/test.jpg'), visualization=True)
self.assertIsInstance(results, np.ndarray)
def test_real_sr4(self):
self.assertRaises(Exception, self.module.real_sr, image=['tests/test.jpg'])
def test_real_sr5(self):
self.assertRaises(FileNotFoundError, self.module.real_sr, image='no.jpg')
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册