Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
6112bd38
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 2 年 前同步成功
通知
285
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6112bd38
编写于
4月 06, 2021
作者:
H
haoyuying
提交者:
GitHub
4月 06, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add deeplabv3 and hrnetw18.
上级
5b108603
变更
17
显示空白变更内容
内联
并排
Showing
17 changed file
with
2822 addition
and
8 deletion
+2822
-8
demo/semantic_segmentation/N0007.jpg
demo/semantic_segmentation/N0007.jpg
+0
-0
demo/semantic_segmentation/README.md
demo/semantic_segmentation/README.md
+164
-0
demo/semantic_segmentation/predict.py
demo/semantic_segmentation/predict.py
+6
-0
demo/semantic_segmentation/train.py
demo/semantic_segmentation/train.py
+16
-0
docs/docs_ch/reference/datasets.md
docs/docs_ch/reference/datasets.md
+15
-0
modules/image/semantic_segmentation/deeplabv3p_resnet50_voc/layers.py
...e/semantic_segmentation/deeplabv3p_resnet50_voc/layers.py
+345
-0
modules/image/semantic_segmentation/deeplabv3p_resnet50_voc/module.py
...e/semantic_segmentation/deeplabv3p_resnet50_voc/module.py
+186
-0
modules/image/semantic_segmentation/deeplabv3p_resnet50_voc/resnet.py
...e/semantic_segmentation/deeplabv3p_resnet50_voc/resnet.py
+137
-0
modules/image/semantic_segmentation/ocrnet_hrnetw18_voc/hrnet.py
.../image/semantic_segmentation/ocrnet_hrnetw18_voc/hrnet.py
+612
-0
modules/image/semantic_segmentation/ocrnet_hrnetw18_voc/layers.py
...image/semantic_segmentation/ocrnet_hrnetw18_voc/layers.py
+345
-0
modules/image/semantic_segmentation/ocrnet_hrnetw18_voc/module.py
...image/semantic_segmentation/ocrnet_hrnetw18_voc/module.py
+243
-0
paddlehub/datasets/__init__.py
paddlehub/datasets/__init__.py
+2
-0
paddlehub/datasets/base_seg_dataset.py
paddlehub/datasets/base_seg_dataset.py
+141
-0
paddlehub/datasets/opticdiscseg.py
paddlehub/datasets/opticdiscseg.py
+78
-0
paddlehub/module/cv_module.py
paddlehub/module/cv_module.py
+111
-2
paddlehub/vision/segmentation_transforms.py
paddlehub/vision/segmentation_transforms.py
+307
-0
paddlehub/vision/utils.py
paddlehub/vision/utils.py
+114
-6
未找到文件。
demo/semantic_segmentation/N0007.jpg
0 → 100644
浏览文件 @
6112bd38
41.2 KB
demo/semantic_segmentation/README.md
0 → 100644
浏览文件 @
6112bd38
# PaddleHub 图像分割
本示例将展示如何使用PaddleHub对预训练模型进行finetune并完成预测任务。
## 如何开始Fine-tune
在完成安装PaddlePaddle与PaddleHub后,通过执行
`python train.py`
即可开始使用ocrnet_hrnetw18_voc模型对OpticDiscSeg等数据集进行Fine-tune。
## 代码步骤o
使用PaddleHub Fine-tune API进行Fine-tune可以分为4个步骤。
### Step1: 定义数据预处理方式
```
python
from
paddlehub.vision.segmentation_transforms
import
Compose
,
Resize
,
Normalize
transform
=
Compose
([
Resize
(
target_size
=
(
512
,
512
)),
Normalize
()])
```
`segmentation_transforms`
数据增强模块定义了丰富的针对图像分割数据的预处理方式,用户可按照需求替换自己需要的数据预处理方式。
### Step2: 下载数据集并使用
```
python
from
paddlehub.datasets
import
OpticDiscSeg
train_reader
=
OpticDiscSeg
(
transform
,
mode
=
'train'
)
```
*
`transform`
: 数据预处理方式。
*
`mode`
: 选择数据模式,可选项有
`train`
,
`test`
,
`val`
, 默认为
`train`
。
数据集的准备代码可以参考
[
opticdiscseg.py
](
../../paddlehub/datasets/opticdiscseg.py
)
。
`hub.datasets.OpticDiscSeg()`
会自动从网络下载数据集并解压到用户目录下
`$HOME/.paddlehub/dataset`
目录。
### Step3: 加载预训练模型
```
python
model
=
hub
.
Module
(
name
=
'ocrnet_hrnetw18_voc'
,
num_classes
=
2
,
pretrained
=
None
)
```
*
`name`
: 选择预训练模型的名字。
*
`num_classes`
: 分割模型的类别数目。
*
`pretrained`
: 是否加载自己训练的模型,若为None,则加载提供的模型默认参数。
### Step4: 选择优化策略和运行配置
```
python
scheduler
=
paddle
.
optimizer
.
lr
.
PolynomialDecay
(
learning_rate
=
0.01
,
decay_steps
=
1000
,
power
=
0.9
,
end_lr
=
0.0001
)
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
scheduler
,
parameters
=
model
.
parameters
())
trainer
=
Trainer
(
model
,
optimizer
,
checkpoint_dir
=
'test_ckpt_img_ocr'
,
use_gpu
=
True
)
```
#### 优化策略
Paddle2.0rc提供了多种优化器选择,如
`SGD`
,
`Adam`
,
`Adamax`
等,详细参见
[
策略
](
https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-rc/api/paddle/optimizer/optimizer/Optimizer_cn.html
)
。
其中
`Adam`
:
*
`learning_rate`
: 全局学习率。
*
`parameters`
: 待优化模型参数。
#### 运行配置
`Trainer`
主要控制Fine-tune的训练,包含以下可控制的参数:
*
`model`
: 被优化模型;
*
`optimizer`
: 优化器选择;
*
`use_gpu`
: 是否使用gpu,默认为False;
*
`use_vdl`
: 是否使用vdl可视化训练过程;
*
`checkpoint_dir`
: 保存模型参数的地址;
*
`compare_metrics`
: 保存最优模型的衡量指标;
`trainer.train`
主要控制具体的训练过程,包含以下可控制的参数:
*
`train_dataset`
: 训练时所用的数据集;
*
`epochs`
: 训练轮数;
*
`batch_size`
: 训练的批大小,如果使用GPU,请根据实际情况调整batch_size;
*
`num_workers`
: works的数量,默认为0;
*
`eval_dataset`
: 验证集;
*
`log_interval`
: 打印日志的间隔, 单位为执行批训练的次数。
*
`save_interval`
: 保存模型的间隔频次,单位为执行训练的轮数。
## 模型预测
当完成Fine-tune后,Fine-tune过程在验证集上表现最优的模型会被保存在
`${CHECKPOINT_DIR}/best_model`
目录下,其中
`${CHECKPOINT_DIR}`
目录为Fine-tune时所选择的保存checkpoint的目录。
我们使用该模型来进行预测。predict.py脚本如下:
```
python
import
paddle
import
cv2
import
paddlehub
as
hub
if
__name__
==
'__main__'
:
model
=
hub
.
Module
(
name
=
'ocrnet_hrnetw18_voc'
,
pretrained
=
'/PATH/TO/CHECKPOINT'
)
img
=
cv2
.
imread
(
"/PATH/TO/IMAGE"
)
model
.
predict
(
images
=
[
img
],
visualization
=
True
)
```
参数配置正确后,请执行脚本
`python predict.py`
。
**Args**
*
`images`
:原始图像路径或BGR格式图片;
*
`visualization`
: 是否可视化,默认为True;
*
`save_path`
: 保存结果的路径,默认保存路径为'seg_result'。
**NOTE:**
进行预测时,所选择的module,checkpoint_dir,dataset必须和Fine-tune所用的一样。
## 服务部署
PaddleHub Serving可以部署一个在线图像分割服务。
### Step1: 启动PaddleHub Serving
运行启动命令:
```
shell
$
hub serving start
-m
ocrnet_hrnetw18_voc
```
这样就完成了一个图像分割服务化API的部署,默认端口号为8866。
**NOTE:**
如使用GPU预测,则需要在启动服务之前,请设置CUDA_VISIBLE_DEVICES环境变量,否则不用设置。
### Step2: 发送预测请求
配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
```
python
import
requests
import
json
import
cv2
import
base64
import
numpy
as
np
def
cv2_to_base64
(
image
):
data
=
cv2
.
imencode
(
'.jpg'
,
image
)[
1
]
return
base64
.
b64encode
(
data
.
tostring
()).
decode
(
'utf8'
)
def
base64_to_cv2
(
b64str
):
data
=
base64
.
b64decode
(
b64str
.
encode
(
'utf8'
))
data
=
np
.
fromstring
(
data
,
np
.
uint8
)
data
=
cv2
.
imdecode
(
data
,
cv2
.
IMREAD_COLOR
)
return
data
# 发送HTTP请求
org_im
=
cv2
.
imread
(
'/PATH/TO/IMAGE'
)
data
=
{
'images'
:[
cv2_to_base64
(
org_im
)]}
headers
=
{
"Content-type"
:
"application/json"
}
url
=
"http://127.0.0.1:8866/predict/ocrnet_hrnetw18_voc"
r
=
requests
.
post
(
url
=
url
,
headers
=
headers
,
data
=
json
.
dumps
(
data
))
mask
=
base64_to_cv2
(
r
.
json
()[
"results"
][
0
])
```
### 查看代码
https://github.com/PaddlePaddle/PaddleSeg
### 依赖
paddlepaddle >= 2.0.0rc
paddlehub >= 2.0.0
demo/semantic_segmentation/predict.py
0 → 100644
浏览文件 @
6112bd38
import
paddle
import
paddlehub
as
hub
if
__name__
==
'__main__'
:
model
=
hub
.
Module
(
name
=
'ocrnet_hrnetw18_voc'
,
num_classes
=
2
,
pretrained
=
'/PATH/TO/CHECKPOINT'
)
model
.
predict
(
images
=
[
"N0007.jpg"
],
visualization
=
True
)
\ No newline at end of file
demo/semantic_segmentation/train.py
0 → 100644
浏览文件 @
6112bd38
import
paddle
import
paddlehub
as
hub
from
paddlehub.finetune.trainer
import
Trainer
from
paddlehub.datasets
import
OpticDiscSeg
from
paddlehub.vision.segmentation_transforms
import
Compose
,
Resize
,
Normalize
if
__name__
==
"__main__"
:
transform
=
Compose
([
Resize
(
target_size
=
(
512
,
512
)),
Normalize
()])
train_reader
=
OpticDiscSeg
(
transform
)
model
=
hub
.
Module
(
name
=
'ocrnet_hrnetw18_voc'
,
num_classes
=
2
)
scheduler
=
paddle
.
optimizer
.
lr
.
PolynomialDecay
(
learning_rate
=
0.01
,
decay_steps
=
1000
,
power
=
0.9
,
end_lr
=
0.0001
)
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
scheduler
,
parameters
=
model
.
parameters
())
trainer
=
Trainer
(
model
,
optimizer
,
checkpoint_dir
=
'test_ckpt_img_ocr'
,
use_gpu
=
True
)
trainer
.
train
(
train_reader
,
epochs
=
20
,
batch_size
=
4
,
eval_dataset
=
train_reader
,
log_interval
=
10
,
save_interval
=
4
)
\ No newline at end of file
docs/docs_ch/reference/datasets.md
浏览文件 @
6112bd38
...
@@ -39,3 +39,18 @@ Dataset for Style transfer. The dataset contains 2001 images for training set an
...
@@ -39,3 +39,18 @@ Dataset for Style transfer. The dataset contains 2001 images for training set an
**Args**
**Args**
*
transforms(callmethod) : The method of preprocess images.
*
transforms(callmethod) : The method of preprocess images.
*
mode(str): The mode for preparing dataset.
*
mode(str): The mode for preparing dataset.
# Class `hub.datasets.OpticDiscSeg`
```
python
hub
.
datasets
.
OpticDiscSeg
(
transforms
:
Callable
,
mode
:
str
=
'train'
)
```
Dataset for semantic segmentation. The dataset contains 267 images for training set, 76 images for validation set and 38 images for testing set.
**Args**
*
transforms(callmethod) : The method of preprocess images.
*
mode(str): The mode for preparing dataset.
modules/image/semantic_segmentation/deeplabv3p_resnet50_voc/layers.py
0 → 100644
浏览文件 @
6112bd38
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle.nn.layer
import
activation
from
paddle.nn
import
Conv2D
,
AvgPool2D
def
SyncBatchNorm
(
*
args
,
**
kwargs
):
"""In cpu environment nn.SyncBatchNorm does not have kernel so use nn.BatchNorm2D instead"""
if
paddle
.
get_device
()
==
'cpu'
:
return
nn
.
BatchNorm2D
(
*
args
,
**
kwargs
)
else
:
return
nn
.
SyncBatchNorm
(
*
args
,
**
kwargs
)
class
ConvBNLayer
(
nn
.
Layer
):
"""Basic conv bn relu layer."""
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
int
,
stride
:
int
=
1
,
dilation
:
int
=
1
,
groups
:
int
=
1
,
is_vd_mode
:
bool
=
False
,
act
:
str
=
None
,
name
:
str
=
None
):
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
is_vd_mode
=
is_vd_mode
self
.
_pool2d_avg
=
AvgPool2D
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
,
ceil_mode
=
True
)
self
.
_conv
=
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
(
kernel_size
-
1
)
//
2
if
dilation
==
1
else
0
,
dilation
=
dilation
,
groups
=
groups
,
bias_attr
=
False
)
self
.
_batch_norm
=
SyncBatchNorm
(
out_channels
)
self
.
_act_op
=
Activation
(
act
=
act
)
def
forward
(
self
,
inputs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
if
self
.
is_vd_mode
:
inputs
=
self
.
_pool2d_avg
(
inputs
)
y
=
self
.
_conv
(
inputs
)
y
=
self
.
_batch_norm
(
y
)
y
=
self
.
_act_op
(
y
)
return
y
class
BottleneckBlock
(
nn
.
Layer
):
"""Residual bottleneck block"""
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
stride
:
int
,
shortcut
:
bool
=
True
,
if_first
:
bool
=
False
,
dilation
:
int
=
1
,
name
:
str
=
None
):
super
(
BottleneckBlock
,
self
).
__init__
()
self
.
conv0
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
,
act
=
'relu'
,
name
=
name
+
"_branch2a"
)
self
.
dilation
=
dilation
self
.
conv1
=
ConvBNLayer
(
in_channels
=
out_channels
,
out_channels
=
out_channels
,
kernel_size
=
3
,
stride
=
stride
,
act
=
'relu'
,
dilation
=
dilation
,
name
=
name
+
"_branch2b"
)
self
.
conv2
=
ConvBNLayer
(
in_channels
=
out_channels
,
out_channels
=
out_channels
*
4
,
kernel_size
=
1
,
act
=
None
,
name
=
name
+
"_branch2c"
)
if
not
shortcut
:
self
.
short
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
out_channels
*
4
,
kernel_size
=
1
,
stride
=
1
,
is_vd_mode
=
False
if
if_first
or
stride
==
1
else
True
,
name
=
name
+
"_branch1"
)
self
.
shortcut
=
shortcut
def
forward
(
self
,
inputs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
y
=
self
.
conv0
(
inputs
)
if
self
.
dilation
>
1
:
padding
=
self
.
dilation
y
=
F
.
pad
(
y
,
[
padding
,
padding
,
padding
,
padding
])
conv1
=
self
.
conv1
(
y
)
conv2
=
self
.
conv2
(
conv1
)
if
self
.
shortcut
:
short
=
inputs
else
:
short
=
self
.
short
(
inputs
)
y
=
paddle
.
add
(
x
=
short
,
y
=
conv2
)
y
=
F
.
relu
(
y
)
return
y
class
SeparableConvBNReLU
(
nn
.
Layer
):
"""Depthwise Separable Convolution."""
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
int
,
padding
:
str
=
'same'
,
**
kwargs
:
dict
):
super
(
SeparableConvBNReLU
,
self
).
__init__
()
self
.
depthwise_conv
=
ConvBN
(
in_channels
,
out_channels
=
in_channels
,
kernel_size
=
kernel_size
,
padding
=
padding
,
groups
=
in_channels
,
**
kwargs
)
self
.
piontwise_conv
=
ConvBNReLU
(
in_channels
,
out_channels
,
kernel_size
=
1
,
groups
=
1
)
def
forward
(
self
,
x
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
x
=
self
.
depthwise_conv
(
x
)
x
=
self
.
piontwise_conv
(
x
)
return
x
class
ConvBN
(
nn
.
Layer
):
"""Basic conv bn layer"""
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
int
,
padding
:
str
=
'same'
,
**
kwargs
:
dict
):
super
(
ConvBN
,
self
).
__init__
()
self
.
_conv
=
Conv2D
(
in_channels
,
out_channels
,
kernel_size
,
padding
=
padding
,
**
kwargs
)
self
.
_batch_norm
=
SyncBatchNorm
(
out_channels
)
def
forward
(
self
,
x
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
x
=
self
.
_conv
(
x
)
x
=
self
.
_batch_norm
(
x
)
return
x
class
ConvBNReLU
(
nn
.
Layer
):
"""Basic conv bn relu layer."""
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
int
,
padding
:
str
=
'same'
,
**
kwargs
:
dict
):
super
(
ConvBNReLU
,
self
).
__init__
()
self
.
_conv
=
Conv2D
(
in_channels
,
out_channels
,
kernel_size
,
padding
=
padding
,
**
kwargs
)
self
.
_batch_norm
=
SyncBatchNorm
(
out_channels
)
def
forward
(
self
,
x
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
x
=
self
.
_conv
(
x
)
x
=
self
.
_batch_norm
(
x
)
x
=
F
.
relu
(
x
)
return
x
class
Activation
(
nn
.
Layer
):
"""
The wrapper of activations.
Args:
act (str, optional): The activation name in lowercase. It must be one of ['elu', 'gelu',
'hardshrink', 'tanh', 'hardtanh', 'prelu', 'relu', 'relu6', 'selu', 'leakyrelu', 'sigmoid',
'softmax', 'softplus', 'softshrink', 'softsign', 'tanhshrink', 'logsigmoid', 'logsoftmax',
'hsigmoid']. Default: None, means identical transformation.
Returns:
A callable object of Activation.
Raises:
KeyError: When parameter `act` is not in the optional range.
Examples:
from paddleseg.models.common.activation import Activation
relu = Activation("relu")
print(relu)
# <class 'paddle.nn.layer.activation.ReLU'>
sigmoid = Activation("sigmoid")
print(sigmoid)
# <class 'paddle.nn.layer.activation.Sigmoid'>
not_exit_one = Activation("not_exit_one")
# KeyError: "not_exit_one does not exist in the current dict_keys(['elu', 'gelu', 'hardshrink',
# 'tanh', 'hardtanh', 'prelu', 'relu', 'relu6', 'selu', 'leakyrelu', 'sigmoid', 'softmax',
# 'softplus', 'softshrink', 'softsign', 'tanhshrink', 'logsigmoid', 'logsoftmax', 'hsigmoid'])"
"""
def
__init__
(
self
,
act
:
str
=
None
):
super
(
Activation
,
self
).
__init__
()
self
.
_act
=
act
upper_act_names
=
activation
.
__all__
lower_act_names
=
[
act
.
lower
()
for
act
in
upper_act_names
]
act_dict
=
dict
(
zip
(
lower_act_names
,
upper_act_names
))
if
act
is
not
None
:
if
act
in
act_dict
.
keys
():
act_name
=
act_dict
[
act
]
self
.
act_func
=
eval
(
"activation.{}()"
.
format
(
act_name
))
else
:
raise
KeyError
(
"{} does not exist in the current {}"
.
format
(
act
,
act_dict
.
keys
()))
def
forward
(
self
,
x
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
if
self
.
_act
is
not
None
:
return
self
.
act_func
(
x
)
else
:
return
x
class
ASPPModule
(
nn
.
Layer
):
"""
Atrous Spatial Pyramid Pooling.
Args:
aspp_ratios (tuple): The dilation rate using in ASSP module.
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature
is even, e.g. 1024x512, otherwise it is True, e.g. 769x769.
use_sep_conv (bool, optional): If using separable conv in ASPP module. Default: False.
image_pooling (bool, optional): If augmented with image-level features. Default: False
"""
def
__init__
(
self
,
aspp_ratios
:
tuple
,
in_channels
:
int
,
out_channels
:
int
,
align_corners
:
bool
,
use_sep_conv
:
bool
=
False
,
image_pooling
:
bool
=
False
):
super
().
__init__
()
self
.
align_corners
=
align_corners
self
.
aspp_blocks
=
nn
.
LayerList
()
for
ratio
in
aspp_ratios
:
if
use_sep_conv
and
ratio
>
1
:
conv_func
=
SeparableConvBNReLU
else
:
conv_func
=
ConvBNReLU
block
=
conv_func
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
if
ratio
==
1
else
3
,
dilation
=
ratio
,
padding
=
0
if
ratio
==
1
else
ratio
)
self
.
aspp_blocks
.
append
(
block
)
out_size
=
len
(
self
.
aspp_blocks
)
if
image_pooling
:
self
.
global_avg_pool
=
nn
.
Sequential
(
nn
.
AdaptiveAvgPool2D
(
output_size
=
(
1
,
1
)),
ConvBNReLU
(
in_channels
,
out_channels
,
kernel_size
=
1
,
bias_attr
=
False
))
out_size
+=
1
self
.
image_pooling
=
image_pooling
self
.
conv_bn_relu
=
ConvBNReLU
(
in_channels
=
out_channels
*
out_size
,
out_channels
=
out_channels
,
kernel_size
=
1
)
self
.
dropout
=
nn
.
Dropout
(
p
=
0.1
)
# drop rate
def
forward
(
self
,
x
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
outputs
=
[]
for
block
in
self
.
aspp_blocks
:
y
=
block
(
x
)
y
=
F
.
interpolate
(
y
,
x
.
shape
[
2
:],
mode
=
'bilinear'
,
align_corners
=
self
.
align_corners
)
outputs
.
append
(
y
)
if
self
.
image_pooling
:
img_avg
=
self
.
global_avg_pool
(
x
)
img_avg
=
F
.
interpolate
(
img_avg
,
x
.
shape
[
2
:],
mode
=
'bilinear'
,
align_corners
=
self
.
align_corners
)
outputs
.
append
(
img_avg
)
x
=
paddle
.
concat
(
outputs
,
axis
=
1
)
x
=
self
.
conv_bn_relu
(
x
)
x
=
self
.
dropout
(
x
)
return
x
modules/image/semantic_segmentation/deeplabv3p_resnet50_voc/module.py
0 → 100644
浏览文件 @
6112bd38
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
typing
import
Union
,
List
,
Tuple
import
paddle
from
paddle
import
nn
import
paddle.nn.functional
as
F
import
numpy
as
np
from
paddlehub.module.module
import
moduleinfo
import
paddlehub.vision.segmentation_transforms
as
T
from
paddlehub.module.cv_module
import
ImageSegmentationModule
from
deeplabv3p_resnet50_voc.resnet
import
ResNet50_vd
import
deeplabv3p_resnet50_voc.layers
as
L
@
moduleinfo
(
name
=
"deeplabv3p_resnet50_voc"
,
type
=
"CV/semantic_segmentation"
,
author
=
"paddlepaddle"
,
author_email
=
""
,
summary
=
"DeepLabV3PResnet50 is a segmentation model."
,
version
=
"1.0.0"
,
meta
=
ImageSegmentationModule
)
class
DeepLabV3PResnet50
(
nn
.
Layer
):
"""
The DeepLabV3PResnet50 implementation based on PaddlePaddle.
The original article refers to
Liang-Chieh Chen, et, al. "Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation"
(https://arxiv.org/abs/1802.02611)
Args:
num_classes (int): the unique number of target classes.
backbone_indices (tuple): two values in the tuple indicate the indices of output of backbone.
the first index will be taken as a low-level feature in Decoder component;
the second one will be taken as input of ASPP component.
Usually backbone consists of four downsampling stage, and return an output of
each stage, so we set default (0, 3), which means taking feature map of the first
stage in backbone as low-level feature used in Decoder, and feature map of the fourth
stage as input of ASPP.
aspp_ratios (tuple): the dilation rate using in ASSP module.
if output_stride=16, aspp_ratios should be set as (1, 6, 12, 18).
if output_stride=8, aspp_ratios is (1, 12, 24, 36).
aspp_out_channels (int): the output channels of ASPP module.
align_corners (bool, optional): An argument of F.interpolate. It should be set to False when the feature size is even,
e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
pretrained (str): the path of pretrained model. Default to None.
"""
def
__init__
(
self
,
num_classes
:
int
=
21
,
backbone_indices
:
Tuple
[
int
]
=
(
0
,
3
),
aspp_ratios
:
Tuple
[
int
]
=
(
1
,
12
,
24
,
36
),
aspp_out_channels
:
int
=
256
,
align_corners
=
False
,
pretrained
:
str
=
None
):
super
(
DeepLabV3PResnet50
,
self
).
__init__
()
self
.
backbone
=
ResNet50_vd
()
backbone_channels
=
[
self
.
backbone
.
feat_channels
[
i
]
for
i
in
backbone_indices
]
self
.
head
=
DeepLabV3PHead
(
num_classes
,
backbone_indices
,
backbone_channels
,
aspp_ratios
,
aspp_out_channels
,
align_corners
)
self
.
align_corners
=
align_corners
self
.
transforms
=
T
.
Compose
([
T
.
Padding
(
target_size
=
(
512
,
512
)),
T
.
Normalize
()])
if
pretrained
is
not
None
:
model_dict
=
paddle
.
load
(
pretrained
)
self
.
set_dict
(
model_dict
)
print
(
"load custom parameters success"
)
else
:
checkpoint
=
os
.
path
.
join
(
self
.
directory
,
'deeplabv3p_model.pdparams'
)
model_dict
=
paddle
.
load
(
checkpoint
)
self
.
set_dict
(
model_dict
)
print
(
"load pretrained parameters success"
)
def
transform
(
self
,
img
:
Union
[
np
.
ndarray
,
str
])
->
Union
[
np
.
ndarray
,
str
]:
return
self
.
transforms
(
img
)
def
forward
(
self
,
x
:
paddle
.
Tensor
)
->
List
[
paddle
.
Tensor
]:
feat_list
=
self
.
backbone
(
x
)
logit_list
=
self
.
head
(
feat_list
)
return
[
F
.
interpolate
(
logit
,
x
.
shape
[
2
:],
mode
=
'bilinear'
,
align_corners
=
self
.
align_corners
)
for
logit
in
logit_list
]
class
DeepLabV3PHead
(
nn
.
Layer
):
"""
The DeepLabV3PHead implementation based on PaddlePaddle.
Args:
num_classes (int): The unique number of target classes.
backbone_indices (tuple): Two values in the tuple indicate the indices of output of backbone.
the first index will be taken as a low-level feature in Decoder component;
the second one will be taken as input of ASPP component.
Usually backbone consists of four downsampling stage, and return an output of
each stage. If we set it as (0, 3), it means taking feature map of the first
stage in backbone as low-level feature used in Decoder, and feature map of the fourth
stage as input of ASPP.
backbone_channels (tuple): The same length with "backbone_indices". It indicates the channels of corresponding index.
aspp_ratios (tuple): The dilation rates using in ASSP module.
aspp_out_channels (int): The output channels of ASPP module.
align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature
is even, e.g. 1024x512, otherwise it is True, e.g. 769x769.
"""
def
__init__
(
self
,
num_classes
:
int
,
backbone_indices
:
Tuple
[
paddle
.
Tensor
],
backbone_channels
:
Tuple
[
paddle
.
Tensor
],
aspp_ratios
:
Tuple
[
float
],
aspp_out_channels
:
int
,
align_corners
:
bool
):
super
().
__init__
()
self
.
aspp
=
L
.
ASPPModule
(
aspp_ratios
,
backbone_channels
[
1
],
aspp_out_channels
,
align_corners
,
use_sep_conv
=
True
,
image_pooling
=
True
)
self
.
decoder
=
Decoder
(
num_classes
,
backbone_channels
[
0
],
align_corners
)
self
.
backbone_indices
=
backbone_indices
def
forward
(
self
,
feat_list
:
List
[
paddle
.
Tensor
])
->
List
[
paddle
.
Tensor
]:
logit_list
=
[]
low_level_feat
=
feat_list
[
self
.
backbone_indices
[
0
]]
x
=
feat_list
[
self
.
backbone_indices
[
1
]]
x
=
self
.
aspp
(
x
)
logit
=
self
.
decoder
(
x
,
low_level_feat
)
logit_list
.
append
(
logit
)
return
logit_list
class
Decoder
(
nn
.
Layer
):
"""
Decoder module of DeepLabV3P model
Args:
num_classes (int): The number of classes.
in_channels (int): The number of input channels in decoder module.
align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature is even, e.g. 1024x512, otherwise it is True, e.g. 769x769.
"""
def
__init__
(
self
,
num_classes
:
int
,
in_channels
:
int
,
align_corners
:
bool
):
super
(
Decoder
,
self
).
__init__
()
self
.
conv_bn_relu1
=
L
.
ConvBNReLU
(
in_channels
=
in_channels
,
out_channels
=
48
,
kernel_size
=
1
)
self
.
conv_bn_relu2
=
L
.
SeparableConvBNReLU
(
in_channels
=
304
,
out_channels
=
256
,
kernel_size
=
3
,
padding
=
1
)
self
.
conv_bn_relu3
=
L
.
SeparableConvBNReLU
(
in_channels
=
256
,
out_channels
=
256
,
kernel_size
=
3
,
padding
=
1
)
self
.
conv
=
nn
.
Conv2D
(
in_channels
=
256
,
out_channels
=
num_classes
,
kernel_size
=
1
)
self
.
align_corners
=
align_corners
def
forward
(
self
,
x
:
paddle
.
Tensor
,
low_level_feat
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
low_level_feat
=
self
.
conv_bn_relu1
(
low_level_feat
)
x
=
F
.
interpolate
(
x
,
low_level_feat
.
shape
[
2
:],
mode
=
'bilinear'
,
align_corners
=
self
.
align_corners
)
x
=
paddle
.
concat
([
x
,
low_level_feat
],
axis
=
1
)
x
=
self
.
conv_bn_relu2
(
x
)
x
=
self
.
conv_bn_relu3
(
x
)
x
=
self
.
conv
(
x
)
return
x
\ No newline at end of file
modules/image/semantic_segmentation/deeplabv3p_resnet50_voc/resnet.py
0 → 100644
浏览文件 @
6112bd38
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
import
deeplabv3p_resnet50_voc.layers
as
L
class
BasicBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
stride
:
int
,
shortcut
:
bool
=
True
,
if_first
:
bool
=
False
,
name
:
str
=
None
):
super
(
BasicBlock
,
self
).
__init__
()
self
.
stride
=
stride
self
.
conv0
=
L
.
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
3
,
stride
=
stride
,
act
=
'relu'
,
name
=
name
+
"_branch2a"
)
self
.
conv1
=
L
.
ConvBNLayer
(
in_channels
=
out_channels
,
out_channels
=
out_channels
,
kernel_size
=
3
,
act
=
None
,
name
=
name
+
"_branch2b"
)
if
not
shortcut
:
self
.
short
=
L
.
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
,
stride
=
1
,
is_vd_mode
=
False
if
if_first
else
True
,
name
=
name
+
"_branch1"
)
self
.
shortcut
=
shortcut
def
forward
(
self
,
inputs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
y
=
self
.
conv0
(
inputs
)
conv1
=
self
.
conv1
(
y
)
if
self
.
shortcut
:
short
=
inputs
else
:
short
=
self
.
short
(
inputs
)
y
=
paddle
.
elementwise_add
(
x
=
short
,
y
=
conv1
,
act
=
'relu'
)
return
y
class
ResNet50_vd
(
nn
.
Layer
):
def
__init__
(
self
,
multi_grid
:
tuple
=
(
1
,
2
,
4
)):
super
(
ResNet50_vd
,
self
).
__init__
()
depth
=
[
3
,
4
,
6
,
3
]
num_channels
=
[
64
,
256
,
512
,
1024
]
num_filters
=
[
64
,
128
,
256
,
512
]
self
.
feat_channels
=
[
c
*
4
for
c
in
num_filters
]
dilation_dict
=
{
2
:
2
,
3
:
4
}
self
.
conv1_1
=
L
.
ConvBNLayer
(
in_channels
=
3
,
out_channels
=
32
,
kernel_size
=
3
,
stride
=
2
,
act
=
'relu'
,
name
=
"conv1_1"
)
self
.
conv1_2
=
L
.
ConvBNLayer
(
in_channels
=
32
,
out_channels
=
32
,
kernel_size
=
3
,
stride
=
1
,
act
=
'relu'
,
name
=
"conv1_2"
)
self
.
conv1_3
=
L
.
ConvBNLayer
(
in_channels
=
32
,
out_channels
=
64
,
kernel_size
=
3
,
stride
=
1
,
act
=
'relu'
,
name
=
"conv1_3"
)
self
.
pool2d_max
=
nn
.
MaxPool2D
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
stage_list
=
[]
for
block
in
range
(
len
(
depth
)):
shortcut
=
False
block_list
=
[]
for
i
in
range
(
depth
[
block
]):
conv_name
=
"res"
+
str
(
block
+
2
)
+
chr
(
97
+
i
)
dilation_rate
=
dilation_dict
[
block
]
if
dilation_dict
and
block
in
dilation_dict
else
1
if
block
==
3
:
dilation_rate
=
dilation_rate
*
multi_grid
[
i
]
bottleneck_block
=
self
.
add_sublayer
(
'bb_%d_%d'
%
(
block
,
i
),
L
.
BottleneckBlock
(
in_channels
=
num_channels
[
block
]
if
i
==
0
else
num_filters
[
block
]
*
4
,
out_channels
=
num_filters
[
block
],
stride
=
2
if
i
==
0
and
block
!=
0
and
dilation_rate
==
1
else
1
,
shortcut
=
shortcut
,
if_first
=
block
==
i
==
0
,
name
=
conv_name
,
dilation
=
dilation_rate
))
block_list
.
append
(
bottleneck_block
)
shortcut
=
True
self
.
stage_list
.
append
(
block_list
)
def
forward
(
self
,
inputs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
y
=
self
.
conv1_1
(
inputs
)
y
=
self
.
conv1_2
(
y
)
y
=
self
.
conv1_3
(
y
)
y
=
self
.
pool2d_max
(
y
)
feat_list
=
[]
for
stage
in
self
.
stage_list
:
for
block
in
stage
:
y
=
block
(
y
)
feat_list
.
append
(
y
)
return
feat_list
\ No newline at end of file
modules/image/semantic_segmentation/ocrnet_hrnetw18_voc/hrnet.py
0 → 100644
浏览文件 @
6112bd38
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
import
ocrnet_hrnetw18_voc.layers
as
L
class
HRNet_W18
(
nn
.
Layer
):
"""
The HRNet implementation based on PaddlePaddle.
The original article refers to
Jingdong Wang, et, al. "HRNet:Deep High-Resolution Representation Learning for Visual Recognition"
(https://arxiv.org/pdf/1908.07919.pdf).
Args:
pretrained (str, optional): The path of pretrained model.
stage1_num_modules (int, optional): Number of modules for stage1. Default 1.
stage1_num_blocks (list, optional): Number of blocks per module for stage1. Default (4).
stage1_num_channels (list, optional): Number of channels per branch for stage1. Default (64).
stage2_num_modules (int, optional): Number of modules for stage2. Default 1.
stage2_num_blocks (list, optional): Number of blocks per module for stage2. Default (4, 4).
stage2_num_channels (list, optional): Number of channels per branch for stage2. Default (18, 36).
stage3_num_modules (int, optional): Number of modules for stage3. Default 4.
stage3_num_blocks (list, optional): Number of blocks per module for stage3. Default (4, 4, 4).
stage3_num_channels (list, optional): Number of channels per branch for stage3. Default [18, 36, 72).
stage4_num_modules (int, optional): Number of modules for stage4. Default 3.
stage4_num_blocks (list, optional): Number of blocks per module for stage4. Default (4, 4, 4, 4).
stage4_num_channels (list, optional): Number of channels per branch for stage4. Default (18, 36, 72. 144).
has_se (bool, optional): Whether to use Squeeze-and-Excitation module. Default False.
align_corners (bool, optional): An argument of F.interpolate. It should be set to False when the feature size is even,
e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
"""
def
__init__
(
self
,
pretrained
:
str
=
None
,
stage1_num_modules
:
int
=
1
,
stage1_num_blocks
:
tuple
=
(
4
,),
stage1_num_channels
:
tuple
=
(
64
,),
stage2_num_modules
:
int
=
1
,
stage2_num_blocks
:
tuple
=
(
4
,
4
),
stage2_num_channels
:
tuple
=
(
18
,
36
),
stage3_num_modules
:
int
=
4
,
stage3_num_blocks
:
tuple
=
(
4
,
4
,
4
),
stage3_num_channels
:
tuple
=
(
18
,
36
,
72
),
stage4_num_modules
:
int
=
3
,
stage4_num_blocks
:
tuple
=
(
4
,
4
,
4
,
4
),
stage4_num_channels
:
tuple
=
(
18
,
36
,
72
,
144
),
has_se
:
bool
=
False
,
align_corners
:
bool
=
False
):
super
(
HRNet_W18
,
self
).
__init__
()
self
.
pretrained
=
pretrained
self
.
stage1_num_modules
=
stage1_num_modules
self
.
stage1_num_blocks
=
stage1_num_blocks
self
.
stage1_num_channels
=
stage1_num_channels
self
.
stage2_num_modules
=
stage2_num_modules
self
.
stage2_num_blocks
=
stage2_num_blocks
self
.
stage2_num_channels
=
stage2_num_channels
self
.
stage3_num_modules
=
stage3_num_modules
self
.
stage3_num_blocks
=
stage3_num_blocks
self
.
stage3_num_channels
=
stage3_num_channels
self
.
stage4_num_modules
=
stage4_num_modules
self
.
stage4_num_blocks
=
stage4_num_blocks
self
.
stage4_num_channels
=
stage4_num_channels
self
.
has_se
=
has_se
self
.
align_corners
=
align_corners
self
.
feat_channels
=
[
sum
(
stage4_num_channels
)]
self
.
conv_layer1_1
=
L
.
ConvBNReLU
(
in_channels
=
3
,
out_channels
=
64
,
kernel_size
=
3
,
stride
=
2
,
padding
=
'same'
,
bias_attr
=
False
)
self
.
conv_layer1_2
=
L
.
ConvBNReLU
(
in_channels
=
64
,
out_channels
=
64
,
kernel_size
=
3
,
stride
=
2
,
padding
=
'same'
,
bias_attr
=
False
)
self
.
la1
=
Layer1
(
num_channels
=
64
,
num_blocks
=
self
.
stage1_num_blocks
[
0
],
num_filters
=
self
.
stage1_num_channels
[
0
],
has_se
=
has_se
,
name
=
"layer2"
)
self
.
tr1
=
TransitionLayer
(
in_channels
=
[
self
.
stage1_num_channels
[
0
]
*
4
],
out_channels
=
self
.
stage2_num_channels
,
name
=
"tr1"
)
self
.
st2
=
Stage
(
num_channels
=
self
.
stage2_num_channels
,
num_modules
=
self
.
stage2_num_modules
,
num_blocks
=
self
.
stage2_num_blocks
,
num_filters
=
self
.
stage2_num_channels
,
has_se
=
self
.
has_se
,
name
=
"st2"
,
align_corners
=
align_corners
)
self
.
tr2
=
TransitionLayer
(
in_channels
=
self
.
stage2_num_channels
,
out_channels
=
self
.
stage3_num_channels
,
name
=
"tr2"
)
self
.
st3
=
Stage
(
num_channels
=
self
.
stage3_num_channels
,
num_modules
=
self
.
stage3_num_modules
,
num_blocks
=
self
.
stage3_num_blocks
,
num_filters
=
self
.
stage3_num_channels
,
has_se
=
self
.
has_se
,
name
=
"st3"
,
align_corners
=
align_corners
)
self
.
tr3
=
TransitionLayer
(
in_channels
=
self
.
stage3_num_channels
,
out_channels
=
self
.
stage4_num_channels
,
name
=
"tr3"
)
self
.
st4
=
Stage
(
num_channels
=
self
.
stage4_num_channels
,
num_modules
=
self
.
stage4_num_modules
,
num_blocks
=
self
.
stage4_num_blocks
,
num_filters
=
self
.
stage4_num_channels
,
has_se
=
self
.
has_se
,
name
=
"st4"
,
align_corners
=
align_corners
)
def
forward
(
self
,
x
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
conv1
=
self
.
conv_layer1_1
(
x
)
conv2
=
self
.
conv_layer1_2
(
conv1
)
la1
=
self
.
la1
(
conv2
)
tr1
=
self
.
tr1
([
la1
])
st2
=
self
.
st2
(
tr1
)
tr2
=
self
.
tr2
(
st2
)
st3
=
self
.
st3
(
tr2
)
tr3
=
self
.
tr3
(
st3
)
st4
=
self
.
st4
(
tr3
)
x0_h
,
x0_w
=
st4
[
0
].
shape
[
2
:]
x1
=
F
.
interpolate
(
st4
[
1
],
(
x0_h
,
x0_w
),
mode
=
'bilinear'
,
align_corners
=
self
.
align_corners
)
x2
=
F
.
interpolate
(
st4
[
2
],
(
x0_h
,
x0_w
),
mode
=
'bilinear'
,
align_corners
=
self
.
align_corners
)
x3
=
F
.
interpolate
(
st4
[
3
],
(
x0_h
,
x0_w
),
mode
=
'bilinear'
,
align_corners
=
self
.
align_corners
)
x
=
paddle
.
concat
([
st4
[
0
],
x1
,
x2
,
x3
],
axis
=
1
)
return
[
x
]
class
Layer1
(
nn
.
Layer
):
def
__init__
(
self
,
num_channels
:
int
,
num_filters
:
int
,
num_blocks
:
int
,
has_se
:
bool
=
False
,
name
:
str
=
None
):
super
(
Layer1
,
self
).
__init__
()
self
.
bottleneck_block_list
=
[]
for
i
in
range
(
num_blocks
):
bottleneck_block
=
self
.
add_sublayer
(
"bb_{}_{}"
.
format
(
name
,
i
+
1
),
BottleneckBlock
(
num_channels
=
num_channels
if
i
==
0
else
num_filters
*
4
,
num_filters
=
num_filters
,
has_se
=
has_se
,
stride
=
1
,
downsample
=
True
if
i
==
0
else
False
,
name
=
name
+
'_'
+
str
(
i
+
1
)))
self
.
bottleneck_block_list
.
append
(
bottleneck_block
)
def
forward
(
self
,
x
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
conv
=
x
for
block_func
in
self
.
bottleneck_block_list
:
conv
=
block_func
(
conv
)
return
conv
class
TransitionLayer
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
name
=
None
):
super
(
TransitionLayer
,
self
).
__init__
()
num_in
=
len
(
in_channels
)
num_out
=
len
(
out_channels
)
self
.
conv_bn_func_list
=
[]
for
i
in
range
(
num_out
):
residual
=
None
if
i
<
num_in
:
if
in_channels
[
i
]
!=
out_channels
[
i
]:
residual
=
self
.
add_sublayer
(
"transition_{}_layer_{}"
.
format
(
name
,
i
+
1
),
L
.
ConvBNReLU
(
in_channels
=
in_channels
[
i
],
out_channels
=
out_channels
[
i
],
kernel_size
=
3
,
padding
=
'same'
,
bias_attr
=
False
))
else
:
residual
=
self
.
add_sublayer
(
"transition_{}_layer_{}"
.
format
(
name
,
i
+
1
),
L
.
ConvBNReLU
(
in_channels
=
in_channels
[
-
1
],
out_channels
=
out_channels
[
i
],
kernel_size
=
3
,
stride
=
2
,
padding
=
'same'
,
bias_attr
=
False
))
self
.
conv_bn_func_list
.
append
(
residual
)
def
forward
(
self
,
x
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
outs
=
[]
for
idx
,
conv_bn_func
in
enumerate
(
self
.
conv_bn_func_list
):
if
conv_bn_func
is
None
:
outs
.
append
(
x
[
idx
])
else
:
if
idx
<
len
(
x
):
outs
.
append
(
conv_bn_func
(
x
[
idx
]))
else
:
outs
.
append
(
conv_bn_func
(
x
[
-
1
]))
return
outs
class
Branches
(
nn
.
Layer
):
def
__init__
(
self
,
num_blocks
:
int
,
in_channels
:
int
,
out_channels
:
int
,
has_se
:
bool
=
False
,
name
:
str
=
None
):
super
(
Branches
,
self
).
__init__
()
self
.
basic_block_list
=
[]
for
i
in
range
(
len
(
out_channels
)):
self
.
basic_block_list
.
append
([])
for
j
in
range
(
num_blocks
[
i
]):
in_ch
=
in_channels
[
i
]
if
j
==
0
else
out_channels
[
i
]
basic_block_func
=
self
.
add_sublayer
(
"bb_{}_branch_layer_{}_{}"
.
format
(
name
,
i
+
1
,
j
+
1
),
BasicBlock
(
num_channels
=
in_ch
,
num_filters
=
out_channels
[
i
],
has_se
=
has_se
,
name
=
name
+
'_branch_layer_'
+
str
(
i
+
1
)
+
'_'
+
str
(
j
+
1
)))
self
.
basic_block_list
[
i
].
append
(
basic_block_func
)
def
forward
(
self
,
x
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
outs
=
[]
for
idx
,
input
in
enumerate
(
x
):
conv
=
input
for
basic_block_func
in
self
.
basic_block_list
[
idx
]:
conv
=
basic_block_func
(
conv
)
outs
.
append
(
conv
)
return
outs
class
BottleneckBlock
(
nn
.
Layer
):
def
__init__
(
self
,
num_channels
:
int
,
num_filters
:
int
,
has_se
:
bool
,
stride
:
int
=
1
,
downsample
:
bool
=
False
,
name
:
str
=
None
):
super
(
BottleneckBlock
,
self
).
__init__
()
self
.
has_se
=
has_se
self
.
downsample
=
downsample
self
.
conv1
=
L
.
ConvBNReLU
(
in_channels
=
num_channels
,
out_channels
=
num_filters
,
kernel_size
=
1
,
padding
=
'same'
,
bias_attr
=
False
)
self
.
conv2
=
L
.
ConvBNReLU
(
in_channels
=
num_filters
,
out_channels
=
num_filters
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
'same'
,
bias_attr
=
False
)
self
.
conv3
=
L
.
ConvBN
(
in_channels
=
num_filters
,
out_channels
=
num_filters
*
4
,
kernel_size
=
1
,
padding
=
'same'
,
bias_attr
=
False
)
if
self
.
downsample
:
self
.
conv_down
=
L
.
ConvBN
(
in_channels
=
num_channels
,
out_channels
=
num_filters
*
4
,
kernel_size
=
1
,
padding
=
'same'
,
bias_attr
=
False
)
if
self
.
has_se
:
self
.
se
=
SELayer
(
num_channels
=
num_filters
*
4
,
num_filters
=
num_filters
*
4
,
reduction_ratio
=
16
,
name
=
name
+
'_fc'
)
def
forward
(
self
,
x
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
residual
=
x
conv1
=
self
.
conv1
(
x
)
conv2
=
self
.
conv2
(
conv1
)
conv3
=
self
.
conv3
(
conv2
)
if
self
.
downsample
:
residual
=
self
.
conv_down
(
x
)
if
self
.
has_se
:
conv3
=
self
.
se
(
conv3
)
y
=
conv3
+
residual
y
=
F
.
relu
(
y
)
return
y
class
BasicBlock
(
nn
.
Layer
):
def
__init__
(
self
,
num_channels
:
int
,
num_filters
:
int
,
stride
:
int
=
1
,
has_se
:
bool
=
False
,
downsample
:
bool
=
False
,
name
:
str
=
None
):
super
(
BasicBlock
,
self
).
__init__
()
self
.
has_se
=
has_se
self
.
downsample
=
downsample
self
.
conv1
=
L
.
ConvBNReLU
(
in_channels
=
num_channels
,
out_channels
=
num_filters
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
'same'
,
bias_attr
=
False
)
self
.
conv2
=
L
.
ConvBN
(
in_channels
=
num_filters
,
out_channels
=
num_filters
,
kernel_size
=
3
,
padding
=
'same'
,
bias_attr
=
False
)
if
self
.
downsample
:
self
.
conv_down
=
L
.
ConvBNReLU
(
in_channels
=
num_channels
,
out_channels
=
num_filters
,
kernel_size
=
1
,
padding
=
'same'
,
bias_attr
=
False
)
if
self
.
has_se
:
self
.
se
=
SELayer
(
num_channels
=
num_filters
,
num_filters
=
num_filters
,
reduction_ratio
=
16
,
name
=
name
+
'_fc'
)
def
forward
(
self
,
x
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
residual
=
x
conv1
=
self
.
conv1
(
x
)
conv2
=
self
.
conv2
(
conv1
)
if
self
.
downsample
:
residual
=
self
.
conv_down
(
x
)
if
self
.
has_se
:
conv2
=
self
.
se
(
conv2
)
y
=
conv2
+
residual
y
=
F
.
relu
(
y
)
return
y
class
SELayer
(
nn
.
Layer
):
def
__init__
(
self
,
num_channels
:
int
,
num_filters
:
int
,
reduction_ratio
:
int
,
name
:
str
=
None
):
super
(
SELayer
,
self
).
__init__
()
self
.
pool2d_gap
=
nn
.
AdaptiveAvgPool2D
(
1
)
self
.
_num_channels
=
num_channels
med_ch
=
int
(
num_channels
/
reduction_ratio
)
stdv
=
1.0
/
math
.
sqrt
(
num_channels
*
1.0
)
self
.
squeeze
=
nn
.
Linear
(
num_channels
,
med_ch
,
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Uniform
(
-
stdv
,
stdv
)))
stdv
=
1.0
/
math
.
sqrt
(
med_ch
*
1.0
)
self
.
excitation
=
nn
.
Linear
(
med_ch
,
num_filters
,
weight_attr
=
paddle
.
ParamAttr
(
initializer
=
nn
.
initializer
.
Uniform
(
-
stdv
,
stdv
)))
def
forward
(
self
,
x
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
pool
=
self
.
pool2d_gap
(
x
)
pool
=
paddle
.
reshape
(
pool
,
shape
=
[
-
1
,
self
.
_num_channels
])
squeeze
=
self
.
squeeze
(
pool
)
squeeze
=
F
.
relu
(
squeeze
)
excitation
=
self
.
excitation
(
squeeze
)
excitation
=
F
.
sigmoid
(
excitation
)
excitation
=
paddle
.
reshape
(
excitation
,
shape
=
[
-
1
,
self
.
_num_channels
,
1
,
1
])
out
=
x
*
excitation
return
out
class
Stage
(
nn
.
Layer
):
def
__init__
(
self
,
num_channels
:
int
,
num_modules
:
int
,
num_blocks
:
int
,
num_filters
:
int
,
has_se
:
bool
=
False
,
multi_scale_output
:
bool
=
True
,
name
:
str
=
None
,
align_corners
:
bool
=
False
):
super
(
Stage
,
self
).
__init__
()
self
.
_num_modules
=
num_modules
self
.
stage_func_list
=
[]
for
i
in
range
(
num_modules
):
if
i
==
num_modules
-
1
and
not
multi_scale_output
:
stage_func
=
self
.
add_sublayer
(
"stage_{}_{}"
.
format
(
name
,
i
+
1
),
HighResolutionModule
(
num_channels
=
num_channels
,
num_blocks
=
num_blocks
,
num_filters
=
num_filters
,
has_se
=
has_se
,
multi_scale_output
=
False
,
name
=
name
+
'_'
+
str
(
i
+
1
),
align_corners
=
align_corners
))
else
:
stage_func
=
self
.
add_sublayer
(
"stage_{}_{}"
.
format
(
name
,
i
+
1
),
HighResolutionModule
(
num_channels
=
num_channels
,
num_blocks
=
num_blocks
,
num_filters
=
num_filters
,
has_se
=
has_se
,
name
=
name
+
'_'
+
str
(
i
+
1
),
align_corners
=
align_corners
))
self
.
stage_func_list
.
append
(
stage_func
)
def
forward
(
self
,
x
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
out
=
x
for
idx
in
range
(
self
.
_num_modules
):
out
=
self
.
stage_func_list
[
idx
](
out
)
return
out
class
HighResolutionModule
(
nn
.
Layer
):
def
__init__
(
self
,
num_channels
:
int
,
num_blocks
:
int
,
num_filters
:
int
,
has_se
:
bool
=
False
,
multi_scale_output
:
bool
=
True
,
name
:
str
=
None
,
align_corners
:
str
=
False
):
super
(
HighResolutionModule
,
self
).
__init__
()
self
.
branches_func
=
Branches
(
num_blocks
=
num_blocks
,
in_channels
=
num_channels
,
out_channels
=
num_filters
,
has_se
=
has_se
,
name
=
name
)
self
.
fuse_func
=
FuseLayers
(
in_channels
=
num_filters
,
out_channels
=
num_filters
,
multi_scale_output
=
multi_scale_output
,
name
=
name
,
align_corners
=
align_corners
)
def
forward
(
self
,
x
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
out
=
self
.
branches_func
(
x
)
out
=
self
.
fuse_func
(
out
)
return
out
class
FuseLayers
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
multi_scale_output
:
bool
=
True
,
name
:
str
=
None
,
align_corners
:
bool
=
False
):
super
(
FuseLayers
,
self
).
__init__
()
self
.
_actual_ch
=
len
(
in_channels
)
if
multi_scale_output
else
1
self
.
_in_channels
=
in_channels
self
.
align_corners
=
align_corners
self
.
residual_func_list
=
[]
for
i
in
range
(
self
.
_actual_ch
):
for
j
in
range
(
len
(
in_channels
)):
if
j
>
i
:
residual_func
=
self
.
add_sublayer
(
"residual_{}_layer_{}_{}"
.
format
(
name
,
i
+
1
,
j
+
1
),
L
.
ConvBN
(
in_channels
=
in_channels
[
j
],
out_channels
=
out_channels
[
i
],
kernel_size
=
1
,
padding
=
'same'
,
bias_attr
=
False
))
self
.
residual_func_list
.
append
(
residual_func
)
elif
j
<
i
:
pre_num_filters
=
in_channels
[
j
]
for
k
in
range
(
i
-
j
):
if
k
==
i
-
j
-
1
:
residual_func
=
self
.
add_sublayer
(
"residual_{}_layer_{}_{}_{}"
.
format
(
name
,
i
+
1
,
j
+
1
,
k
+
1
),
L
.
ConvBN
(
in_channels
=
pre_num_filters
,
out_channels
=
out_channels
[
i
],
kernel_size
=
3
,
stride
=
2
,
padding
=
'same'
,
bias_attr
=
False
))
pre_num_filters
=
out_channels
[
i
]
else
:
residual_func
=
self
.
add_sublayer
(
"residual_{}_layer_{}_{}_{}"
.
format
(
name
,
i
+
1
,
j
+
1
,
k
+
1
),
L
.
ConvBNReLU
(
in_channels
=
pre_num_filters
,
out_channels
=
out_channels
[
j
],
kernel_size
=
3
,
stride
=
2
,
padding
=
'same'
,
bias_attr
=
False
))
pre_num_filters
=
out_channels
[
j
]
self
.
residual_func_list
.
append
(
residual_func
)
def
forward
(
self
,
x
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
outs
=
[]
residual_func_idx
=
0
for
i
in
range
(
self
.
_actual_ch
):
residual
=
x
[
i
]
residual_shape
=
residual
.
shape
[
-
2
:]
for
j
in
range
(
len
(
self
.
_in_channels
)):
if
j
>
i
:
y
=
self
.
residual_func_list
[
residual_func_idx
](
x
[
j
])
residual_func_idx
+=
1
y
=
F
.
interpolate
(
y
,
residual_shape
,
mode
=
'bilinear'
,
align_corners
=
self
.
align_corners
)
residual
=
residual
+
y
elif
j
<
i
:
y
=
x
[
j
]
for
k
in
range
(
i
-
j
):
y
=
self
.
residual_func_list
[
residual_func_idx
](
y
)
residual_func_idx
+=
1
residual
=
residual
+
y
residual
=
F
.
relu
(
residual
)
outs
.
append
(
residual
)
return
outs
\ No newline at end of file
modules/image/semantic_segmentation/ocrnet_hrnetw18_voc/layers.py
0 → 100644
浏览文件 @
6112bd38
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle.nn.layer
import
activation
from
paddle.nn
import
Conv2D
,
AvgPool2D
def
SyncBatchNorm
(
*
args
,
**
kwargs
):
"""In cpu environment nn.SyncBatchNorm does not have kernel so use nn.BatchNorm2D instead"""
if
paddle
.
get_device
()
==
'cpu'
:
return
nn
.
BatchNorm2D
(
*
args
,
**
kwargs
)
else
:
return
nn
.
SyncBatchNorm
(
*
args
,
**
kwargs
)
class
ConvBNLayer
(
nn
.
Layer
):
"""Basic conv bn relu layer."""
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
int
,
stride
:
int
=
1
,
dilation
:
int
=
1
,
groups
:
int
=
1
,
is_vd_mode
:
bool
=
False
,
act
:
str
=
None
,
name
:
str
=
None
):
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
is_vd_mode
=
is_vd_mode
self
.
_pool2d_avg
=
AvgPool2D
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
,
ceil_mode
=
True
)
self
.
_conv
=
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
(
kernel_size
-
1
)
//
2
if
dilation
==
1
else
0
,
dilation
=
dilation
,
groups
=
groups
,
bias_attr
=
False
)
self
.
_batch_norm
=
SyncBatchNorm
(
out_channels
)
self
.
_act_op
=
Activation
(
act
=
act
)
def
forward
(
self
,
inputs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
if
self
.
is_vd_mode
:
inputs
=
self
.
_pool2d_avg
(
inputs
)
y
=
self
.
_conv
(
inputs
)
y
=
self
.
_batch_norm
(
y
)
y
=
self
.
_act_op
(
y
)
return
y
class
BottleneckBlock
(
nn
.
Layer
):
"""Residual bottleneck block"""
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
stride
:
int
,
shortcut
:
bool
=
True
,
if_first
:
bool
=
False
,
dilation
:
int
=
1
,
name
:
str
=
None
):
super
(
BottleneckBlock
,
self
).
__init__
()
self
.
conv0
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
,
act
=
'relu'
,
name
=
name
+
"_branch2a"
)
self
.
dilation
=
dilation
self
.
conv1
=
ConvBNLayer
(
in_channels
=
out_channels
,
out_channels
=
out_channels
,
kernel_size
=
3
,
stride
=
stride
,
act
=
'relu'
,
dilation
=
dilation
,
name
=
name
+
"_branch2b"
)
self
.
conv2
=
ConvBNLayer
(
in_channels
=
out_channels
,
out_channels
=
out_channels
*
4
,
kernel_size
=
1
,
act
=
None
,
name
=
name
+
"_branch2c"
)
if
not
shortcut
:
self
.
short
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
out_channels
*
4
,
kernel_size
=
1
,
stride
=
1
,
is_vd_mode
=
False
if
if_first
or
stride
==
1
else
True
,
name
=
name
+
"_branch1"
)
self
.
shortcut
=
shortcut
def
forward
(
self
,
inputs
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
y
=
self
.
conv0
(
inputs
)
if
self
.
dilation
>
1
:
padding
=
self
.
dilation
y
=
F
.
pad
(
y
,
[
padding
,
padding
,
padding
,
padding
])
conv1
=
self
.
conv1
(
y
)
conv2
=
self
.
conv2
(
conv1
)
if
self
.
shortcut
:
short
=
inputs
else
:
short
=
self
.
short
(
inputs
)
y
=
paddle
.
add
(
x
=
short
,
y
=
conv2
)
y
=
F
.
relu
(
y
)
return
y
class
SeparableConvBNReLU
(
nn
.
Layer
):
"""Depthwise Separable Convolution."""
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
int
,
padding
:
str
=
'same'
,
**
kwargs
:
dict
):
super
(
SeparableConvBNReLU
,
self
).
__init__
()
self
.
depthwise_conv
=
ConvBN
(
in_channels
,
out_channels
=
in_channels
,
kernel_size
=
kernel_size
,
padding
=
padding
,
groups
=
in_channels
,
**
kwargs
)
self
.
piontwise_conv
=
ConvBNReLU
(
in_channels
,
out_channels
,
kernel_size
=
1
,
groups
=
1
)
def
forward
(
self
,
x
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
x
=
self
.
depthwise_conv
(
x
)
x
=
self
.
piontwise_conv
(
x
)
return
x
class
ConvBN
(
nn
.
Layer
):
"""Basic conv bn layer"""
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
int
,
padding
:
str
=
'same'
,
**
kwargs
:
dict
):
super
(
ConvBN
,
self
).
__init__
()
self
.
_conv
=
Conv2D
(
in_channels
,
out_channels
,
kernel_size
,
padding
=
padding
,
**
kwargs
)
self
.
_batch_norm
=
SyncBatchNorm
(
out_channels
)
def
forward
(
self
,
x
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
x
=
self
.
_conv
(
x
)
x
=
self
.
_batch_norm
(
x
)
return
x
class
ConvBNReLU
(
nn
.
Layer
):
"""Basic conv bn relu layer."""
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
int
,
padding
:
str
=
'same'
,
**
kwargs
:
dict
):
super
(
ConvBNReLU
,
self
).
__init__
()
self
.
_conv
=
Conv2D
(
in_channels
,
out_channels
,
kernel_size
,
padding
=
padding
,
**
kwargs
)
self
.
_batch_norm
=
SyncBatchNorm
(
out_channels
)
def
forward
(
self
,
x
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
x
=
self
.
_conv
(
x
)
x
=
self
.
_batch_norm
(
x
)
x
=
F
.
relu
(
x
)
return
x
class
Activation
(
nn
.
Layer
):
"""
The wrapper of activations.
Args:
act (str, optional): The activation name in lowercase. It must be one of ['elu', 'gelu',
'hardshrink', 'tanh', 'hardtanh', 'prelu', 'relu', 'relu6', 'selu', 'leakyrelu', 'sigmoid',
'softmax', 'softplus', 'softshrink', 'softsign', 'tanhshrink', 'logsigmoid', 'logsoftmax',
'hsigmoid']. Default: None, means identical transformation.
Returns:
A callable object of Activation.
Raises:
KeyError: When parameter `act` is not in the optional range.
Examples:
from paddleseg.models.common.activation import Activation
relu = Activation("relu")
print(relu)
# <class 'paddle.nn.layer.activation.ReLU'>
sigmoid = Activation("sigmoid")
print(sigmoid)
# <class 'paddle.nn.layer.activation.Sigmoid'>
not_exit_one = Activation("not_exit_one")
# KeyError: "not_exit_one does not exist in the current dict_keys(['elu', 'gelu', 'hardshrink',
# 'tanh', 'hardtanh', 'prelu', 'relu', 'relu6', 'selu', 'leakyrelu', 'sigmoid', 'softmax',
# 'softplus', 'softshrink', 'softsign', 'tanhshrink', 'logsigmoid', 'logsoftmax', 'hsigmoid'])"
"""
def
__init__
(
self
,
act
:
str
=
None
):
super
(
Activation
,
self
).
__init__
()
self
.
_act
=
act
upper_act_names
=
activation
.
__all__
lower_act_names
=
[
act
.
lower
()
for
act
in
upper_act_names
]
act_dict
=
dict
(
zip
(
lower_act_names
,
upper_act_names
))
if
act
is
not
None
:
if
act
in
act_dict
.
keys
():
act_name
=
act_dict
[
act
]
self
.
act_func
=
eval
(
"activation.{}()"
.
format
(
act_name
))
else
:
raise
KeyError
(
"{} does not exist in the current {}"
.
format
(
act
,
act_dict
.
keys
()))
def
forward
(
self
,
x
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
if
self
.
_act
is
not
None
:
return
self
.
act_func
(
x
)
else
:
return
x
class
ASPPModule
(
nn
.
Layer
):
"""
Atrous Spatial Pyramid Pooling.
Args:
aspp_ratios (tuple): The dilation rate using in ASSP module.
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature
is even, e.g. 1024x512, otherwise it is True, e.g. 769x769.
use_sep_conv (bool, optional): If using separable conv in ASPP module. Default: False.
image_pooling (bool, optional): If augmented with image-level features. Default: False
"""
def
__init__
(
self
,
aspp_ratios
,
in_channels
,
out_channels
,
align_corners
,
use_sep_conv
=
False
,
image_pooling
=
False
):
super
().
__init__
()
self
.
align_corners
=
align_corners
self
.
aspp_blocks
=
nn
.
LayerList
()
for
ratio
in
aspp_ratios
:
if
use_sep_conv
and
ratio
>
1
:
conv_func
=
SeparableConvBNReLU
else
:
conv_func
=
ConvBNReLU
block
=
conv_func
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
if
ratio
==
1
else
3
,
dilation
=
ratio
,
padding
=
0
if
ratio
==
1
else
ratio
)
self
.
aspp_blocks
.
append
(
block
)
out_size
=
len
(
self
.
aspp_blocks
)
if
image_pooling
:
self
.
global_avg_pool
=
nn
.
Sequential
(
nn
.
AdaptiveAvgPool2D
(
output_size
=
(
1
,
1
)),
ConvBNReLU
(
in_channels
,
out_channels
,
kernel_size
=
1
,
bias_attr
=
False
))
out_size
+=
1
self
.
image_pooling
=
image_pooling
self
.
conv_bn_relu
=
ConvBNReLU
(
in_channels
=
out_channels
*
out_size
,
out_channels
=
out_channels
,
kernel_size
=
1
)
self
.
dropout
=
nn
.
Dropout
(
p
=
0.1
)
# drop rate
def
forward
(
self
,
x
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
outputs
=
[]
for
block
in
self
.
aspp_blocks
:
y
=
block
(
x
)
y
=
F
.
interpolate
(
y
,
x
.
shape
[
2
:],
mode
=
'bilinear'
,
align_corners
=
self
.
align_corners
)
outputs
.
append
(
y
)
if
self
.
image_pooling
:
img_avg
=
self
.
global_avg_pool
(
x
)
img_avg
=
F
.
interpolate
(
img_avg
,
x
.
shape
[
2
:],
mode
=
'bilinear'
,
align_corners
=
self
.
align_corners
)
outputs
.
append
(
img_avg
)
x
=
paddle
.
concat
(
outputs
,
axis
=
1
)
x
=
self
.
conv_bn_relu
(
x
)
x
=
self
.
dropout
(
x
)
return
x
\ No newline at end of file
modules/image/semantic_segmentation/ocrnet_hrnetw18_voc/module.py
0 → 100644
浏览文件 @
6112bd38
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
typing
import
List
import
paddle
import
numpy
as
np
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddlehub.module.module
import
moduleinfo
import
paddlehub.vision.segmentation_transforms
as
T
from
paddlehub.module.cv_module
import
ImageSegmentationModule
import
ocrnet_hrnetw18_voc.layers
as
L
from
ocrnet_hrnetw18_voc.hrnet
import
HRNet_W18
@
moduleinfo
(
name
=
"ocrnet_hrnetw18_voc"
,
type
=
"CV/semantic_segmentation"
,
author
=
"paddlepaddle"
,
author_email
=
""
,
summary
=
"OCRNetHRNetW18 is a segmentation model pretrained by pascal voc."
,
version
=
"1.0.0"
,
meta
=
ImageSegmentationModule
)
class
OCRNetHRNetW18
(
nn
.
Layer
):
"""
The OCRNet implementation based on PaddlePaddle.
The original article refers to
Yuan, Yuhui, et al. "Object-Contextual Representations for Semantic Segmentation"
(https://arxiv.org/pdf/1909.11065.pdf)
Args:
num_classes (int): The unique number of target classes.
backbone_indices (list): A list indicates the indices of output of backbone.
It can be either one or two values, if two values, the first index will be taken as
a deep-supervision feature in auxiliary layer; the second one will be taken as
input of pixel representation. If one value, it is taken by both above.
ocr_mid_channels (int, optional): The number of middle channels in OCRHead. Default: 512.
ocr_key_channels (int, optional): The number of key channels in ObjectAttentionBlock. Default: 256.
align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature
is even, e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
pretrained (str, optional): The path or url of pretrained model. Default: None.
"""
def
__init__
(
self
,
num_classes
:
int
=
21
,
backbone_indices
:
List
[
int
]
=
[
0
],
ocr_mid_channels
:
int
=
512
,
ocr_key_channels
:
int
=
256
,
align_corners
:
bool
=
False
,
pretrained
:
str
=
None
):
super
(
OCRNetHRNetW18
,
self
).
__init__
()
self
.
backbone
=
HRNet_W18
()
self
.
backbone_indices
=
backbone_indices
in_channels
=
[
self
.
backbone
.
feat_channels
[
i
]
for
i
in
backbone_indices
]
self
.
head
=
OCRHead
(
num_classes
=
num_classes
,
in_channels
=
in_channels
,
ocr_mid_channels
=
ocr_mid_channels
,
ocr_key_channels
=
ocr_key_channels
)
self
.
align_corners
=
align_corners
self
.
transforms
=
T
.
Compose
([
T
.
Padding
(
target_size
=
(
512
,
512
)),
T
.
Normalize
()])
if
pretrained
is
not
None
:
model_dict
=
paddle
.
load
(
pretrained
)
self
.
set_dict
(
model_dict
)
print
(
"load custom parameters success"
)
else
:
checkpoint
=
os
.
path
.
join
(
self
.
directory
,
'ocrnet_hrnetw18.pdparams'
)
model_dict
=
paddle
.
load
(
checkpoint
)
self
.
set_dict
(
model_dict
)
print
(
"load pretrained parameters success"
)
def
transform
(
self
,
img
:
np
.
ndarray
)
->
np
.
ndarray
:
return
self
.
transforms
(
img
)
def
forward
(
self
,
x
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
feats
=
self
.
backbone
(
x
)
feats
=
[
feats
[
i
]
for
i
in
self
.
backbone_indices
]
logit_list
=
self
.
head
(
feats
)
logit_list
=
[
F
.
interpolate
(
logit
,
x
.
shape
[
2
:],
mode
=
'bilinear'
,
align_corners
=
self
.
align_corners
)
for
logit
in
logit_list
]
return
logit_list
class
OCRHead
(
nn
.
Layer
):
"""
The Object contextual representation head.
Args:
num_classes(int): The unique number of target classes.
in_channels(tuple): The number of input channels.
ocr_mid_channels(int, optional): The number of middle channels in OCRHead. Default: 512.
ocr_key_channels(int, optional): The number of key channels in ObjectAttentionBlock. Default: 256.
"""
def
__init__
(
self
,
num_classes
:
int
,
in_channels
:
int
,
ocr_mid_channels
:
int
=
512
,
ocr_key_channels
:
int
=
256
):
super
().
__init__
()
self
.
num_classes
=
num_classes
self
.
spatial_gather
=
SpatialGatherBlock
()
self
.
spatial_ocr
=
SpatialOCRModule
(
ocr_mid_channels
,
ocr_key_channels
,
ocr_mid_channels
)
self
.
indices
=
[
-
2
,
-
1
]
if
len
(
in_channels
)
>
1
else
[
-
1
,
-
1
]
self
.
conv3x3_ocr
=
L
.
ConvBNReLU
(
in_channels
[
self
.
indices
[
1
]],
ocr_mid_channels
,
3
,
padding
=
1
)
self
.
cls_head
=
nn
.
Conv2D
(
ocr_mid_channels
,
self
.
num_classes
,
1
)
self
.
aux_head
=
nn
.
Sequential
(
L
.
ConvBNReLU
(
in_channels
[
self
.
indices
[
0
]],
in_channels
[
self
.
indices
[
0
]],
1
),
nn
.
Conv2D
(
in_channels
[
self
.
indices
[
0
]],
self
.
num_classes
,
1
))
def
forward
(
self
,
feat_list
:
List
[
paddle
.
Tensor
])
->
paddle
.
Tensor
:
feat_shallow
,
feat_deep
=
feat_list
[
self
.
indices
[
0
]],
feat_list
[
self
.
indices
[
1
]]
soft_regions
=
self
.
aux_head
(
feat_shallow
)
pixels
=
self
.
conv3x3_ocr
(
feat_deep
)
object_regions
=
self
.
spatial_gather
(
pixels
,
soft_regions
)
ocr
=
self
.
spatial_ocr
(
pixels
,
object_regions
)
logit
=
self
.
cls_head
(
ocr
)
return
[
logit
,
soft_regions
]
class
SpatialGatherBlock
(
nn
.
Layer
):
"""Aggregation layer to compute the pixel-region representation."""
def
forward
(
self
,
pixels
:
paddle
.
Tensor
,
regions
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
n
,
c
,
h
,
w
=
pixels
.
shape
_
,
k
,
_
,
_
=
regions
.
shape
# pixels: from (n, c, h, w) to (n, h*w, c)
pixels
=
paddle
.
reshape
(
pixels
,
(
n
,
c
,
h
*
w
))
pixels
=
paddle
.
transpose
(
pixels
,
[
0
,
2
,
1
])
# regions: from (n, k, h, w) to (n, k, h*w)
regions
=
paddle
.
reshape
(
regions
,
(
n
,
k
,
h
*
w
))
regions
=
F
.
softmax
(
regions
,
axis
=
2
)
# feats: from (n, k, c) to (n, c, k, 1)
feats
=
paddle
.
bmm
(
regions
,
pixels
)
feats
=
paddle
.
transpose
(
feats
,
[
0
,
2
,
1
])
feats
=
paddle
.
unsqueeze
(
feats
,
axis
=-
1
)
return
feats
class
SpatialOCRModule
(
nn
.
Layer
):
"""Aggregate the global object representation to update the representation for each pixel."""
def
__init__
(
self
,
in_channels
:
int
,
key_channels
:
int
,
out_channels
:
int
,
dropout_rate
:
float
=
0.1
):
super
().
__init__
()
self
.
attention_block
=
ObjectAttentionBlock
(
in_channels
,
key_channels
)
self
.
conv1x1
=
nn
.
Sequential
(
L
.
ConvBNReLU
(
2
*
in_channels
,
out_channels
,
1
),
nn
.
Dropout2D
(
dropout_rate
))
def
forward
(
self
,
pixels
:
paddle
.
Tensor
,
regions
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
context
=
self
.
attention_block
(
pixels
,
regions
)
feats
=
paddle
.
concat
([
context
,
pixels
],
axis
=
1
)
feats
=
self
.
conv1x1
(
feats
)
return
feats
class
ObjectAttentionBlock
(
nn
.
Layer
):
"""A self-attention module."""
def
__init__
(
self
,
in_channels
:
int
,
key_channels
:
int
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
key_channels
=
key_channels
self
.
f_pixel
=
nn
.
Sequential
(
L
.
ConvBNReLU
(
in_channels
,
key_channels
,
1
),
L
.
ConvBNReLU
(
key_channels
,
key_channels
,
1
))
self
.
f_object
=
nn
.
Sequential
(
L
.
ConvBNReLU
(
in_channels
,
key_channels
,
1
),
L
.
ConvBNReLU
(
key_channels
,
key_channels
,
1
))
self
.
f_down
=
L
.
ConvBNReLU
(
in_channels
,
key_channels
,
1
)
self
.
f_up
=
L
.
ConvBNReLU
(
key_channels
,
in_channels
,
1
)
def
forward
(
self
,
x
:
paddle
.
Tensor
,
proxy
:
paddle
.
Tensor
)
->
paddle
.
Tensor
:
n
,
_
,
h
,
w
=
x
.
shape
# query : from (n, c1, h1, w1) to (n, h1*w1, key_channels)
query
=
self
.
f_pixel
(
x
)
query
=
paddle
.
reshape
(
query
,
(
n
,
self
.
key_channels
,
-
1
))
query
=
paddle
.
transpose
(
query
,
[
0
,
2
,
1
])
# key : from (n, c2, h2, w2) to (n, key_channels, h2*w2)
key
=
self
.
f_object
(
proxy
)
key
=
paddle
.
reshape
(
key
,
(
n
,
self
.
key_channels
,
-
1
))
# value : from (n, c2, h2, w2) to (n, h2*w2, key_channels)
value
=
self
.
f_down
(
proxy
)
value
=
paddle
.
reshape
(
value
,
(
n
,
self
.
key_channels
,
-
1
))
value
=
paddle
.
transpose
(
value
,
[
0
,
2
,
1
])
# sim_map (n, h1*w1, h2*w2)
sim_map
=
paddle
.
bmm
(
query
,
key
)
sim_map
=
(
self
.
key_channels
**-
.
5
)
*
sim_map
sim_map
=
F
.
softmax
(
sim_map
,
axis
=-
1
)
# context from (n, h1*w1, key_channels) to (n , out_channels, h1, w1)
context
=
paddle
.
bmm
(
sim_map
,
value
)
context
=
paddle
.
transpose
(
context
,
[
0
,
2
,
1
])
context
=
paddle
.
reshape
(
context
,
(
n
,
self
.
key_channels
,
h
,
w
))
context
=
self
.
f_up
(
context
)
return
context
\ No newline at end of file
paddlehub/datasets/__init__.py
浏览文件 @
6112bd38
...
@@ -18,3 +18,5 @@ from paddlehub.datasets.minicoco import MiniCOCO
...
@@ -18,3 +18,5 @@ from paddlehub.datasets.minicoco import MiniCOCO
from
paddlehub.datasets.chnsenticorp
import
ChnSentiCorp
from
paddlehub.datasets.chnsenticorp
import
ChnSentiCorp
from
paddlehub.datasets.msra_ner
import
MSRA_NER
from
paddlehub.datasets.msra_ner
import
MSRA_NER
from
paddlehub.datasets.lcqmc
import
LCQMC
from
paddlehub.datasets.lcqmc
import
LCQMC
from
paddlehub.datasets.base_seg_dataset
import
SegDataset
from
paddlehub.datasets.opticdiscseg
import
OpticDiscSeg
paddlehub/datasets/base_seg_dataset.py
0 → 100644
浏览文件 @
6112bd38
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
typing
import
Tuple
,
Callable
import
paddle
import
numpy
as
np
from
PIL
import
Image
class
SegDataset
(
paddle
.
io
.
Dataset
):
"""
Pass in a custom dataset that conforms to the format.
Args:
transforms (Callable): Transforms for image.
dataset_root (str): The dataset directory.
num_classes (int): Number of classes.
mode (str, optional): which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'.
train_path (str, optional): The train dataset file. When mode is 'train', train_path is necessary.
The contents of train_path file are as follow:
image1.jpg ground_truth1.png
image2.jpg ground_truth2.png
val_path (str. optional): The evaluation dataset file. When mode is 'val', val_path is necessary.
The contents is the same as train_path
test_path (str, optional): The test dataset file. When mode is 'test', test_path is necessary.
The annotation file is not necessary in test_path file.
separator (str, optional): The separator of dataset list. Default: ' '.
edge (bool, optional): Whether to compute edge while training. Default: False
"""
def
__init__
(
self
,
transforms
:
Callable
,
dataset_root
:
str
,
num_classes
:
int
,
mode
:
str
=
'train'
,
train_path
:
str
=
None
,
val_path
:
str
=
None
,
test_path
:
str
=
None
,
separator
:
str
=
' '
,
ignore_index
:
int
=
255
,
edge
:
bool
=
False
):
self
.
dataset_root
=
dataset_root
self
.
transforms
=
transforms
self
.
file_list
=
list
()
mode
=
mode
.
lower
()
self
.
mode
=
mode
self
.
num_classes
=
num_classes
self
.
ignore_index
=
ignore_index
self
.
edge
=
edge
if
mode
.
lower
()
not
in
[
'train'
,
'val'
,
'test'
]:
raise
ValueError
(
"mode should be 'train', 'val' or 'test', but got {}."
.
format
(
mode
))
if
self
.
transforms
is
None
:
raise
ValueError
(
"`transforms` is necessary, but it is None."
)
self
.
dataset_root
=
dataset_root
if
not
os
.
path
.
exists
(
self
.
dataset_root
):
raise
FileNotFoundError
(
'there is not `dataset_root`: {}.'
.
format
(
self
.
dataset_root
))
if
mode
==
'train'
:
if
train_path
is
None
:
raise
ValueError
(
'When `mode` is "train", `train_path` is necessary, but it is None.'
)
elif
not
os
.
path
.
exists
(
train_path
):
raise
FileNotFoundError
(
'`train_path` is not found: {}'
.
format
(
train_path
))
else
:
file_path
=
train_path
elif
mode
==
'val'
:
if
val_path
is
None
:
raise
ValueError
(
'When `mode` is "val", `val_path` is necessary, but it is None.'
)
elif
not
os
.
path
.
exists
(
val_path
):
raise
FileNotFoundError
(
'`val_path` is not found: {}'
.
format
(
val_path
))
else
:
file_path
=
val_path
else
:
if
test_path
is
None
:
raise
ValueError
(
'When `mode` is "test", `test_path` is necessary, but it is None.'
)
elif
not
os
.
path
.
exists
(
test_path
):
raise
FileNotFoundError
(
'`test_path` is not found: {}'
.
format
(
test_path
))
else
:
file_path
=
test_path
with
open
(
file_path
,
'r'
)
as
f
:
for
line
in
f
:
items
=
line
.
strip
().
split
(
separator
)
if
len
(
items
)
!=
2
:
if
mode
==
'train'
or
mode
==
'val'
:
raise
ValueError
(
"File list format incorrect! In training or evaluation task it should be"
" image_name{}label_name
\\
n"
.
format
(
separator
))
image_path
=
os
.
path
.
join
(
self
.
dataset_root
,
items
[
0
])
label_path
=
None
else
:
image_path
=
os
.
path
.
join
(
self
.
dataset_root
,
items
[
0
])
label_path
=
os
.
path
.
join
(
self
.
dataset_root
,
items
[
1
])
self
.
file_list
.
append
([
image_path
,
label_path
])
def
__getitem__
(
self
,
idx
:
int
)
->
Tuple
[
np
.
ndarray
]:
image_path
,
label_path
=
self
.
file_list
[
idx
]
if
self
.
mode
==
'test'
:
im
,
_
=
self
.
transforms
(
im
=
image_path
)
im
=
im
[
np
.
newaxis
,
...]
return
im
,
image_path
elif
self
.
mode
==
'val'
:
im
,
_
=
self
.
transforms
(
im
=
image_path
)
label
=
np
.
asarray
(
Image
.
open
(
label_path
))
label
=
label
[
np
.
newaxis
,
:,
:]
return
im
,
label
else
:
im
,
label
=
self
.
transforms
(
im
=
image_path
,
label
=
label_path
)
return
im
,
label
def
__len__
(
self
)
->
int
:
return
len
(
self
.
file_list
)
paddlehub/datasets/opticdiscseg.py
0 → 100644
浏览文件 @
6112bd38
# coding:utf-8
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
typing
import
Callable
import
paddle
import
numpy
as
np
from
PIL
import
Image
import
paddlehub.env
as
hubenv
from
paddlehub.utils.download
import
download_data
from
paddlehub.datasets.base_seg_dataset
import
SegDataset
@
download_data
(
url
=
'https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip'
)
class
OpticDiscSeg
(
SegDataset
):
"""
OpticDiscSeg dataset is extraced from iChallenge-AMD
(https://ai.baidu.com/broad/subordinate?dataset=amd).
Args:
transforms (Callable): Transforms for image.
mode (str, optional): Which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'.
edge (bool, optional): Whether to compute edge while training. Default: False
"""
def
__init__
(
self
,
transforms
:
Callable
=
None
,
mode
:
str
=
'train'
):
self
.
transforms
=
transforms
mode
=
mode
.
lower
()
self
.
mode
=
mode
self
.
file_list
=
list
()
self
.
num_classes
=
2
self
.
ignore_index
=
255
if
mode
not
in
[
'train'
,
'val'
,
'test'
]:
raise
ValueError
(
"`mode` should be 'train', 'val' or 'test', but got {}."
.
format
(
mode
))
if
self
.
transforms
is
None
:
raise
ValueError
(
"`transforms` is necessary, but it is None."
)
if
mode
==
'train'
:
file_path
=
os
.
path
.
join
(
hubenv
.
DATA_HOME
,
'optic_disc_seg'
,
'train_list.txt'
)
elif
mode
==
'test'
:
file_path
=
os
.
path
.
join
(
hubenv
.
DATA_HOME
,
'optic_disc_seg'
,
'test_list.txt'
)
else
:
file_path
=
os
.
path
.
join
(
hubenv
.
DATA_HOME
,
'optic_disc_seg'
,
'val_list.txt'
)
with
open
(
file_path
,
'r'
)
as
f
:
for
line
in
f
:
items
=
line
.
strip
().
split
()
if
len
(
items
)
!=
2
:
if
mode
==
'train'
or
mode
==
'val'
:
raise
Exception
(
"File list format incorrect! It should be"
" image_name label_name
\\
n"
)
image_path
=
os
.
path
.
join
(
hubenv
.
DATA_HOME
,
'optic_disc_seg'
,
items
[
0
])
grt_path
=
None
else
:
image_path
=
os
.
path
.
join
(
hubenv
.
DATA_HOME
,
'optic_disc_seg'
,
items
[
0
])
grt_path
=
os
.
path
.
join
(
hubenv
.
DATA_HOME
,
'optic_disc_seg'
,
items
[
1
])
self
.
file_list
.
append
([
image_path
,
grt_path
])
\ No newline at end of file
paddlehub/module/cv_module.py
浏览文件 @
6112bd38
...
@@ -17,7 +17,7 @@ import time
...
@@ -17,7 +17,7 @@ import time
import
os
import
os
import
base64
import
base64
import
argparse
import
argparse
from
typing
import
List
,
Union
from
typing
import
List
,
Union
,
Tuple
from
collections
import
OrderedDict
from
collections
import
OrderedDict
import
cv2
import
cv2
...
@@ -630,3 +630,112 @@ class StyleTransferModule(RunModule, ImageServing):
...
@@ -630,3 +630,112 @@ class StyleTransferModule(RunModule, ImageServing):
'--input_path'
,
type
=
str
,
help
=
"path to image."
)
'--input_path'
,
type
=
str
,
help
=
"path to image."
)
self
.
arg_input_group
.
add_argument
(
self
.
arg_input_group
.
add_argument
(
'--style_path'
,
type
=
str
,
help
=
"path to style image."
)
'--style_path'
,
type
=
str
,
help
=
"path to style image."
)
class
ImageSegmentationModule
(
ImageServing
,
RunModule
):
def
training_step
(
self
,
batch
:
List
[
paddle
.
Tensor
],
batch_idx
:
int
)
->
dict
:
'''
One step for training, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]): The one batch data, which contains images, ground truth boxes, labels and scores.
batch_idx(int): The index of batch.
Returns:
results(dict): The model outputs, such as loss.
'''
return
self
.
validation_step
(
batch
,
batch_idx
)
def
validation_step
(
self
,
batch
:
List
[
paddle
.
Tensor
],
batch_idx
:
int
)
->
dict
:
"""
One step for validation, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]): The one batch data, which contains images and labels.
batch_idx(int): The index of batch.
Returns:
results(dict) : The model outputs, such as metrics.
"""
label
=
batch
[
1
].
astype
(
'int64'
)
criterionCE
=
nn
.
loss
.
CrossEntropyLoss
()
logits
=
self
(
batch
[
0
])
loss
=
0
for
i
in
range
(
len
(
logits
)):
logit
=
logits
[
i
]
if
logit
.
shape
[
-
2
:]
!=
label
.
shape
[
-
2
:]:
logit
=
F
.
resize_bilinear
(
logit
,
label
.
shape
[
-
2
:])
logit
=
logit
.
transpose
([
0
,
2
,
3
,
1
])
loss_ce
=
criterionCE
(
logit
,
label
)
loss
+=
loss_ce
/
len
(
logits
)
return
{
"loss"
:
loss
}
def
predict
(
self
,
images
:
Union
[
str
,
np
.
ndarray
],
batch_size
:
int
=
1
,
visualization
:
bool
=
True
,
save_path
:
str
=
'seg_result'
)
->
List
[
np
.
ndarray
]:
'''
Obtain segmentation results.
Args:
images(list[str|np.array]): Content image path or BGR image.
batch_size(int): Batch size for prediciton.
visualization(bool): Whether to save colorized images.
save_path(str) : Path to save colorized images.
Returns:
output(list[np.ndarray]) : The segmentation mask.
'''
self
.
eval
()
result
=
[]
total_num
=
len
(
images
)
loop_num
=
int
(
np
.
ceil
(
total_num
/
batch_size
))
for
iter_id
in
range
(
loop_num
):
batch_data
=
[]
handle_id
=
iter_id
*
batch_size
for
image_id
in
range
(
batch_size
):
try
:
image
,
_
=
self
.
transform
(
images
[
handle_id
+
image_id
])
batch_data
.
append
(
image
)
except
:
pass
batch_image
=
np
.
array
(
batch_data
).
astype
(
'float32'
)
pred
=
self
(
paddle
.
to_tensor
(
batch_image
))
pred
=
paddle
.
argmax
(
pred
[
0
],
axis
=
1
,
keepdim
=
True
,
dtype
=
'int32'
)
for
num
in
range
(
pred
.
shape
[
0
]):
if
isinstance
(
images
[
handle_id
+
num
],
str
):
image
=
cv2
.
imread
(
images
[
handle_id
+
num
])
else
:
image
=
images
[
handle_id
+
num
]
h
,
w
,
c
=
image
.
shape
pred_final
=
utils
.
reverse_transform
(
pred
[
num
:
num
+
1
],
(
h
,
w
),
self
.
transforms
.
transforms
)
pred_final
=
paddle
.
squeeze
(
pred_final
)
pred_final
=
pred_final
.
numpy
().
astype
(
'uint8'
)
if
visualization
:
added_image
=
utils
.
visualize
(
images
[
handle_id
+
num
],
pred_final
,
weight
=
0.6
)
pred_mask
=
utils
.
get_pseudo_color_map
(
pred_final
)
pred_image_path
=
os
.
path
.
join
(
save_path
,
'image'
,
str
(
time
.
time
())
+
".png"
)
pred_mask_path
=
os
.
path
.
join
(
save_path
,
'mask'
,
str
(
time
.
time
())
+
".png"
)
if
not
os
.
path
.
exists
(
os
.
path
.
dirname
(
pred_image_path
)):
os
.
makedirs
(
os
.
path
.
dirname
(
pred_image_path
))
if
not
os
.
path
.
exists
(
os
.
path
.
dirname
(
pred_mask_path
)):
os
.
makedirs
(
os
.
path
.
dirname
(
pred_mask_path
))
cv2
.
imwrite
(
pred_image_path
,
added_image
)
pred_mask
.
save
(
pred_mask_path
)
result
.
append
(
pred_final
)
return
result
@
serving
def
serving_method
(
self
,
images
:
List
[
str
],
**
kwargs
):
"""
Run as a service.
"""
images_decode
=
[
base64_to_cv2
(
image
)
for
image
in
images
]
visual
=
self
.
predict
(
images
=
images_decode
,
**
kwargs
)
final
=
[]
for
mask
in
visual
:
final
.
append
(
cv2_to_base64
(
mask
))
return
final
\ No newline at end of file
paddlehub/vision/segmentation_transforms.py
0 → 100644
浏览文件 @
6112bd38
# coding: utf8
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
random
from
typing
import
Callable
,
Union
,
List
,
Tuple
import
cv2
import
numpy
as
np
from
PIL
import
Image
import
paddlehub.vision.functional
as
F
class
Compose
:
"""
Do transformation on input data with corresponding pre-processing and augmentation operations.
The shape of input data to all operations is [height, width, channels].
Args:
transforms (list): A list contains data pre-processing or augmentation.
to_rgb (bool, optional): If converting image to RGB color space. Default: True.
Raises:
TypeError: When 'transforms' is not a list.
ValueError: when the length of 'transforms' is less than 1.
"""
def
__init__
(
self
,
transforms
:
Callable
,
to_rgb
:
bool
=
True
):
if
not
isinstance
(
transforms
,
list
):
raise
TypeError
(
'The transforms must be a list!'
)
if
len
(
transforms
)
<
1
:
raise
ValueError
(
'The length of transforms '
+
\
'must be equal or larger than 1!'
)
self
.
transforms
=
transforms
self
.
to_rgb
=
to_rgb
def
__call__
(
self
,
im
:
Union
[
np
.
ndarray
,
str
],
label
:
Union
[
np
.
ndarray
,
str
]
=
None
)
->
Tuple
:
"""
Args:
im (str|np.ndarray): It is either image path or image object.
label (str|np.ndarray): It is either label path or label ndarray.
Returns:
(tuple). A tuple including image, image info, and label after transformation.
"""
if
isinstance
(
im
,
str
):
im
=
cv2
.
imread
(
im
).
astype
(
'float32'
)
if
isinstance
(
label
,
str
):
label
=
np
.
asarray
(
Image
.
open
(
label
))
if
im
is
None
:
raise
ValueError
(
'Can
\'
t read The image file {}!'
.
format
(
im
))
if
self
.
to_rgb
:
im
=
cv2
.
cvtColor
(
im
,
cv2
.
COLOR_BGR2RGB
)
for
op
in
self
.
transforms
:
outputs
=
op
(
im
,
label
)
im
=
outputs
[
0
]
if
len
(
outputs
)
==
2
:
label
=
outputs
[
1
]
im
=
np
.
transpose
(
im
,
(
2
,
0
,
1
))
return
(
im
,
label
)
class
ColorMap
:
"Calculate color map for mapping segmentation result."
def
__init__
(
self
,
num_classes
:
int
=
256
):
self
.
num_classes
=
num_classes
+
1
def
__call__
(
self
)
->
np
.
ndarray
:
color_map
=
self
.
num_classes
*
[
0
,
0
,
0
]
for
i
in
range
(
0
,
self
.
num_classes
):
j
=
0
lab
=
i
while
lab
:
color_map
[
i
*
3
]
|=
(((
lab
>>
0
)
&
1
)
<<
(
7
-
j
))
color_map
[
i
*
3
+
1
]
|=
(((
lab
>>
1
)
&
1
)
<<
(
7
-
j
))
color_map
[
i
*
3
+
2
]
|=
(((
lab
>>
2
)
&
1
)
<<
(
7
-
j
))
j
+=
1
lab
>>=
3
color_map
=
[
color_map
[
i
:
i
+
3
]
for
i
in
range
(
0
,
len
(
color_map
),
3
)]
color_map
=
color_map
[
1
:]
return
color_map
class
SegmentVisual
:
"""Visualization the segmentation result.
Args:
weight(float): weight of original image in combining image, default is 0.6.
"""
def
__init__
(
self
,
weight
:
float
=
0.6
):
self
.
weight
=
weight
self
.
get_color_map_list
=
ColorMap
(
256
)
def
__call__
(
self
,
image
:
str
,
result
:
np
.
ndarray
,
save_dir
:
str
)
->
np
.
ndarray
:
color_map
=
self
.
get_color_map_list
()
color_map
=
np
.
array
(
color_map
).
astype
(
"uint8"
)
# Use OpenCV LUT for color mapping
c1
=
cv2
.
LUT
(
result
,
color_map
[:,
0
])
c2
=
cv2
.
LUT
(
result
,
color_map
[:,
1
])
c3
=
cv2
.
LUT
(
result
,
color_map
[:,
2
])
pseudo_img
=
np
.
dstack
((
c1
,
c2
,
c3
))
im
=
cv2
.
imread
(
image
)
vis_result
=
cv2
.
addWeighted
(
im
,
self
.
weight
,
pseudo_img
,
1
-
self
.
weight
,
0
)
if
save_dir
is
not
None
:
if
not
os
.
path
.
exists
(
save_dir
):
os
.
makedirs
(
save_dir
)
image_name
=
os
.
path
.
split
(
image
)[
-
1
]
out_path
=
os
.
path
.
join
(
save_dir
,
image_name
)
cv2
.
imwrite
(
out_path
,
vis_result
)
return
vis_result
class
Padding
:
"""
Add bottom-right padding to a raw image or annotation image.
Args:
target_size (list|tuple): The target size after padding.
im_padding_value (list, optional): The padding value of raw image.
Default: [127.5, 127.5, 127.5].
label_padding_value (int, optional): The padding value of annotation image. Default: 255.
Raises:
TypeError: When target_size is neither list nor tuple.
ValueError: When the length of target_size is not 2.
"""
def
__init__
(
self
,
target_size
:
Union
[
List
[
int
],
Tuple
[
int
],
int
],
im_padding_value
:
Union
[
List
[
int
],
Tuple
[
int
],
int
]
=
(
128
,
128
,
128
),
label_padding_value
:
int
=
255
):
if
isinstance
(
target_size
,
list
)
or
isinstance
(
target_size
,
tuple
):
if
len
(
target_size
)
!=
2
:
raise
ValueError
(
'`target_size` should include 2 elements, but it is {}'
.
format
(
target_size
))
else
:
raise
TypeError
(
"Type of target_size is invalid. It should be list or tuple, now is {}"
.
format
(
type
(
target_size
)))
self
.
target_size
=
target_size
self
.
im_padding_value
=
im_padding_value
self
.
label_padding_value
=
label_padding_value
def
__call__
(
self
,
im
:
np
.
ndarray
,
label
:
np
.
ndarray
=
None
)
->
Tuple
:
"""
Args:
im (np.ndarray): The Image data.
label (np.ndarray, optional): The label data. Default: None.
Returns:
(tuple). When label is None, it returns (im, ), otherwise it returns (im, label).
"""
im_height
,
im_width
=
im
.
shape
[
0
],
im
.
shape
[
1
]
if
isinstance
(
self
.
target_size
,
int
):
target_height
=
self
.
target_size
target_width
=
self
.
target_size
else
:
target_height
=
self
.
target_size
[
1
]
target_width
=
self
.
target_size
[
0
]
pad_height
=
target_height
-
im_height
pad_width
=
target_width
-
im_width
if
pad_height
<
0
or
pad_width
<
0
:
raise
ValueError
(
'The size of image should be less than `target_size`, but the size of image ({}, {}) is larger than `target_size` ({}, {})'
.
format
(
im_width
,
im_height
,
target_width
,
target_height
))
else
:
im
=
cv2
.
copyMakeBorder
(
im
,
0
,
pad_height
,
0
,
pad_width
,
cv2
.
BORDER_CONSTANT
,
value
=
self
.
im_padding_value
)
if
label
is
not
None
:
label
=
cv2
.
copyMakeBorder
(
label
,
0
,
pad_height
,
0
,
pad_width
,
cv2
.
BORDER_CONSTANT
,
value
=
self
.
label_padding_value
)
if
label
is
None
:
return
(
im
,)
else
:
return
(
im
,
label
)
class
Normalize
:
"""
Normalize an image.
Args:
mean (list|tuple): The mean value of a data set. Default: [0.5, 0.5, 0.5].
std (list|tuple): The standard deviation of a data set. Default: [0.5, 0.5, 0.5].
Raises:
ValueError: When mean/std is not list or any value in std is 0.
"""
def
__init__
(
self
,
mean
:
Union
[
List
[
float
],
Tuple
[
float
]]
=
(
0.5
,
0.5
,
0.5
),
std
:
Union
[
List
[
float
],
Tuple
[
float
]]
=
(
0.5
,
0.5
,
0.5
)):
self
.
mean
=
mean
self
.
std
=
std
if
not
(
isinstance
(
self
.
mean
,
(
list
,
tuple
))
and
isinstance
(
self
.
std
,
(
list
,
tuple
))):
raise
ValueError
(
"{}: input type is invalid. It should be list or tuple"
.
format
(
self
))
from
functools
import
reduce
if
reduce
(
lambda
x
,
y
:
x
*
y
,
self
.
std
)
==
0
:
raise
ValueError
(
'{}: std is invalid!'
.
format
(
self
))
def
__call__
(
self
,
im
:
np
.
ndarray
,
label
:
np
.
ndarray
=
None
)
->
Tuple
:
"""
Args:
im (np.ndarray): The Image data.
label (np.ndarray, optional): The label data. Default: None.
Returns:
(tuple). When label is None, it returns (im, ), otherwise it returns (im, label).
"""
mean
=
np
.
array
(
self
.
mean
)[
np
.
newaxis
,
np
.
newaxis
,
:]
std
=
np
.
array
(
self
.
std
)[
np
.
newaxis
,
np
.
newaxis
,
:]
im
=
F
.
normalize
(
im
,
mean
,
std
)
if
label
is
None
:
return
(
im
,)
else
:
return
(
im
,
label
)
class
Resize
:
"""
Resize an image.
Args:
target_size (list|tuple, optional): The target size of image. Default: (512, 512).
interp (str, optional): The interpolation mode of resize is consistent with opencv.
['NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM']. Note that when it is
'RANDOM', a random interpolation mode would be specified. Default: "LINEAR".
Raises:
TypeError: When 'target_size' type is neither list nor tuple.
ValueError: When "interp" is out of pre-defined methods ('NEAREST', 'LINEAR', 'CUBIC',
'AREA', 'LANCZOS4', 'RANDOM').
"""
# The interpolation mode
interp_dict
=
{
'NEAREST'
:
cv2
.
INTER_NEAREST
,
'LINEAR'
:
cv2
.
INTER_LINEAR
,
'CUBIC'
:
cv2
.
INTER_CUBIC
,
'AREA'
:
cv2
.
INTER_AREA
,
'LANCZOS4'
:
cv2
.
INTER_LANCZOS4
}
def
__init__
(
self
,
target_size
:
Union
[
List
[
int
],
Tuple
[
int
]]
=
(
512
,
512
),
interp
:
str
=
'LINEAR'
):
self
.
interp
=
interp
if
not
(
interp
==
"RANDOM"
or
interp
in
self
.
interp_dict
):
raise
ValueError
(
"`interp` should be one of {}"
.
format
(
self
.
interp_dict
.
keys
()))
if
isinstance
(
target_size
,
list
)
or
isinstance
(
target_size
,
tuple
):
if
len
(
target_size
)
!=
2
:
raise
ValueError
(
'`target_size` should include 2 elements, but it is {}'
.
format
(
target_size
))
else
:
raise
TypeError
(
"Type of `target_size` is invalid. It should be list or tuple, but it is {}"
.
format
(
type
(
target_size
)))
self
.
target_size
=
target_size
def
__call__
(
self
,
im
:
np
.
ndarray
,
label
:
np
.
ndarray
=
None
)
->
Tuple
:
"""
Args:
im (np.ndarray): The Image data.
label (np.ndarray, optional): The label data. Default: None.
Returns:
(tuple). When label is None, it returns (im, ), otherwise it returns (im, label),
Raises:
TypeError: When the 'img' type is not numpy.
ValueError: When the length of "im" shape is not 3.
"""
if
not
isinstance
(
im
,
np
.
ndarray
):
raise
TypeError
(
"Resize: image type is not numpy."
)
if
len
(
im
.
shape
)
!=
3
:
raise
ValueError
(
'Resize: image is not 3-dimensional.'
)
if
self
.
interp
==
"RANDOM"
:
interp
=
random
.
choice
(
list
(
self
.
interp_dict
.
keys
()))
else
:
interp
=
self
.
interp
im
=
F
.
resize
(
im
,
self
.
target_size
,
self
.
interp_dict
[
interp
])
if
label
is
not
None
:
label
=
F
.
resize
(
label
,
self
.
target_size
,
cv2
.
INTER_NEAREST
)
if
label
is
None
:
return
(
im
,)
else
:
return
(
im
,
label
)
\ No newline at end of file
paddlehub/vision/utils.py
浏览文件 @
6112bd38
# Copyright (c) 202
0
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -13,11 +13,14 @@
...
@@ -13,11 +13,14 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
from
typing
import
Callable
,
Union
,
List
,
Tuple
import
cv2
import
paddle
import
paddle
import
PIL
import
PIL
import
numpy
as
np
import
numpy
as
np
import
matplotlib
as
plt
import
matplotlib
as
plt
import
paddle.nn.functional
as
F
def
is_image_file
(
filename
:
str
)
->
bool
:
def
is_image_file
(
filename
:
str
)
->
bool
:
...
@@ -26,7 +29,7 @@ def is_image_file(filename: str) -> bool:
...
@@ -26,7 +29,7 @@ def is_image_file(filename: str) -> bool:
return
ext
in
[
'.bmp'
,
'.dib'
,
'.png'
,
'.jpg'
,
'.jpeg'
,
'.pbm'
,
'.pgm'
,
'.ppm'
,
'.tif'
,
'.tiff'
]
return
ext
in
[
'.bmp'
,
'.dib'
,
'.png'
,
'.jpg'
,
'.jpeg'
,
'.pbm'
,
'.pgm'
,
'.ppm'
,
'.tif'
,
'.tiff'
]
def
get_img_file
(
dir_name
:
str
)
->
list
:
def
get_img_file
(
dir_name
:
str
)
->
List
[
str
]
:
'''Get all image file paths in several directories which have the same parent directory.'''
'''Get all image file paths in several directories which have the same parent directory.'''
images
=
[]
images
=
[]
for
parent
,
_
,
filenames
in
os
.
walk
(
dir_name
):
for
parent
,
_
,
filenames
in
os
.
walk
(
dir_name
):
...
@@ -39,7 +42,7 @@ def get_img_file(dir_name: str) -> list:
...
@@ -39,7 +42,7 @@ def get_img_file(dir_name: str) -> list:
return
images
return
images
def
box_crop
(
boxes
:
np
.
ndarray
,
labels
:
np
.
ndarray
,
scores
:
np
.
ndarray
,
crop
:
list
,
img_shape
:
list
)
:
def
box_crop
(
boxes
:
np
.
ndarray
,
labels
:
np
.
ndarray
,
scores
:
np
.
ndarray
,
crop
:
List
[
int
],
img_shape
:
List
[
int
])
->
Tuple
:
"""Crop the boxes ,labels, scores according to the given shape"""
"""Crop the boxes ,labels, scores according to the given shape"""
x
,
y
,
w
,
h
=
map
(
float
,
crop
)
x
,
y
,
w
,
h
=
map
(
float
,
crop
)
...
@@ -99,7 +102,7 @@ def draw_boxes_on_image(image_path: str,
...
@@ -99,7 +102,7 @@ def draw_boxes_on_image(image_path: str,
boxes
:
np
.
ndarray
,
boxes
:
np
.
ndarray
,
scores
:
np
.
ndarray
,
scores
:
np
.
ndarray
,
labels
:
np
.
ndarray
,
labels
:
np
.
ndarray
,
label_names
:
list
,
label_names
:
List
[
str
]
,
score_thresh
:
float
=
0.5
,
score_thresh
:
float
=
0.5
,
save_path
:
str
=
'result'
):
save_path
:
str
=
'result'
):
"""Draw boxes on images."""
"""Draw boxes on images."""
...
@@ -145,7 +148,7 @@ def draw_boxes_on_image(image_path: str,
...
@@ -145,7 +148,7 @@ def draw_boxes_on_image(image_path: str,
plt
.
close
(
'all'
)
plt
.
close
(
'all'
)
def
get_label_infos
(
file_list
:
str
):
def
get_label_infos
(
file_list
:
str
)
->
str
:
"""Get label names by corresponding category ids."""
"""Get label names by corresponding category ids."""
from
pycocotools.coco
import
COCO
from
pycocotools.coco
import
COCO
map_label
=
COCO
(
file_list
)
map_label
=
COCO
(
file_list
)
...
@@ -175,10 +178,115 @@ def gram_matrix(data: paddle.Tensor) -> paddle.Tensor:
...
@@ -175,10 +178,115 @@ def gram_matrix(data: paddle.Tensor) -> paddle.Tensor:
return
gram
return
gram
def
npmax
(
array
:
np
.
ndarray
):
def
npmax
(
array
:
np
.
ndarray
)
->
Tuple
[
int
]
:
"""Get max value and index."""
"""Get max value and index."""
arrayindex
=
array
.
argmax
(
1
)
arrayindex
=
array
.
argmax
(
1
)
arrayvalue
=
array
.
max
(
1
)
arrayvalue
=
array
.
max
(
1
)
i
=
arrayvalue
.
argmax
()
i
=
arrayvalue
.
argmax
()
j
=
arrayindex
[
i
]
j
=
arrayindex
[
i
]
return
i
,
j
return
i
,
j
def
visualize
(
image
:
Union
[
np
.
ndarray
,
str
],
result
:
np
.
ndarray
,
weight
:
float
=
0.6
)
->
np
.
ndarray
:
"""
Convert segmentation result to color image, and save added image.
Args:
image (str|np.ndarray): The path of origin image or bgr image.
result (np.ndarray): The predict result of image.
weight (float): The image weight of visual image, and the result weight is (1 - weight). Default: 0.6
Returns:
vis_result (np.ndarray): return the visualized result.
"""
color_map
=
get_color_map_list
(
256
)
color_map
=
[
color_map
[
i
:
i
+
3
]
for
i
in
range
(
0
,
len
(
color_map
),
3
)]
color_map
=
np
.
array
(
color_map
).
astype
(
"uint8"
)
# Use OpenCV LUT for color mapping
c1
=
cv2
.
LUT
(
result
,
color_map
[:,
0
])
c2
=
cv2
.
LUT
(
result
,
color_map
[:,
1
])
c3
=
cv2
.
LUT
(
result
,
color_map
[:,
2
])
pseudo_img
=
np
.
dstack
((
c1
,
c2
,
c3
))
if
isinstance
(
image
,
str
):
im
=
cv2
.
imread
(
image
)
else
:
im
=
image
vis_result
=
cv2
.
addWeighted
(
im
,
weight
,
pseudo_img
,
1
-
weight
,
0
)
return
vis_result
def
get_pseudo_color_map
(
pred
:
np
.
ndarray
)
->
PIL
.
Image
.
Image
:
'''visualization the segmentation mask.'''
pred_mask
=
PIL
.
Image
.
fromarray
(
pred
.
astype
(
np
.
uint8
),
mode
=
'P'
)
color_map
=
get_color_map_list
(
256
)
pred_mask
.
putpalette
(
color_map
)
return
pred_mask
def
get_color_map_list
(
num_classes
:
int
)
->
List
[
int
]:
"""
Returns the color map for visualizing the segmentation mask,
which can support arbitrary number of classes.
Args:
num_classes (int): Number of classes.
Returns:
(list). The color map.
"""
num_classes
+=
1
color_map
=
num_classes
*
[
0
,
0
,
0
]
for
i
in
range
(
0
,
num_classes
):
j
=
0
lab
=
i
while
lab
:
color_map
[
i
*
3
]
|=
(((
lab
>>
0
)
&
1
)
<<
(
7
-
j
))
color_map
[
i
*
3
+
1
]
|=
(((
lab
>>
1
)
&
1
)
<<
(
7
-
j
))
color_map
[
i
*
3
+
2
]
|=
(((
lab
>>
2
)
&
1
)
<<
(
7
-
j
))
j
+=
1
lab
>>=
3
color_map
=
color_map
[
3
:]
return
color_map
def
get_reverse_list
(
ori_shape
:
List
[
int
],
transforms
:
List
[
Callable
])
->
List
[
tuple
]:
"""
get reverse list of transform.
Args:
ori_shape (list): Origin shape of image.
transforms (list): List of transform.
Returns:
list: List of tuple, there are two format:
('resize', (h, w)) The image shape before resize,
('padding', (h, w)) The image shape before padding.
"""
reverse_list
=
[]
h
,
w
=
ori_shape
[
0
],
ori_shape
[
1
]
for
op
in
transforms
:
if
op
.
__class__
.
__name__
in
[
'Resize'
,
'ResizeByLong'
]:
reverse_list
.
append
((
'resize'
,
(
h
,
w
)))
h
,
w
=
op
.
target_size
[
0
],
op
.
target_size
[
1
]
if
op
.
__class__
.
__name__
in
[
'Padding'
]:
reverse_list
.
append
((
'padding'
,
(
h
,
w
)))
w
,
h
=
op
.
target_size
[
0
],
op
.
target_size
[
1
]
return
reverse_list
def
reverse_transform
(
pred
:
paddle
.
Tensor
,
ori_shape
:
List
[
int
],
transforms
:
List
[
int
])
->
paddle
.
Tensor
:
"""recover pred to origin shape"""
reverse_list
=
get_reverse_list
(
ori_shape
,
transforms
)
for
item
in
reverse_list
[::
-
1
]:
if
item
[
0
]
==
'resize'
:
h
,
w
=
item
[
1
][
0
],
item
[
1
][
1
]
pred
=
F
.
interpolate
(
pred
,
(
h
,
w
),
mode
=
'nearest'
)
elif
item
[
0
]
==
'padding'
:
h
,
w
=
item
[
1
][
0
],
item
[
1
][
1
]
pred
=
pred
[:,
:,
0
:
h
,
0
:
w
]
else
:
raise
Exception
(
"Unexpected info '{}' in im_info"
.
format
(
item
[
0
]))
return
pred
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录