未验证 提交 270cc958 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

add fbcnn_color module (#2065)

* add fbcnn_color module

* update example

* fix save name

* fix a cls
上级 21545f0c
# fbcnn_color
|模型名称|fbcnn_color|
| :--- | :---: |
|类别|图像-图像编辑|
|网络|FBCNN|
|数据集|-|
|是否支持Fine-tuning|否|
|模型大小|288MB|
|指标|-|
|最新更新日期|2022-10-08|
## 一、模型基本信息
- ### 应用效果展示
- 网络结构:
<p align="center">
<img src="https://ai-studio-static-online.cdn.bcebos.com/08afa15df2e54adeb39587cd7aaa9b60fc82d349bda34f51993d6304776fd374" hspace='10'/> <br />
</p>
- 样例结果示例:
<p align="center">
<img src="https://ai-studio-static-online.cdn.bcebos.com/f486da7c9d5e4cac8b7ff252b5a4c17633f44f28745c4e489f31e6b78caea005" hspace='10'/>
</p>
- ### 模型介绍
- FBCNN 是一个基于卷积神经网络的 JPEG 图像伪影去除模型,它可以预测可调整的质量因子,以控制伪影重新移动和细节保留之间的权衡。
## 二、安装
- ### 1、环境依赖
- paddlepaddle >= 2.0.0
- paddlehub >= 2.0.0
- ### 2.安装
- ```shell
$ hub install fbcnn_color
```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
## 三、模型API预测
- ### 1、命令行预测
```shell
$ hub run fbcnn_color \
--input_path "/PATH/TO/IMAGE" \
--quality_factor -1 \
--output_dir "fbcnn_color_output"
```
- ### 2、预测代码示例
```python
import paddlehub as hub
import cv2
module = hub.Module(name="fbcnn_color")
result = module.artifacts_removal(
image=cv2.imread('/PATH/TO/IMAGE'),
quality_factor=None,
visualization=True,
output_dir='fbcnn_color_output'
)
```
- ### 3、API
```python
def artifacts_removal(
image: Union[str, numpy.ndarray],
quality_factor: float = None,
visualization: bool = True,
output_dir: str = "fbcnn_color_output"
) -> numpy.ndarray
```
- 伪影去除 API
- **参数**
* image (Union\[str, numpy.ndarray\]): 图片数据,ndarray.shape 为 \[H, W, C\],BGR格式;
* quality_factor (float): 自定义质量因子(0.0 - 1.0),默认 None 为自适应;
* visualization (bool): 是否将识别结果保存为图片文件;
* output\_dir (str): 保存处理结果的文件目录。
- **返回**
* res (numpy.ndarray): 图像伪影去除结果 (BGR);
## 四、服务部署
- PaddleHub Serving 可以部署一个图像伪影去除的在线服务。
- ### 第一步:启动PaddleHub Serving
- 运行启动命令:
```shell
$ hub serving start -m fbcnn_color
```
- 这样就完成了一个图像伪影去除服务化API的部署,默认端口号为8866。
- ### 第二步:发送预测请求
- 配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
```python
import requests
import json
import base64
import cv2
import numpy as np
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tobytes()).decode('utf8')
def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8'))
data = np.frombuffer(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data
# 发送HTTP请求
org_im = cv2.imread('/PATH/TO/IMAGE')
data = {
'image': cv2_to_base64(org_im)
}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/fbcnn_color"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
# 结果转换
results = r.json()['results']
results = base64_to_cv2(results)
# 保存结果
cv2.imwrite('output.jpg', results)
```
## 五、参考资料
* 论文:[Towards Flexible Blind JPEG Artifacts Removal](https://arxiv.org/abs/2109.14573)
* 官方实现:[jiaxi-jiang/FBCNN](https://github.com/jiaxi-jiang/FBCNN)
## 六、更新历史
* 1.0.0
初始发布
```shell
$ hub install fbcnn_color==1.0.0
```
from collections import OrderedDict
import numpy as np
import paddle.nn as nn
'''
# --------------------------------------------
# Advanced nn.Sequential
# https://github.com/xinntao/BasicSR
# --------------------------------------------
'''
def sequential(*args):
"""Advanced nn.Sequential.
Args:
nn.Sequential, nn.Layer
Returns:
nn.Sequential
"""
if len(args) == 1:
if isinstance(args[0], OrderedDict):
raise NotImplementedError('sequential does not support OrderedDict input.')
return args[0] # No sequential is needed.
modules = []
for module in args:
if isinstance(module, nn.Sequential):
for submodule in module.children():
modules.append(submodule)
elif isinstance(module, nn.Layer):
modules.append(module)
return nn.Sequential(*modules)
# --------------------------------------------
# return nn.Sequantial of (Conv + BN + ReLU)
# --------------------------------------------
def conv(in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
bias=True,
mode='CBR',
negative_slope=0.2):
L = []
for t in mode:
if t == 'C':
L.append(
nn.Conv2D(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias_attr=bias))
elif t == 'T':
L.append(
nn.Conv2DTranspose(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias_attr=bias))
elif t == 'B':
L.append(nn.BatchNorm2D(out_channels, momentum=0.9, eps=1e-04, affine=True))
elif t == 'I':
L.append(nn.InstanceNorm2D(out_channels, affine=True))
elif t == 'R':
L.append(nn.ReLU())
elif t == 'r':
L.append(nn.ReLU())
elif t == 'L':
L.append(nn.LeakyReLU(negative_slope=negative_slope))
elif t == 'l':
L.append(nn.LeakyReLU(negative_slope=negative_slope))
elif t == '2':
L.append(nn.PixelShuffle(upscale_factor=2))
elif t == '3':
L.append(nn.PixelShuffle(upscale_factor=3))
elif t == '4':
L.append(nn.PixelShuffle(upscale_factor=4))
elif t == 'U':
L.append(nn.Upsample(scale_factor=2, mode='nearest'))
elif t == 'u':
L.append(nn.Upsample(scale_factor=3, mode='nearest'))
elif t == 'v':
L.append(nn.Upsample(scale_factor=4, mode='nearest'))
elif t == 'M':
L.append(nn.MaxPool2D(kernel_size=kernel_size, stride=stride, padding=0))
elif t == 'A':
L.append(nn.AvgPool2D(kernel_size=kernel_size, stride=stride, padding=0))
else:
raise NotImplementedError('Undefined type: '.format(t))
return sequential(*L)
# --------------------------------------------
# Res Block: x + conv(relu(conv(x)))
# --------------------------------------------
class ResBlock(nn.Layer):
def __init__(self,
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
bias=True,
mode='CRC',
negative_slope=0.2):
super(ResBlock, self).__init__()
assert in_channels == out_channels, 'Only support in_channels==out_channels.'
if mode[0] in ['R', 'L']:
mode = mode[0].lower() + mode[1:]
self.res = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope)
def forward(self, x):
res = self.res(x)
return x + res
# --------------------------------------------
# conv + subp (+ relu)
# --------------------------------------------
def upsample_pixelshuffle(in_channels=64,
out_channels=3,
kernel_size=3,
stride=1,
padding=1,
bias=True,
mode='2R',
negative_slope=0.2):
assert len(mode) < 4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.'
up1 = conv(in_channels,
out_channels * (int(mode[0])**2),
kernel_size,
stride,
padding,
bias,
mode='C' + mode,
negative_slope=negative_slope)
return up1
# --------------------------------------------
# nearest_upsample + conv (+ R)
# --------------------------------------------
def upsample_upconv(in_channels=64,
out_channels=3,
kernel_size=3,
stride=1,
padding=1,
bias=True,
mode='2R',
negative_slope=0.2):
assert len(mode) < 4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR'
if mode[0] == '2':
uc = 'UC'
elif mode[0] == '3':
uc = 'uC'
elif mode[0] == '4':
uc = 'vC'
mode = mode.replace(mode[0], uc)
up1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode, negative_slope=negative_slope)
return up1
# --------------------------------------------
# convTranspose (+ relu)
# --------------------------------------------
def upsample_convtranspose(in_channels=64,
out_channels=3,
kernel_size=2,
stride=2,
padding=0,
bias=True,
mode='2R',
negative_slope=0.2):
assert len(mode) < 4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.'
kernel_size = int(mode[0])
stride = int(mode[0])
mode = mode.replace(mode[0], 'T')
up1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope)
return up1
'''
# --------------------------------------------
# Downsampler
# Kai Zhang, https://github.com/cszn/KAIR
# --------------------------------------------
# downsample_strideconv
# downsample_maxpool
# downsample_avgpool
# --------------------------------------------
'''
# --------------------------------------------
# strideconv (+ relu)
# --------------------------------------------
def downsample_strideconv(in_channels=64,
out_channels=64,
kernel_size=2,
stride=2,
padding=0,
bias=True,
mode='2R',
negative_slope=0.2):
assert len(mode) < 4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.'
kernel_size = int(mode[0])
stride = int(mode[0])
mode = mode.replace(mode[0], 'C')
down1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope)
return down1
# --------------------------------------------
# maxpooling + conv (+ relu)
# --------------------------------------------
def downsample_maxpool(in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=0,
bias=True,
mode='2R',
negative_slope=0.2):
assert len(mode) < 4 and mode[0] in ['2', '3'], 'mode examples: 2, 2R, 2BR, 3, ..., 3BR.'
kernel_size_pool = int(mode[0])
stride_pool = int(mode[0])
mode = mode.replace(mode[0], 'MC')
pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0], negative_slope=negative_slope)
pool_tail = conv(in_channels,
out_channels,
kernel_size,
stride,
padding,
bias,
mode=mode[1:],
negative_slope=negative_slope)
return sequential(pool, pool_tail)
# --------------------------------------------
# averagepooling + conv (+ relu)
# --------------------------------------------
def downsample_avgpool(in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
bias=True,
mode='2R',
negative_slope=0.2):
assert len(mode) < 4 and mode[0] in ['2', '3'], 'mode examples: 2, 2R, 2BR, 3, ..., 3BR.'
kernel_size_pool = int(mode[0])
stride_pool = int(mode[0])
mode = mode.replace(mode[0], 'AC')
pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0], negative_slope=negative_slope)
pool_tail = conv(in_channels,
out_channels,
kernel_size,
stride,
padding,
bias,
mode=mode[1:],
negative_slope=negative_slope)
return sequential(pool, pool_tail)
class QFAttention(nn.Layer):
def __init__(self,
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
bias=True,
mode='CRC',
negative_slope=0.2):
super(QFAttention, self).__init__()
assert in_channels == out_channels, 'Only support in_channels==out_channels.'
if mode[0] in ['R', 'L']:
mode = mode[0].lower() + mode[1:]
self.res = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope)
def forward(self, x, gamma, beta):
gamma = gamma.unsqueeze(-1).unsqueeze(-1)
beta = beta.unsqueeze(-1).unsqueeze(-1)
res = (gamma) * self.res(x) + beta
return x + res
class FBCNN(nn.Layer):
def __init__(self,
in_nc=3,
out_nc=3,
nc=[64, 128, 256, 512],
nb=4,
act_mode='R',
downsample_mode='strideconv',
upsample_mode='convtranspose'):
super(FBCNN, self).__init__()
self.m_head = conv(in_nc, nc[0], bias=True, mode='C')
self.nb = nb
self.nc = nc
# downsample
if downsample_mode == 'avgpool':
downsample_block = downsample_avgpool
elif downsample_mode == 'maxpool':
downsample_block = downsample_maxpool
elif downsample_mode == 'strideconv':
downsample_block = downsample_strideconv
else:
raise NotImplementedError('downsample mode [{:s}] is not found'.format(downsample_mode))
self.m_down1 = sequential(*[ResBlock(nc[0], nc[0], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)],
downsample_block(nc[0], nc[1], bias=True, mode='2'))
self.m_down2 = sequential(*[ResBlock(nc[1], nc[1], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)],
downsample_block(nc[1], nc[2], bias=True, mode='2'))
self.m_down3 = sequential(*[ResBlock(nc[2], nc[2], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)],
downsample_block(nc[2], nc[3], bias=True, mode='2'))
self.m_body_encoder = sequential(
*[ResBlock(nc[3], nc[3], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)])
self.m_body_decoder = sequential(
*[ResBlock(nc[3], nc[3], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)])
# upsample
if upsample_mode == 'upconv':
upsample_block = upsample_upconv
elif upsample_mode == 'pixelshuffle':
upsample_block = upsample_pixelshuffle
elif upsample_mode == 'convtranspose':
upsample_block = upsample_convtranspose
else:
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
self.m_up3 = nn.LayerList([
upsample_block(nc[3], nc[2], bias=True, mode='2'),
*[QFAttention(nc[2], nc[2], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)]
])
self.m_up2 = nn.LayerList([
upsample_block(nc[2], nc[1], bias=True, mode='2'),
*[QFAttention(nc[1], nc[1], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)]
])
self.m_up1 = nn.LayerList([
upsample_block(nc[1], nc[0], bias=True, mode='2'),
*[QFAttention(nc[0], nc[0], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)]
])
self.m_tail = conv(nc[0], out_nc, bias=True, mode='C')
self.qf_pred = sequential(*[ResBlock(nc[3], nc[3], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)],
nn.AdaptiveAvgPool2D((1, 1)), nn.Flatten(), nn.Linear(512, 512), nn.ReLU(),
nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 1), nn.Sigmoid())
self.qf_embed = sequential(nn.Linear(1, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 512),
nn.ReLU())
self.to_gamma_3 = sequential(nn.Linear(512, nc[2]), nn.Sigmoid())
self.to_beta_3 = sequential(nn.Linear(512, nc[2]), nn.Tanh())
self.to_gamma_2 = sequential(nn.Linear(512, nc[1]), nn.Sigmoid())
self.to_beta_2 = sequential(nn.Linear(512, nc[1]), nn.Tanh())
self.to_gamma_1 = sequential(nn.Linear(512, nc[0]), nn.Sigmoid())
self.to_beta_1 = sequential(nn.Linear(512, nc[0]), nn.Tanh())
def forward(self, x, qf_input=None):
h, w = x.shape[-2:]
paddingBottom = int(np.ceil(h / 8) * 8 - h)
paddingRight = int(np.ceil(w / 8) * 8 - w)
x = nn.functional.pad(x, (0, paddingRight, 0, paddingBottom), mode='reflect')
x1 = self.m_head(x)
x2 = self.m_down1(x1)
x3 = self.m_down2(x2)
x4 = self.m_down3(x3)
x = self.m_body_encoder(x4)
qf = self.qf_pred(x)
x = self.m_body_decoder(x)
qf_embedding = self.qf_embed(qf_input) if qf_input is not None else self.qf_embed(qf)
gamma_3 = self.to_gamma_3(qf_embedding)
beta_3 = self.to_beta_3(qf_embedding)
gamma_2 = self.to_gamma_2(qf_embedding)
beta_2 = self.to_beta_2(qf_embedding)
gamma_1 = self.to_gamma_1(qf_embedding)
beta_1 = self.to_beta_1(qf_embedding)
x = x + x4
x = self.m_up3[0](x)
for i in range(self.nb):
x = self.m_up3[i + 1](x, gamma_3, beta_3)
x = x + x3
x = self.m_up2[0](x)
for i in range(self.nb):
x = self.m_up2[i + 1](x, gamma_2, beta_2)
x = x + x2
x = self.m_up1[0](x)
for i in range(self.nb):
x = self.m_up1[i + 1](x, gamma_1, beta_1)
x = x + x1
x = self.m_tail(x)
x = x[..., :h, :w]
return x, qf
import argparse
import base64
import os
import time
from typing import Union
import cv2
import numpy as np
import paddle
import paddle.nn as nn
from .fbcnn import FBCNN
from paddlehub.module.module import moduleinfo
from paddlehub.module.module import runnable
from paddlehub.module.module import serving
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tobytes()).decode('utf8')
def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8'))
data = np.frombuffer(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data
@moduleinfo(
name='fbcnn_color',
version='1.0.0',
type="CV/image_editing",
author="",
author_email="",
summary="Flexible JPEG Artifacts Removal.",
)
class FBCNNColor(nn.Layer):
def __init__(self):
super(FBCNNColor, self).__init__()
self.default_pretrained_model_path = os.path.join(self.directory, 'ckpts', 'fbcnn_color.pdparams')
self.fbcnn = FBCNN()
state_dict = paddle.load(self.default_pretrained_model_path)
self.fbcnn.set_state_dict(state_dict)
self.fbcnn.eval()
def preprocess(self, img: np.ndarray) -> np.ndarray:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img.transpose((2, 0, 1))
img = img / 255.0
return img.astype(np.float32)
def postprocess(self, img: np.ndarray) -> np.ndarray:
img = img.clip(0, 1)
img = img * 255.0
img = img.transpose((1, 2, 0))
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
return img.astype(np.uint8)
def artifacts_removal(self,
image: Union[str, np.ndarray],
quality_factor: float = None,
visualization: bool = True,
output_dir: str = "fbcnn_color_output") -> np.ndarray:
if isinstance(image, str):
_, file_name = os.path.split(image)
save_name, _ = os.path.splitext(file_name)
save_name = save_name + '_' + str(int(time.time())) + '.jpg'
image = cv2.imdecode(np.fromfile(image, dtype=np.uint8), cv2.IMREAD_COLOR)
elif isinstance(image, np.ndarray):
save_name = str(int(time.time())) + '.jpg'
image = image
else:
raise Exception("image should be a str / np.ndarray")
with paddle.no_grad():
img_input = self.preprocess(image)
img_input = paddle.to_tensor(img_input[None, ...], dtype=paddle.float32)
if quality_factor and 0 <= quality_factor <= 1:
qf_input = paddle.to_tensor([[quality_factor]], dtype=paddle.float32)
else:
qf_input = None
img_output, _ = self.fbcnn(img_input, qf_input)
img_output = img_output.numpy()[0]
img_output = self.postprocess(img_output)
if visualization:
if not os.path.isdir(output_dir):
os.makedirs(output_dir)
save_path = os.path.join(output_dir, save_name)
cv2.imwrite(save_path, img_output)
return img_output
@runnable
def run_cmd(self, argvs):
"""
Run as a command.
"""
self.parser = argparse.ArgumentParser(description="Run the {} module.".format(self.name),
prog='hub run {}'.format(self.name),
usage='%(prog)s',
add_help=True)
self.parser.add_argument('--input_path', type=str, help="Path to image.")
self.parser.add_argument('--quality_factor', type=float, default=None, help="Image quality factor (0.0-1.0).")
self.parser.add_argument('--output_dir',
type=str,
default='fbcnn_color_output',
help="The directory to save output images.")
args = self.parser.parse_args(argvs)
self.artifacts_removal(image=args.input_path,
quality_factor=args.quality_factor,
visualization=True,
output_dir=args.output_dir)
return 'Artifacts removal results are saved in %s' % args.output_dir
@serving
def serving_method(self, image, **kwargs):
"""
Run as a service.
"""
image = base64_to_cv2(image)
img_output = self.artifacts_removal(image=image, **kwargs)
return cv2_to_base64(img_output)
import os
import shutil
import unittest
import cv2
import numpy as np
import requests
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/mJaD10XeD7w/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8M3x8Y2F0fGVufDB8fHx8MTY2MzczNDc3Mw&force=true&w=640'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="fbcnn_color")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('fbcnn_color_output')
def test_artifacts_removal1(self):
results = self.module.artifacts_removal(image='tests/test.jpg', quality_factor=None, visualization=False)
self.assertIsInstance(results, np.ndarray)
def test_artifacts_removal2(self):
results = self.module.artifacts_removal(image=cv2.imread('tests/test.jpg'),
quality_factor=None,
visualization=True)
self.assertIsInstance(results, np.ndarray)
def test_artifacts_removal3(self):
results = self.module.artifacts_removal(image=cv2.imread('tests/test.jpg'),
quality_factor=0.5,
visualization=True)
self.assertIsInstance(results, np.ndarray)
def test_artifacts_removal4(self):
self.assertRaises(Exception, self.module.artifacts_removal, image=['tests/test.jpg'])
def test_artifacts_removal5(self):
self.assertRaises(FileNotFoundError, self.module.artifacts_removal, image='no.jpg')
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册