Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
270cc958
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
1 年多 前同步成功
通知
284
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
270cc958
编写于
10月 18, 2022
作者:
jm_12138
提交者:
GitHub
10月 18, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add fbcnn_color module (#2065)
* add fbcnn_color module * update example * fix save name * fix a cls
上级
21545f0c
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
773 addition
and
0 deletion
+773
-0
modules/image/Image_editing/enhancement/fbcnn_color/README.md
...les/image/Image_editing/enhancement/fbcnn_color/README.md
+166
-0
modules/image/Image_editing/enhancement/fbcnn_color/fbcnn.py
modules/image/Image_editing/enhancement/fbcnn_color/fbcnn.py
+422
-0
modules/image/Image_editing/enhancement/fbcnn_color/module.py
...les/image/Image_editing/enhancement/fbcnn_color/module.py
+126
-0
modules/image/Image_editing/enhancement/fbcnn_color/test.py
modules/image/Image_editing/enhancement/fbcnn_color/test.py
+59
-0
未找到文件。
modules/image/Image_editing/enhancement/fbcnn_color/README.md
0 → 100644
浏览文件 @
270cc958
# fbcnn_color
|模型名称|fbcnn_color|
| :--- | :---: |
|类别|图像-图像编辑|
|网络|FBCNN|
|数据集|-|
|是否支持Fine-tuning|否|
|模型大小|288MB|
|指标|-|
|最新更新日期|2022-10-08|
## 一、模型基本信息
-
### 应用效果展示
-
网络结构:
<p
align=
"center"
>
<img
src=
"https://ai-studio-static-online.cdn.bcebos.com/08afa15df2e54adeb39587cd7aaa9b60fc82d349bda34f51993d6304776fd374"
hspace=
'10'
/>
<br
/>
</p>
-
样例结果示例:
<p
align=
"center"
>
<img
src=
"https://ai-studio-static-online.cdn.bcebos.com/f486da7c9d5e4cac8b7ff252b5a4c17633f44f28745c4e489f31e6b78caea005"
hspace=
'10'
/>
</p>
-
### 模型介绍
-
FBCNN 是一个基于卷积神经网络的 JPEG 图像伪影去除模型,它可以预测可调整的质量因子,以控制伪影重新移动和细节保留之间的权衡。
## 二、安装
-
### 1、环境依赖
-
paddlepaddle >= 2.0.0
-
paddlehub >= 2.0.0
-
### 2.安装
- ```shell
$ hub install fbcnn_color
```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
## 三、模型API预测
-
### 1、命令行预测
```shell
$ hub run fbcnn_color \
--input_path "/PATH/TO/IMAGE" \
--quality_factor -1 \
--output_dir "fbcnn_color_output"
```
-
### 2、预测代码示例
```python
import paddlehub as hub
import cv2
module = hub.Module(name="fbcnn_color")
result = module.artifacts_removal(
image=cv2.imread('/PATH/TO/IMAGE'),
quality_factor=None,
visualization=True,
output_dir='fbcnn_color_output'
)
```
-
### 3、API
```python
def artifacts_removal(
image: Union[str, numpy.ndarray],
quality_factor: float = None,
visualization: bool = True,
output_dir: str = "fbcnn_color_output"
) -> numpy.ndarray
```
- 伪影去除 API
- **参数**
* image (Union\[str, numpy.ndarray\]): 图片数据,ndarray.shape 为 \[H, W, C\],BGR格式;
* quality_factor (float): 自定义质量因子(0.0 - 1.0),默认 None 为自适应;
* visualization (bool): 是否将识别结果保存为图片文件;
* output\_dir (str): 保存处理结果的文件目录。
- **返回**
* res (numpy.ndarray): 图像伪影去除结果 (BGR);
## 四、服务部署
-
PaddleHub Serving 可以部署一个图像伪影去除的在线服务。
-
### 第一步:启动PaddleHub Serving
-
运行启动命令:
```shell
$ hub serving start -m fbcnn_color
```
- 这样就完成了一个图像伪影去除服务化API的部署,默认端口号为8866。
-
### 第二步:发送预测请求
-
配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
```python
import requests
import json
import base64
import cv2
import numpy as np
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tobytes()).decode('utf8')
def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8'))
data = np.frombuffer(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data
# 发送HTTP请求
org_im = cv2.imread('/PATH/TO/IMAGE')
data = {
'image': cv2_to_base64(org_im)
}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/fbcnn_color"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
# 结果转换
results = r.json()['results']
results = base64_to_cv2(results)
# 保存结果
cv2.imwrite('output.jpg', results)
```
## 五、参考资料
*
论文:
[
Towards Flexible Blind JPEG Artifacts Removal
](
https://arxiv.org/abs/2109.14573
)
*
官方实现:
[
jiaxi-jiang/FBCNN
](
https://github.com/jiaxi-jiang/FBCNN
)
## 六、更新历史
*
1.0.0
初始发布
```
shell
$
hub
install
fbcnn_color
==
1.0.0
```
modules/image/Image_editing/enhancement/fbcnn_color/fbcnn.py
0 → 100644
浏览文件 @
270cc958
from
collections
import
OrderedDict
import
numpy
as
np
import
paddle.nn
as
nn
'''
# --------------------------------------------
# Advanced nn.Sequential
# https://github.com/xinntao/BasicSR
# --------------------------------------------
'''
def
sequential
(
*
args
):
"""Advanced nn.Sequential.
Args:
nn.Sequential, nn.Layer
Returns:
nn.Sequential
"""
if
len
(
args
)
==
1
:
if
isinstance
(
args
[
0
],
OrderedDict
):
raise
NotImplementedError
(
'sequential does not support OrderedDict input.'
)
return
args
[
0
]
# No sequential is needed.
modules
=
[]
for
module
in
args
:
if
isinstance
(
module
,
nn
.
Sequential
):
for
submodule
in
module
.
children
():
modules
.
append
(
submodule
)
elif
isinstance
(
module
,
nn
.
Layer
):
modules
.
append
(
module
)
return
nn
.
Sequential
(
*
modules
)
# --------------------------------------------
# return nn.Sequantial of (Conv + BN + ReLU)
# --------------------------------------------
def
conv
(
in_channels
=
64
,
out_channels
=
64
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
True
,
mode
=
'CBR'
,
negative_slope
=
0.2
):
L
=
[]
for
t
in
mode
:
if
t
==
'C'
:
L
.
append
(
nn
.
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
bias_attr
=
bias
))
elif
t
==
'T'
:
L
.
append
(
nn
.
Conv2DTranspose
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
bias_attr
=
bias
))
elif
t
==
'B'
:
L
.
append
(
nn
.
BatchNorm2D
(
out_channels
,
momentum
=
0.9
,
eps
=
1e-04
,
affine
=
True
))
elif
t
==
'I'
:
L
.
append
(
nn
.
InstanceNorm2D
(
out_channels
,
affine
=
True
))
elif
t
==
'R'
:
L
.
append
(
nn
.
ReLU
())
elif
t
==
'r'
:
L
.
append
(
nn
.
ReLU
())
elif
t
==
'L'
:
L
.
append
(
nn
.
LeakyReLU
(
negative_slope
=
negative_slope
))
elif
t
==
'l'
:
L
.
append
(
nn
.
LeakyReLU
(
negative_slope
=
negative_slope
))
elif
t
==
'2'
:
L
.
append
(
nn
.
PixelShuffle
(
upscale_factor
=
2
))
elif
t
==
'3'
:
L
.
append
(
nn
.
PixelShuffle
(
upscale_factor
=
3
))
elif
t
==
'4'
:
L
.
append
(
nn
.
PixelShuffle
(
upscale_factor
=
4
))
elif
t
==
'U'
:
L
.
append
(
nn
.
Upsample
(
scale_factor
=
2
,
mode
=
'nearest'
))
elif
t
==
'u'
:
L
.
append
(
nn
.
Upsample
(
scale_factor
=
3
,
mode
=
'nearest'
))
elif
t
==
'v'
:
L
.
append
(
nn
.
Upsample
(
scale_factor
=
4
,
mode
=
'nearest'
))
elif
t
==
'M'
:
L
.
append
(
nn
.
MaxPool2D
(
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
0
))
elif
t
==
'A'
:
L
.
append
(
nn
.
AvgPool2D
(
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
0
))
else
:
raise
NotImplementedError
(
'Undefined type: '
.
format
(
t
))
return
sequential
(
*
L
)
# --------------------------------------------
# Res Block: x + conv(relu(conv(x)))
# --------------------------------------------
class
ResBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
=
64
,
out_channels
=
64
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
True
,
mode
=
'CRC'
,
negative_slope
=
0.2
):
super
(
ResBlock
,
self
).
__init__
()
assert
in_channels
==
out_channels
,
'Only support in_channels==out_channels.'
if
mode
[
0
]
in
[
'R'
,
'L'
]:
mode
=
mode
[
0
].
lower
()
+
mode
[
1
:]
self
.
res
=
conv
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
bias
,
mode
,
negative_slope
)
def
forward
(
self
,
x
):
res
=
self
.
res
(
x
)
return
x
+
res
# --------------------------------------------
# conv + subp (+ relu)
# --------------------------------------------
def
upsample_pixelshuffle
(
in_channels
=
64
,
out_channels
=
3
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
True
,
mode
=
'2R'
,
negative_slope
=
0.2
):
assert
len
(
mode
)
<
4
and
mode
[
0
]
in
[
'2'
,
'3'
,
'4'
],
'mode examples: 2, 2R, 2BR, 3, ..., 4BR.'
up1
=
conv
(
in_channels
,
out_channels
*
(
int
(
mode
[
0
])
**
2
),
kernel_size
,
stride
,
padding
,
bias
,
mode
=
'C'
+
mode
,
negative_slope
=
negative_slope
)
return
up1
# --------------------------------------------
# nearest_upsample + conv (+ R)
# --------------------------------------------
def
upsample_upconv
(
in_channels
=
64
,
out_channels
=
3
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
True
,
mode
=
'2R'
,
negative_slope
=
0.2
):
assert
len
(
mode
)
<
4
and
mode
[
0
]
in
[
'2'
,
'3'
,
'4'
],
'mode examples: 2, 2R, 2BR, 3, ..., 4BR'
if
mode
[
0
]
==
'2'
:
uc
=
'UC'
elif
mode
[
0
]
==
'3'
:
uc
=
'uC'
elif
mode
[
0
]
==
'4'
:
uc
=
'vC'
mode
=
mode
.
replace
(
mode
[
0
],
uc
)
up1
=
conv
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
bias
,
mode
=
mode
,
negative_slope
=
negative_slope
)
return
up1
# --------------------------------------------
# convTranspose (+ relu)
# --------------------------------------------
def
upsample_convtranspose
(
in_channels
=
64
,
out_channels
=
3
,
kernel_size
=
2
,
stride
=
2
,
padding
=
0
,
bias
=
True
,
mode
=
'2R'
,
negative_slope
=
0.2
):
assert
len
(
mode
)
<
4
and
mode
[
0
]
in
[
'2'
,
'3'
,
'4'
],
'mode examples: 2, 2R, 2BR, 3, ..., 4BR.'
kernel_size
=
int
(
mode
[
0
])
stride
=
int
(
mode
[
0
])
mode
=
mode
.
replace
(
mode
[
0
],
'T'
)
up1
=
conv
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
bias
,
mode
,
negative_slope
)
return
up1
'''
# --------------------------------------------
# Downsampler
# Kai Zhang, https://github.com/cszn/KAIR
# --------------------------------------------
# downsample_strideconv
# downsample_maxpool
# downsample_avgpool
# --------------------------------------------
'''
# --------------------------------------------
# strideconv (+ relu)
# --------------------------------------------
def
downsample_strideconv
(
in_channels
=
64
,
out_channels
=
64
,
kernel_size
=
2
,
stride
=
2
,
padding
=
0
,
bias
=
True
,
mode
=
'2R'
,
negative_slope
=
0.2
):
assert
len
(
mode
)
<
4
and
mode
[
0
]
in
[
'2'
,
'3'
,
'4'
],
'mode examples: 2, 2R, 2BR, 3, ..., 4BR.'
kernel_size
=
int
(
mode
[
0
])
stride
=
int
(
mode
[
0
])
mode
=
mode
.
replace
(
mode
[
0
],
'C'
)
down1
=
conv
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
bias
,
mode
,
negative_slope
)
return
down1
# --------------------------------------------
# maxpooling + conv (+ relu)
# --------------------------------------------
def
downsample_maxpool
(
in_channels
=
64
,
out_channels
=
64
,
kernel_size
=
3
,
stride
=
1
,
padding
=
0
,
bias
=
True
,
mode
=
'2R'
,
negative_slope
=
0.2
):
assert
len
(
mode
)
<
4
and
mode
[
0
]
in
[
'2'
,
'3'
],
'mode examples: 2, 2R, 2BR, 3, ..., 3BR.'
kernel_size_pool
=
int
(
mode
[
0
])
stride_pool
=
int
(
mode
[
0
])
mode
=
mode
.
replace
(
mode
[
0
],
'MC'
)
pool
=
conv
(
kernel_size
=
kernel_size_pool
,
stride
=
stride_pool
,
mode
=
mode
[
0
],
negative_slope
=
negative_slope
)
pool_tail
=
conv
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
bias
,
mode
=
mode
[
1
:],
negative_slope
=
negative_slope
)
return
sequential
(
pool
,
pool_tail
)
# --------------------------------------------
# averagepooling + conv (+ relu)
# --------------------------------------------
def
downsample_avgpool
(
in_channels
=
64
,
out_channels
=
64
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
True
,
mode
=
'2R'
,
negative_slope
=
0.2
):
assert
len
(
mode
)
<
4
and
mode
[
0
]
in
[
'2'
,
'3'
],
'mode examples: 2, 2R, 2BR, 3, ..., 3BR.'
kernel_size_pool
=
int
(
mode
[
0
])
stride_pool
=
int
(
mode
[
0
])
mode
=
mode
.
replace
(
mode
[
0
],
'AC'
)
pool
=
conv
(
kernel_size
=
kernel_size_pool
,
stride
=
stride_pool
,
mode
=
mode
[
0
],
negative_slope
=
negative_slope
)
pool_tail
=
conv
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
bias
,
mode
=
mode
[
1
:],
negative_slope
=
negative_slope
)
return
sequential
(
pool
,
pool_tail
)
class
QFAttention
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
=
64
,
out_channels
=
64
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
True
,
mode
=
'CRC'
,
negative_slope
=
0.2
):
super
(
QFAttention
,
self
).
__init__
()
assert
in_channels
==
out_channels
,
'Only support in_channels==out_channels.'
if
mode
[
0
]
in
[
'R'
,
'L'
]:
mode
=
mode
[
0
].
lower
()
+
mode
[
1
:]
self
.
res
=
conv
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
bias
,
mode
,
negative_slope
)
def
forward
(
self
,
x
,
gamma
,
beta
):
gamma
=
gamma
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
beta
=
beta
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
res
=
(
gamma
)
*
self
.
res
(
x
)
+
beta
return
x
+
res
class
FBCNN
(
nn
.
Layer
):
def
__init__
(
self
,
in_nc
=
3
,
out_nc
=
3
,
nc
=
[
64
,
128
,
256
,
512
],
nb
=
4
,
act_mode
=
'R'
,
downsample_mode
=
'strideconv'
,
upsample_mode
=
'convtranspose'
):
super
(
FBCNN
,
self
).
__init__
()
self
.
m_head
=
conv
(
in_nc
,
nc
[
0
],
bias
=
True
,
mode
=
'C'
)
self
.
nb
=
nb
self
.
nc
=
nc
# downsample
if
downsample_mode
==
'avgpool'
:
downsample_block
=
downsample_avgpool
elif
downsample_mode
==
'maxpool'
:
downsample_block
=
downsample_maxpool
elif
downsample_mode
==
'strideconv'
:
downsample_block
=
downsample_strideconv
else
:
raise
NotImplementedError
(
'downsample mode [{:s}] is not found'
.
format
(
downsample_mode
))
self
.
m_down1
=
sequential
(
*
[
ResBlock
(
nc
[
0
],
nc
[
0
],
bias
=
True
,
mode
=
'C'
+
act_mode
+
'C'
)
for
_
in
range
(
nb
)],
downsample_block
(
nc
[
0
],
nc
[
1
],
bias
=
True
,
mode
=
'2'
))
self
.
m_down2
=
sequential
(
*
[
ResBlock
(
nc
[
1
],
nc
[
1
],
bias
=
True
,
mode
=
'C'
+
act_mode
+
'C'
)
for
_
in
range
(
nb
)],
downsample_block
(
nc
[
1
],
nc
[
2
],
bias
=
True
,
mode
=
'2'
))
self
.
m_down3
=
sequential
(
*
[
ResBlock
(
nc
[
2
],
nc
[
2
],
bias
=
True
,
mode
=
'C'
+
act_mode
+
'C'
)
for
_
in
range
(
nb
)],
downsample_block
(
nc
[
2
],
nc
[
3
],
bias
=
True
,
mode
=
'2'
))
self
.
m_body_encoder
=
sequential
(
*
[
ResBlock
(
nc
[
3
],
nc
[
3
],
bias
=
True
,
mode
=
'C'
+
act_mode
+
'C'
)
for
_
in
range
(
nb
)])
self
.
m_body_decoder
=
sequential
(
*
[
ResBlock
(
nc
[
3
],
nc
[
3
],
bias
=
True
,
mode
=
'C'
+
act_mode
+
'C'
)
for
_
in
range
(
nb
)])
# upsample
if
upsample_mode
==
'upconv'
:
upsample_block
=
upsample_upconv
elif
upsample_mode
==
'pixelshuffle'
:
upsample_block
=
upsample_pixelshuffle
elif
upsample_mode
==
'convtranspose'
:
upsample_block
=
upsample_convtranspose
else
:
raise
NotImplementedError
(
'upsample mode [{:s}] is not found'
.
format
(
upsample_mode
))
self
.
m_up3
=
nn
.
LayerList
([
upsample_block
(
nc
[
3
],
nc
[
2
],
bias
=
True
,
mode
=
'2'
),
*
[
QFAttention
(
nc
[
2
],
nc
[
2
],
bias
=
True
,
mode
=
'C'
+
act_mode
+
'C'
)
for
_
in
range
(
nb
)]
])
self
.
m_up2
=
nn
.
LayerList
([
upsample_block
(
nc
[
2
],
nc
[
1
],
bias
=
True
,
mode
=
'2'
),
*
[
QFAttention
(
nc
[
1
],
nc
[
1
],
bias
=
True
,
mode
=
'C'
+
act_mode
+
'C'
)
for
_
in
range
(
nb
)]
])
self
.
m_up1
=
nn
.
LayerList
([
upsample_block
(
nc
[
1
],
nc
[
0
],
bias
=
True
,
mode
=
'2'
),
*
[
QFAttention
(
nc
[
0
],
nc
[
0
],
bias
=
True
,
mode
=
'C'
+
act_mode
+
'C'
)
for
_
in
range
(
nb
)]
])
self
.
m_tail
=
conv
(
nc
[
0
],
out_nc
,
bias
=
True
,
mode
=
'C'
)
self
.
qf_pred
=
sequential
(
*
[
ResBlock
(
nc
[
3
],
nc
[
3
],
bias
=
True
,
mode
=
'C'
+
act_mode
+
'C'
)
for
_
in
range
(
nb
)],
nn
.
AdaptiveAvgPool2D
((
1
,
1
)),
nn
.
Flatten
(),
nn
.
Linear
(
512
,
512
),
nn
.
ReLU
(),
nn
.
Linear
(
512
,
512
),
nn
.
ReLU
(),
nn
.
Linear
(
512
,
1
),
nn
.
Sigmoid
())
self
.
qf_embed
=
sequential
(
nn
.
Linear
(
1
,
512
),
nn
.
ReLU
(),
nn
.
Linear
(
512
,
512
),
nn
.
ReLU
(),
nn
.
Linear
(
512
,
512
),
nn
.
ReLU
())
self
.
to_gamma_3
=
sequential
(
nn
.
Linear
(
512
,
nc
[
2
]),
nn
.
Sigmoid
())
self
.
to_beta_3
=
sequential
(
nn
.
Linear
(
512
,
nc
[
2
]),
nn
.
Tanh
())
self
.
to_gamma_2
=
sequential
(
nn
.
Linear
(
512
,
nc
[
1
]),
nn
.
Sigmoid
())
self
.
to_beta_2
=
sequential
(
nn
.
Linear
(
512
,
nc
[
1
]),
nn
.
Tanh
())
self
.
to_gamma_1
=
sequential
(
nn
.
Linear
(
512
,
nc
[
0
]),
nn
.
Sigmoid
())
self
.
to_beta_1
=
sequential
(
nn
.
Linear
(
512
,
nc
[
0
]),
nn
.
Tanh
())
def
forward
(
self
,
x
,
qf_input
=
None
):
h
,
w
=
x
.
shape
[
-
2
:]
paddingBottom
=
int
(
np
.
ceil
(
h
/
8
)
*
8
-
h
)
paddingRight
=
int
(
np
.
ceil
(
w
/
8
)
*
8
-
w
)
x
=
nn
.
functional
.
pad
(
x
,
(
0
,
paddingRight
,
0
,
paddingBottom
),
mode
=
'reflect'
)
x1
=
self
.
m_head
(
x
)
x2
=
self
.
m_down1
(
x1
)
x3
=
self
.
m_down2
(
x2
)
x4
=
self
.
m_down3
(
x3
)
x
=
self
.
m_body_encoder
(
x4
)
qf
=
self
.
qf_pred
(
x
)
x
=
self
.
m_body_decoder
(
x
)
qf_embedding
=
self
.
qf_embed
(
qf_input
)
if
qf_input
is
not
None
else
self
.
qf_embed
(
qf
)
gamma_3
=
self
.
to_gamma_3
(
qf_embedding
)
beta_3
=
self
.
to_beta_3
(
qf_embedding
)
gamma_2
=
self
.
to_gamma_2
(
qf_embedding
)
beta_2
=
self
.
to_beta_2
(
qf_embedding
)
gamma_1
=
self
.
to_gamma_1
(
qf_embedding
)
beta_1
=
self
.
to_beta_1
(
qf_embedding
)
x
=
x
+
x4
x
=
self
.
m_up3
[
0
](
x
)
for
i
in
range
(
self
.
nb
):
x
=
self
.
m_up3
[
i
+
1
](
x
,
gamma_3
,
beta_3
)
x
=
x
+
x3
x
=
self
.
m_up2
[
0
](
x
)
for
i
in
range
(
self
.
nb
):
x
=
self
.
m_up2
[
i
+
1
](
x
,
gamma_2
,
beta_2
)
x
=
x
+
x2
x
=
self
.
m_up1
[
0
](
x
)
for
i
in
range
(
self
.
nb
):
x
=
self
.
m_up1
[
i
+
1
](
x
,
gamma_1
,
beta_1
)
x
=
x
+
x1
x
=
self
.
m_tail
(
x
)
x
=
x
[...,
:
h
,
:
w
]
return
x
,
qf
modules/image/Image_editing/enhancement/fbcnn_color/module.py
0 → 100644
浏览文件 @
270cc958
import
argparse
import
base64
import
os
import
time
from
typing
import
Union
import
cv2
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
from
.fbcnn
import
FBCNN
from
paddlehub.module.module
import
moduleinfo
from
paddlehub.module.module
import
runnable
from
paddlehub.module.module
import
serving
def
cv2_to_base64
(
image
):
data
=
cv2
.
imencode
(
'.jpg'
,
image
)[
1
]
return
base64
.
b64encode
(
data
.
tobytes
()).
decode
(
'utf8'
)
def
base64_to_cv2
(
b64str
):
data
=
base64
.
b64decode
(
b64str
.
encode
(
'utf8'
))
data
=
np
.
frombuffer
(
data
,
np
.
uint8
)
data
=
cv2
.
imdecode
(
data
,
cv2
.
IMREAD_COLOR
)
return
data
@
moduleinfo
(
name
=
'fbcnn_color'
,
version
=
'1.0.0'
,
type
=
"CV/image_editing"
,
author
=
""
,
author_email
=
""
,
summary
=
"Flexible JPEG Artifacts Removal."
,
)
class
FBCNNColor
(
nn
.
Layer
):
def
__init__
(
self
):
super
(
FBCNNColor
,
self
).
__init__
()
self
.
default_pretrained_model_path
=
os
.
path
.
join
(
self
.
directory
,
'ckpts'
,
'fbcnn_color.pdparams'
)
self
.
fbcnn
=
FBCNN
()
state_dict
=
paddle
.
load
(
self
.
default_pretrained_model_path
)
self
.
fbcnn
.
set_state_dict
(
state_dict
)
self
.
fbcnn
.
eval
()
def
preprocess
(
self
,
img
:
np
.
ndarray
)
->
np
.
ndarray
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
img
=
img
.
transpose
((
2
,
0
,
1
))
img
=
img
/
255.0
return
img
.
astype
(
np
.
float32
)
def
postprocess
(
self
,
img
:
np
.
ndarray
)
->
np
.
ndarray
:
img
=
img
.
clip
(
0
,
1
)
img
=
img
*
255.0
img
=
img
.
transpose
((
1
,
2
,
0
))
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_RGB2BGR
)
return
img
.
astype
(
np
.
uint8
)
def
artifacts_removal
(
self
,
image
:
Union
[
str
,
np
.
ndarray
],
quality_factor
:
float
=
None
,
visualization
:
bool
=
True
,
output_dir
:
str
=
"fbcnn_color_output"
)
->
np
.
ndarray
:
if
isinstance
(
image
,
str
):
_
,
file_name
=
os
.
path
.
split
(
image
)
save_name
,
_
=
os
.
path
.
splitext
(
file_name
)
save_name
=
save_name
+
'_'
+
str
(
int
(
time
.
time
()))
+
'.jpg'
image
=
cv2
.
imdecode
(
np
.
fromfile
(
image
,
dtype
=
np
.
uint8
),
cv2
.
IMREAD_COLOR
)
elif
isinstance
(
image
,
np
.
ndarray
):
save_name
=
str
(
int
(
time
.
time
()))
+
'.jpg'
image
=
image
else
:
raise
Exception
(
"image should be a str / np.ndarray"
)
with
paddle
.
no_grad
():
img_input
=
self
.
preprocess
(
image
)
img_input
=
paddle
.
to_tensor
(
img_input
[
None
,
...],
dtype
=
paddle
.
float32
)
if
quality_factor
and
0
<=
quality_factor
<=
1
:
qf_input
=
paddle
.
to_tensor
([[
quality_factor
]],
dtype
=
paddle
.
float32
)
else
:
qf_input
=
None
img_output
,
_
=
self
.
fbcnn
(
img_input
,
qf_input
)
img_output
=
img_output
.
numpy
()[
0
]
img_output
=
self
.
postprocess
(
img_output
)
if
visualization
:
if
not
os
.
path
.
isdir
(
output_dir
):
os
.
makedirs
(
output_dir
)
save_path
=
os
.
path
.
join
(
output_dir
,
save_name
)
cv2
.
imwrite
(
save_path
,
img_output
)
return
img_output
@
runnable
def
run_cmd
(
self
,
argvs
):
"""
Run as a command.
"""
self
.
parser
=
argparse
.
ArgumentParser
(
description
=
"Run the {} module."
.
format
(
self
.
name
),
prog
=
'hub run {}'
.
format
(
self
.
name
),
usage
=
'%(prog)s'
,
add_help
=
True
)
self
.
parser
.
add_argument
(
'--input_path'
,
type
=
str
,
help
=
"Path to image."
)
self
.
parser
.
add_argument
(
'--quality_factor'
,
type
=
float
,
default
=
None
,
help
=
"Image quality factor (0.0-1.0)."
)
self
.
parser
.
add_argument
(
'--output_dir'
,
type
=
str
,
default
=
'fbcnn_color_output'
,
help
=
"The directory to save output images."
)
args
=
self
.
parser
.
parse_args
(
argvs
)
self
.
artifacts_removal
(
image
=
args
.
input_path
,
quality_factor
=
args
.
quality_factor
,
visualization
=
True
,
output_dir
=
args
.
output_dir
)
return
'Artifacts removal results are saved in %s'
%
args
.
output_dir
@
serving
def
serving_method
(
self
,
image
,
**
kwargs
):
"""
Run as a service.
"""
image
=
base64_to_cv2
(
image
)
img_output
=
self
.
artifacts_removal
(
image
=
image
,
**
kwargs
)
return
cv2_to_base64
(
img_output
)
modules/image/Image_editing/enhancement/fbcnn_color/test.py
0 → 100644
浏览文件 @
270cc958
import
os
import
shutil
import
unittest
import
cv2
import
numpy
as
np
import
requests
import
paddlehub
as
hub
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'0'
class
TestHubModule
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
)
->
None
:
img_url
=
'https://unsplash.com/photos/mJaD10XeD7w/download?ixid=MnwxMjA3fDB8MXxzZWFyY2h8M3x8Y2F0fGVufDB8fHx8MTY2MzczNDc3Mw&force=true&w=640'
if
not
os
.
path
.
exists
(
'tests'
):
os
.
makedirs
(
'tests'
)
response
=
requests
.
get
(
img_url
)
assert
response
.
status_code
==
200
,
'Network Error.'
with
open
(
'tests/test.jpg'
,
'wb'
)
as
f
:
f
.
write
(
response
.
content
)
cls
.
module
=
hub
.
Module
(
name
=
"fbcnn_color"
)
@
classmethod
def
tearDownClass
(
cls
)
->
None
:
shutil
.
rmtree
(
'tests'
)
shutil
.
rmtree
(
'fbcnn_color_output'
)
def
test_artifacts_removal1
(
self
):
results
=
self
.
module
.
artifacts_removal
(
image
=
'tests/test.jpg'
,
quality_factor
=
None
,
visualization
=
False
)
self
.
assertIsInstance
(
results
,
np
.
ndarray
)
def
test_artifacts_removal2
(
self
):
results
=
self
.
module
.
artifacts_removal
(
image
=
cv2
.
imread
(
'tests/test.jpg'
),
quality_factor
=
None
,
visualization
=
True
)
self
.
assertIsInstance
(
results
,
np
.
ndarray
)
def
test_artifacts_removal3
(
self
):
results
=
self
.
module
.
artifacts_removal
(
image
=
cv2
.
imread
(
'tests/test.jpg'
),
quality_factor
=
0.5
,
visualization
=
True
)
self
.
assertIsInstance
(
results
,
np
.
ndarray
)
def
test_artifacts_removal4
(
self
):
self
.
assertRaises
(
Exception
,
self
.
module
.
artifacts_removal
,
image
=
[
'tests/test.jpg'
])
def
test_artifacts_removal5
(
self
):
self
.
assertRaises
(
FileNotFoundError
,
self
.
module
.
artifacts_removal
,
image
=
'no.jpg'
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录