未验证 提交 1e55536f 编写于 作者: H haoyuying 提交者: GitHub

add semantic segmentation task

上级 d76d727e
# PaddleHub 图像分割
本示例将展示如何使用PaddleHub对预训练模型进行finetune并完成预测任务。
## 如何开始Fine-tune
在完成安装PaddlePaddle与PaddleHub后,通过执行`python train.py`即可开始使用ocrnet_hrnetw18_voc模型对OpticDiscSeg等数据集进行Fine-tune。
## 代码步骤o
使用PaddleHub Fine-tune API进行Fine-tune可以分为4个步骤。
### Step1: 定义数据预处理方式
```python
from paddlehub.vision.segmentation_transforms import Compose, Resize, Normalize
transform = Compose([Resize(target_size=(512, 512)), Normalize()])
```
`segmentation_transforms` 数据增强模块定义了丰富的针对图像分割数据的预处理方式,用户可按照需求替换自己需要的数据预处理方式。
### Step2: 下载数据集并使用
```python
from paddlehub.datasets import OpticDiscSeg
train_reader = OpticDiscSeg(transform mode='train')
```
* `transform`: 数据预处理方式。
* `mode`: 选择数据模式,可选项有 `train`, `test`, `val`, 默认为`train`
数据集的准备代码可以参考 [opticdiscseg.py](../../paddlehub/datasets/opticdiscseg.py)`hub.datasets.OpticDiscSeg()`会自动从网络下载数据集并解压到用户目录下`$HOME/.paddlehub/dataset`目录。
### Step3: 加载预训练模型
```python
model = hub.Module(name='ocrnet_hrnetw18_voc', num_classes=2, pretrained=None)
```
* `name`: 选择预训练模型的名字。
* `num_classes`: 分割模型的类别数目。
* `pretrained`: 是否加载自己训练的模型,若为None,则加载提供的模型默认参数。
### Step4: 选择优化策略和运行配置
```python
scheduler = paddle.optimizer.lr.PolynomialDecay(learning_rate=0.01, decay_steps=1000, power=0.9, end_lr=0.0001)
optimizer = paddle.optimizer.Adam(learning_rate=scheduler, parameters=model.parameters())
trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_ocr', use_gpu=True)
```
#### 优化策略
Paddle2.0rc提供了多种优化器选择,如`SGD`, `Adam`, `Adamax`等,详细参见[策略](https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-rc/api/paddle/optimizer/optimizer/Optimizer_cn.html)
其中`Adam`:
* `learning_rate`: 全局学习率。
* `parameters`: 待优化模型参数。
#### 运行配置
`Trainer` 主要控制Fine-tune的训练,包含以下可控制的参数:
* `model`: 被优化模型;
* `optimizer`: 优化器选择;
* `use_gpu`: 是否使用gpu,默认为False;
* `use_vdl`: 是否使用vdl可视化训练过程;
* `checkpoint_dir`: 保存模型参数的地址;
* `compare_metrics`: 保存最优模型的衡量指标;
`trainer.train` 主要控制具体的训练过程,包含以下可控制的参数:
* `train_dataset`: 训练时所用的数据集;
* `epochs`: 训练轮数;
* `batch_size`: 训练的批大小,如果使用GPU,请根据实际情况调整batch_size;
* `num_workers`: works的数量,默认为0;
* `eval_dataset`: 验证集;
* `log_interval`: 打印日志的间隔, 单位为执行批训练的次数。
* `save_interval`: 保存模型的间隔频次,单位为执行训练的轮数。
## 模型预测
当完成Fine-tune后,Fine-tune过程在验证集上表现最优的模型会被保存在`${CHECKPOINT_DIR}/best_model`目录下,其中`${CHECKPOINT_DIR}`目录为Fine-tune时所选择的保存checkpoint的目录。
我们使用该模型来进行预测。predict.py脚本如下:
```python
import paddle
import cv2
import paddlehub as hub
if __name__ == '__main__':
model = hub.Module(name='ocrnet_hrnetw18_voc', pretrained='/PATH/TO/CHECKPOINT')
img = cv2.imread("/PATH/TO/IMAGE")
model.predict(images=[img], visualization=True)
```
参数配置正确后,请执行脚本`python predict.py`
**Args**
* `images`:原始图像路径或BGR格式图片;
* `visualization`: 是否可视化,默认为True;
* `save_path`: 保存结果的路径,默认保存路径为'seg_result'。
**NOTE:** 进行预测时,所选择的module,checkpoint_dir,dataset必须和Fine-tune所用的一样。
## 服务部署
PaddleHub Serving可以部署一个在线图像分割服务。
### Step1: 启动PaddleHub Serving
运行启动命令:
```shell
$ hub serving start -m ocrnet_hrnetw18_voc
```
这样就完成了一个图像分割服务化API的部署,默认端口号为8866。
**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA_VISIBLE_DEVICES环境变量,否则不用设置。
### Step2: 发送预测请求
配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
```python
import requests
import json
import cv2
import base64
import numpy as np
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8')
def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data
# 发送HTTP请求
org_im = cv2.imread('/PATH/TO/IMAGE')
data = {'images':[cv2_to_base64(org_im)]}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/ocrnet_hrnetw18_voc"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
mask = base64_to_cv2(r.json()["results"][0])
```
### 查看代码
https://github.com/PaddlePaddle/PaddleSeg
### 依赖
paddlepaddle >= 2.0.0rc
paddlehub >= 2.0.0
import paddle
import paddlehub as hub
if __name__ == '__main__':
model = hub.Module(name='ocrnet_hrnetw18_voc', num_classes=2, pretrained='/PATH/TO/CHECKPOINT')
model.predict(images=["N0007.jpg"], visualization=True)
\ No newline at end of file
import paddle
import paddlehub as hub
from paddlehub.finetune.trainer import Trainer
from paddlehub.datasets import OpticDiscSeg
from paddlehub.vision.segmentation_transforms import Compose, Resize, Normalize
if __name__ == "__main__":
transform = Compose([Resize(target_size=(512, 512)), Normalize()])
train_reader = OpticDiscSeg(transform)
model = hub.Module(name='ocrnet_hrnetw18_voc', num_classes=2)
scheduler = paddle.optimizer.lr.PolynomialDecay(learning_rate=0.01, decay_steps=1000, power=0.9, end_lr=0.0001)
optimizer = paddle.optimizer.Adam(learning_rate=scheduler, parameters=model.parameters())
trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_ocr', use_gpu=True)
trainer.train(train_reader, epochs=20, batch_size=4, eval_dataset=train_reader, log_interval=10, save_interval=4)
\ No newline at end of file
......@@ -39,3 +39,18 @@ Dataset for Style transfer. The dataset contains 2001 images for training set an
**Args**
* transforms(callmethod) : The method of preprocess images.
* mode(str): The mode for preparing dataset.
# Class `hub.datasets.OpticDiscSeg`
```python
hub.datasets.OpticDiscSeg(
transforms: Callable,
mode: str = 'train')
```
Dataset for semantic segmentation. The dataset contains 267 images for training set, 76 images for validation set and 38 images for testing set.
**Args**
* transforms(callmethod) : The method of preprocess images.
* mode(str): The mode for preparing dataset.
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.layer import activation
from paddle.nn import Conv2D, AvgPool2D
def SyncBatchNorm(*args, **kwargs):
"""In cpu environment nn.SyncBatchNorm does not have kernel so use nn.BatchNorm2D instead"""
if paddle.get_device() == 'cpu':
return nn.BatchNorm2D(*args, **kwargs)
else:
return nn.SyncBatchNorm(*args, **kwargs)
class ConvBNLayer(nn.Layer):
"""Basic conv bn relu layer."""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
is_vd_mode: bool = False,
act: str = None,
name: str = None):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self._pool2d_avg = AvgPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self._conv = Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2 if dilation == 1 else 0,
dilation=dilation,
groups=groups,
bias_attr=False)
self._batch_norm = SyncBatchNorm(out_channels)
self._act_op = Activation(act=act)
def forward(self, inputs: paddle.Tensor) -> paddle.Tensor:
if self.is_vd_mode:
inputs = self._pool2d_avg(inputs)
y = self._conv(inputs)
y = self._batch_norm(y)
y = self._act_op(y)
return y
class BottleneckBlock(nn.Layer):
"""Residual bottleneck block"""
def __init__(self,
in_channels: int,
out_channels: int,
stride: int,
shortcut: bool = True,
if_first: bool = False,
dilation: int = 1,
name: str = None):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
act='relu',
name=name + "_branch2a")
self.dilation = dilation
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
dilation=dilation,
name=name + "_branch2b")
self.conv2 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels * 4,
kernel_size=1,
act=None,
name=name + "_branch2c")
if not shortcut:
self.short = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels * 4,
kernel_size=1,
stride=1,
is_vd_mode=False if if_first or stride == 1 else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs: paddle.Tensor) -> paddle.Tensor:
y = self.conv0(inputs)
if self.dilation > 1:
padding = self.dilation
y = F.pad(y, [padding, padding, padding, padding])
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv2)
y = F.relu(y)
return y
class SeparableConvBNReLU(nn.Layer):
"""Depthwise Separable Convolution."""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: int,
padding: str = 'same',
**kwargs: dict):
super(SeparableConvBNReLU, self).__init__()
self.depthwise_conv = ConvBN(
in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
padding=padding,
groups=in_channels,
**kwargs)
self.piontwise_conv = ConvBNReLU(
in_channels, out_channels, kernel_size=1, groups=1)
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
x = self.depthwise_conv(x)
x = self.piontwise_conv(x)
return x
class ConvBN(nn.Layer):
"""Basic conv bn layer"""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: int,
padding: str = 'same',
**kwargs: dict):
super(ConvBN, self).__init__()
self._conv = Conv2D(
in_channels, out_channels, kernel_size, padding=padding, **kwargs)
self._batch_norm = SyncBatchNorm(out_channels)
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
x = self._conv(x)
x = self._batch_norm(x)
return x
class ConvBNReLU(nn.Layer):
"""Basic conv bn relu layer."""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: int,
padding: str = 'same',
**kwargs: dict):
super(ConvBNReLU, self).__init__()
self._conv = Conv2D(
in_channels, out_channels, kernel_size, padding=padding, **kwargs)
self._batch_norm = SyncBatchNorm(out_channels)
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
x = self._conv(x)
x = self._batch_norm(x)
x = F.relu(x)
return x
class Activation(nn.Layer):
"""
The wrapper of activations.
Args:
act (str, optional): The activation name in lowercase. It must be one of ['elu', 'gelu',
'hardshrink', 'tanh', 'hardtanh', 'prelu', 'relu', 'relu6', 'selu', 'leakyrelu', 'sigmoid',
'softmax', 'softplus', 'softshrink', 'softsign', 'tanhshrink', 'logsigmoid', 'logsoftmax',
'hsigmoid']. Default: None, means identical transformation.
Returns:
A callable object of Activation.
Raises:
KeyError: When parameter `act` is not in the optional range.
Examples:
from paddleseg.models.common.activation import Activation
relu = Activation("relu")
print(relu)
# <class 'paddle.nn.layer.activation.ReLU'>
sigmoid = Activation("sigmoid")
print(sigmoid)
# <class 'paddle.nn.layer.activation.Sigmoid'>
not_exit_one = Activation("not_exit_one")
# KeyError: "not_exit_one does not exist in the current dict_keys(['elu', 'gelu', 'hardshrink',
# 'tanh', 'hardtanh', 'prelu', 'relu', 'relu6', 'selu', 'leakyrelu', 'sigmoid', 'softmax',
# 'softplus', 'softshrink', 'softsign', 'tanhshrink', 'logsigmoid', 'logsoftmax', 'hsigmoid'])"
"""
def __init__(self, act: str = None):
super(Activation, self).__init__()
self._act = act
upper_act_names = activation.__all__
lower_act_names = [act.lower() for act in upper_act_names]
act_dict = dict(zip(lower_act_names, upper_act_names))
if act is not None:
if act in act_dict.keys():
act_name = act_dict[act]
self.act_func = eval("activation.{}()".format(act_name))
else:
raise KeyError("{} does not exist in the current {}".format(
act, act_dict.keys()))
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
if self._act is not None:
return self.act_func(x)
else:
return x
class ASPPModule(nn.Layer):
"""
Atrous Spatial Pyramid Pooling.
Args:
aspp_ratios (tuple): The dilation rate using in ASSP module.
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature
is even, e.g. 1024x512, otherwise it is True, e.g. 769x769.
use_sep_conv (bool, optional): If using separable conv in ASPP module. Default: False.
image_pooling (bool, optional): If augmented with image-level features. Default: False
"""
def __init__(self,
aspp_ratios: tuple,
in_channels: int,
out_channels: int,
align_corners: bool,
use_sep_conv: bool= False,
image_pooling: bool = False):
super().__init__()
self.align_corners = align_corners
self.aspp_blocks = nn.LayerList()
for ratio in aspp_ratios:
if use_sep_conv and ratio > 1:
conv_func = SeparableConvBNReLU
else:
conv_func = ConvBNReLU
block = conv_func(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1 if ratio == 1 else 3,
dilation=ratio,
padding=0 if ratio == 1 else ratio)
self.aspp_blocks.append(block)
out_size = len(self.aspp_blocks)
if image_pooling:
self.global_avg_pool = nn.Sequential(
nn.AdaptiveAvgPool2D(output_size=(1, 1)),
ConvBNReLU(in_channels, out_channels, kernel_size=1, bias_attr=False))
out_size += 1
self.image_pooling = image_pooling
self.conv_bn_relu = ConvBNReLU(
in_channels=out_channels * out_size,
out_channels=out_channels,
kernel_size=1)
self.dropout = nn.Dropout(p=0.1) # drop rate
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
outputs = []
for block in self.aspp_blocks:
y = block(x)
y = F.interpolate(
y,
x.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
outputs.append(y)
if self.image_pooling:
img_avg = self.global_avg_pool(x)
img_avg = F.interpolate(
img_avg,
x.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
outputs.append(img_avg)
x = paddle.concat(outputs, axis=1)
x = self.conv_bn_relu(x)
x = self.dropout(x)
return x
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Union, List, Tuple
import paddle
from paddle import nn
import paddle.nn.functional as F
import numpy as np
from paddlehub.module.module import moduleinfo
import paddlehub.vision.segmentation_transforms as T
from paddlehub.module.cv_module import ImageSegmentationModule
from deeplabv3p_resnet50.resnet import ResNet50_vd
import deeplabv3p_resnet50.layers as L
@moduleinfo(
name="deeplabv3p_resnet50_voc",
type="CV/semantic_segmentation",
author="paddlepaddle",
author_email="",
summary="DeepLabV3PResnet50 is a segmentation model.",
version="1.0.0",
meta=ImageSegmentationModule)
class DeepLabV3PResnet50(nn.Layer):
"""
The DeepLabV3PResnet50 implementation based on PaddlePaddle.
The original article refers to
Liang-Chieh Chen, et, al. "Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation"
(https://arxiv.org/abs/1802.02611)
Args:
num_classes (int): the unique number of target classes.
backbone_indices (tuple): two values in the tuple indicate the indices of output of backbone.
the first index will be taken as a low-level feature in Decoder component;
the second one will be taken as input of ASPP component.
Usually backbone consists of four downsampling stage, and return an output of
each stage, so we set default (0, 3), which means taking feature map of the first
stage in backbone as low-level feature used in Decoder, and feature map of the fourth
stage as input of ASPP.
aspp_ratios (tuple): the dilation rate using in ASSP module.
if output_stride=16, aspp_ratios should be set as (1, 6, 12, 18).
if output_stride=8, aspp_ratios is (1, 12, 24, 36).
aspp_out_channels (int): the output channels of ASPP module.
align_corners (bool, optional): An argument of F.interpolate. It should be set to False when the feature size is even,
e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
pretrained (str): the path of pretrained model. Default to None.
"""
def __init__(self,
num_classes: int = 21,
backbone_indices: Tuple[int] = (0, 3),
aspp_ratios: Tuple[int] = (1, 12, 24, 36),
aspp_out_channels: int = 256,
align_corners=False,
pretrained: str = None):
super(DeepLabV3PResnet50, self).__init__()
self.backbone = ResNet50_vd()
backbone_channels = [self.backbone.feat_channels[i] for i in backbone_indices]
self.head = DeepLabV3PHead(num_classes, backbone_indices,
backbone_channels, aspp_ratios,
aspp_out_channels, align_corners)
self.align_corners = align_corners
self.transforms = T.Compose([T.Padding(target_size=(512, 512)), T.Normalize()])
if pretrained is not None:
model_dict = paddle.load(pretrained)
self.set_dict(model_dict)
print("load custom parameters success")
else:
checkpoint = os.path.join(self.directory, 'deeplabv3p_model.pdparams')
model_dict = paddle.load(checkpoint)
self.set_dict(model_dict)
print("load pretrained parameters success")
def transform(self, img: Union[np.ndarray, str]) -> Union[np.ndarray, str]:
return self.transforms(img)
def forward(self, x: paddle.Tensor) -> List[paddle.Tensor]:
feat_list = self.backbone(x)
logit_list = self.head(feat_list)
return [
F.interpolate(
logit,
x.shape[2:],
mode='bilinear',
align_corners=self.align_corners) for logit in logit_list]
class DeepLabV3PHead(nn.Layer):
"""
The DeepLabV3PHead implementation based on PaddlePaddle.
Args:
num_classes (int): The unique number of target classes.
backbone_indices (tuple): Two values in the tuple indicate the indices of output of backbone.
the first index will be taken as a low-level feature in Decoder component;
the second one will be taken as input of ASPP component.
Usually backbone consists of four downsampling stage, and return an output of
each stage. If we set it as (0, 3), it means taking feature map of the first
stage in backbone as low-level feature used in Decoder, and feature map of the fourth
stage as input of ASPP.
backbone_channels (tuple): The same length with "backbone_indices". It indicates the channels of corresponding index.
aspp_ratios (tuple): The dilation rates using in ASSP module.
aspp_out_channels (int): The output channels of ASPP module.
align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature
is even, e.g. 1024x512, otherwise it is True, e.g. 769x769.
"""
def __init__(self, num_classes: int, backbone_indices: Tuple[paddle.Tensor], backbone_channels: Tuple[paddle.Tensor],
aspp_ratios: Tuple[float], aspp_out_channels: int, align_corners: bool):
super().__init__()
self.aspp = L.ASPPModule(
aspp_ratios,
backbone_channels[1],
aspp_out_channels,
align_corners,
use_sep_conv=True,
image_pooling=True)
self.decoder = Decoder(num_classes, backbone_channels[0], align_corners)
self.backbone_indices = backbone_indices
def forward(self, feat_list: List[paddle.Tensor]) -> List[paddle.Tensor]:
logit_list = []
low_level_feat = feat_list[self.backbone_indices[0]]
x = feat_list[self.backbone_indices[1]]
x = self.aspp(x)
logit = self.decoder(x, low_level_feat)
logit_list.append(logit)
return logit_list
class Decoder(nn.Layer):
"""
Decoder module of DeepLabV3P model
Args:
num_classes (int): The number of classes.
in_channels (int): The number of input channels in decoder module.
align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature is even, e.g. 1024x512, otherwise it is True, e.g. 769x769.
"""
def __init__(self, num_classes: int, in_channels: int, align_corners: bool):
super(Decoder, self).__init__()
self.conv_bn_relu1 = L.ConvBNReLU(
in_channels=in_channels, out_channels=48, kernel_size=1)
self.conv_bn_relu2 = L.SeparableConvBNReLU(
in_channels=304, out_channels=256, kernel_size=3, padding=1)
self.conv_bn_relu3 = L.SeparableConvBNReLU(
in_channels=256, out_channels=256, kernel_size=3, padding=1)
self.conv = nn.Conv2D(
in_channels=256, out_channels=num_classes, kernel_size=1)
self.align_corners = align_corners
def forward(self, x: paddle.Tensor, low_level_feat: paddle.Tensor) -> paddle.Tensor:
low_level_feat = self.conv_bn_relu1(low_level_feat)
x = F.interpolate(
x,
low_level_feat.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
x = paddle.concat([x, low_level_feat], axis=1)
x = self.conv_bn_relu2(x)
x = self.conv_bn_relu3(x)
x = self.conv(x)
return x
\ No newline at end of file
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import deeplabv3p_resnet50.layers as L
class BasicBlock(nn.Layer):
def __init__(self,
in_channels: int,
out_channels: int,
stride: int,
shortcut: bool = True,
if_first: bool = False,
name: str = None):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = L.ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
name=name + "_branch2a")
self.conv1 = L.ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
act=None,
name=name + "_branch2b")
if not shortcut:
self.short = L.ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
is_vd_mode=False if if_first else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs: paddle.Tensor) -> paddle.Tensor:
y = self.conv0(inputs)
conv1 = self.conv1(y)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.elementwise_add(x=short, y=conv1, act='relu')
return y
class ResNet50_vd(nn.Layer):
def __init__(self,
multi_grid: tuple = (1, 2, 4)):
super(ResNet50_vd, self).__init__()
depth = [3, 4, 6, 3]
num_channels = [64, 256, 512, 1024]
num_filters = [64, 128, 256, 512]
self.feat_channels = [c * 4 for c in num_filters]
dilation_dict = {2: 2, 3: 4}
self.conv1_1 = L.ConvBNLayer(
in_channels=3,
out_channels=32,
kernel_size=3,
stride=2,
act='relu',
name="conv1_1")
self.conv1_2 = L.ConvBNLayer(
in_channels=32,
out_channels=32,
kernel_size=3,
stride=1,
act='relu',
name="conv1_2")
self.conv1_3 = L.ConvBNLayer(
in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
act='relu',
name="conv1_3")
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
self.stage_list = []
for block in range(len(depth)):
shortcut = False
block_list = []
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
dilation_rate = dilation_dict[
block] if dilation_dict and block in dilation_dict else 1
if block == 3:
dilation_rate = dilation_rate * multi_grid[i]
bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
L.BottleneckBlock(
in_channels=num_channels[block]
if i == 0 else num_filters[block] * 4,
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0
and dilation_rate == 1 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name,
dilation=dilation_rate))
block_list.append(bottleneck_block)
shortcut = True
self.stage_list.append(block_list)
def forward(self, inputs: paddle.Tensor) -> paddle.Tensor:
y = self.conv1_1(inputs)
y = self.conv1_2(y)
y = self.conv1_3(y)
y = self.pool2d_max(y)
feat_list = []
for stage in self.stage_list:
for block in stage:
y = block(y)
feat_list.append(y)
return feat_list
\ No newline at end of file
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.layer import activation
from paddle.nn import Conv2D, AvgPool2D
def SyncBatchNorm(*args, **kwargs):
"""In cpu environment nn.SyncBatchNorm does not have kernel so use nn.BatchNorm2D instead"""
if paddle.get_device() == 'cpu':
return nn.BatchNorm2D(*args, **kwargs)
else:
return nn.SyncBatchNorm(*args, **kwargs)
class ConvBNLayer(nn.Layer):
"""Basic conv bn relu layer."""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
is_vd_mode: bool = False,
act: str = None,
name: str = None):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self._pool2d_avg = AvgPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self._conv = Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2 if dilation == 1 else 0,
dilation=dilation,
groups=groups,
bias_attr=False)
self._batch_norm = SyncBatchNorm(out_channels)
self._act_op = Activation(act=act)
def forward(self, inputs: paddle.Tensor) -> paddle.Tensor:
if self.is_vd_mode:
inputs = self._pool2d_avg(inputs)
y = self._conv(inputs)
y = self._batch_norm(y)
y = self._act_op(y)
return y
class BottleneckBlock(nn.Layer):
"""Residual bottleneck block"""
def __init__(self,
in_channels: int,
out_channels: int,
stride: int,
shortcut: bool = True,
if_first: bool = False,
dilation: int = 1,
name: str = None):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
act='relu',
name=name + "_branch2a")
self.dilation = dilation
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
dilation=dilation,
name=name + "_branch2b")
self.conv2 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels * 4,
kernel_size=1,
act=None,
name=name + "_branch2c")
if not shortcut:
self.short = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels * 4,
kernel_size=1,
stride=1,
is_vd_mode=False if if_first or stride == 1 else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs: paddle.Tensor) -> paddle.Tensor:
y = self.conv0(inputs)
if self.dilation > 1:
padding = self.dilation
y = F.pad(y, [padding, padding, padding, padding])
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv2)
y = F.relu(y)
return y
class SeparableConvBNReLU(nn.Layer):
"""Depthwise Separable Convolution."""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: int,
padding: str = 'same',
**kwargs: dict):
super(SeparableConvBNReLU, self).__init__()
self.depthwise_conv = ConvBN(
in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
padding=padding,
groups=in_channels,
**kwargs)
self.piontwise_conv = ConvBNReLU(
in_channels, out_channels, kernel_size=1, groups=1)
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
x = self.depthwise_conv(x)
x = self.piontwise_conv(x)
return x
class ConvBN(nn.Layer):
"""Basic conv bn layer"""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: int,
padding: str = 'same',
**kwargs: dict):
super(ConvBN, self).__init__()
self._conv = Conv2D(
in_channels, out_channels, kernel_size, padding=padding, **kwargs)
self._batch_norm = SyncBatchNorm(out_channels)
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
x = self._conv(x)
x = self._batch_norm(x)
return x
class ConvBNReLU(nn.Layer):
"""Basic conv bn relu layer."""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: int,
padding: str = 'same',
**kwargs: dict):
super(ConvBNReLU, self).__init__()
self._conv = Conv2D(
in_channels, out_channels, kernel_size, padding=padding, **kwargs)
self._batch_norm = SyncBatchNorm(out_channels)
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
x = self._conv(x)
x = self._batch_norm(x)
x = F.relu(x)
return x
class Activation(nn.Layer):
"""
The wrapper of activations.
Args:
act (str, optional): The activation name in lowercase. It must be one of ['elu', 'gelu',
'hardshrink', 'tanh', 'hardtanh', 'prelu', 'relu', 'relu6', 'selu', 'leakyrelu', 'sigmoid',
'softmax', 'softplus', 'softshrink', 'softsign', 'tanhshrink', 'logsigmoid', 'logsoftmax',
'hsigmoid']. Default: None, means identical transformation.
Returns:
A callable object of Activation.
Raises:
KeyError: When parameter `act` is not in the optional range.
Examples:
from paddleseg.models.common.activation import Activation
relu = Activation("relu")
print(relu)
# <class 'paddle.nn.layer.activation.ReLU'>
sigmoid = Activation("sigmoid")
print(sigmoid)
# <class 'paddle.nn.layer.activation.Sigmoid'>
not_exit_one = Activation("not_exit_one")
# KeyError: "not_exit_one does not exist in the current dict_keys(['elu', 'gelu', 'hardshrink',
# 'tanh', 'hardtanh', 'prelu', 'relu', 'relu6', 'selu', 'leakyrelu', 'sigmoid', 'softmax',
# 'softplus', 'softshrink', 'softsign', 'tanhshrink', 'logsigmoid', 'logsoftmax', 'hsigmoid'])"
"""
def __init__(self, act: str = None):
super(Activation, self).__init__()
self._act = act
upper_act_names = activation.__all__
lower_act_names = [act.lower() for act in upper_act_names]
act_dict = dict(zip(lower_act_names, upper_act_names))
if act is not None:
if act in act_dict.keys():
act_name = act_dict[act]
self.act_func = eval("activation.{}()".format(act_name))
else:
raise KeyError("{} does not exist in the current {}".format(
act, act_dict.keys()))
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
if self._act is not None:
return self.act_func(x)
else:
return x
class ASPPModule(nn.Layer):
"""
Atrous Spatial Pyramid Pooling.
Args:
aspp_ratios (tuple): The dilation rate using in ASSP module.
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature
is even, e.g. 1024x512, otherwise it is True, e.g. 769x769.
use_sep_conv (bool, optional): If using separable conv in ASPP module. Default: False.
image_pooling (bool, optional): If augmented with image-level features. Default: False
"""
def __init__(self,
aspp_ratios,
in_channels,
out_channels,
align_corners,
use_sep_conv=False,
image_pooling=False):
super().__init__()
self.align_corners = align_corners
self.aspp_blocks = nn.LayerList()
for ratio in aspp_ratios:
if use_sep_conv and ratio > 1:
conv_func = SeparableConvBNReLU
else:
conv_func = ConvBNReLU
block = conv_func(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1 if ratio == 1 else 3,
dilation=ratio,
padding=0 if ratio == 1 else ratio)
self.aspp_blocks.append(block)
out_size = len(self.aspp_blocks)
if image_pooling:
self.global_avg_pool = nn.Sequential(
nn.AdaptiveAvgPool2D(output_size=(1, 1)),
ConvBNReLU(in_channels, out_channels, kernel_size=1, bias_attr=False))
out_size += 1
self.image_pooling = image_pooling
self.conv_bn_relu = ConvBNReLU(
in_channels=out_channels * out_size,
out_channels=out_channels,
kernel_size=1)
self.dropout = nn.Dropout(p=0.1) # drop rate
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
outputs = []
for block in self.aspp_blocks:
y = block(x)
y = F.interpolate(
y,
x.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
outputs.append(y)
if self.image_pooling:
img_avg = self.global_avg_pool(x)
img_avg = F.interpolate(
img_avg,
x.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
outputs.append(img_avg)
x = paddle.concat(outputs, axis=1)
x = self.conv_bn_relu(x)
x = self.dropout(x)
return x
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import List
import paddle
import numpy as np
import paddle.nn as nn
import paddle.nn.functional as F
from paddlehub.module.module import moduleinfo
import paddlehub.vision.segmentation_transforms as T
from paddlehub.module.cv_module import ImageSegmentationModule
import ocrnet_hrnetw18.layers as L
from ocrnet_hrnetw18.hrnet import HRNet_W18
@moduleinfo(
name="ocrnet_hrnetw18_voc",
type="CV/semantic_segmentation",
author="paddlepaddle",
author_email="",
summary="OCRNetHRNetW18 is a segmentation model pretrained by pascal voc.",
version="1.0.0",
meta=ImageSegmentationModule)
class OCRNetHRNetW18(nn.Layer):
"""
The OCRNet implementation based on PaddlePaddle.
The original article refers to
Yuan, Yuhui, et al. "Object-Contextual Representations for Semantic Segmentation"
(https://arxiv.org/pdf/1909.11065.pdf)
Args:
num_classes (int): The unique number of target classes.
backbone_indices (list): A list indicates the indices of output of backbone.
It can be either one or two values, if two values, the first index will be taken as
a deep-supervision feature in auxiliary layer; the second one will be taken as
input of pixel representation. If one value, it is taken by both above.
ocr_mid_channels (int, optional): The number of middle channels in OCRHead. Default: 512.
ocr_key_channels (int, optional): The number of key channels in ObjectAttentionBlock. Default: 256.
align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature
is even, e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
pretrained (str, optional): The path or url of pretrained model. Default: None.
"""
def __init__(self,
num_classes: int = 21,
backbone_indices: List[int] = [0],
ocr_mid_channels: int = 512,
ocr_key_channels: int = 256,
align_corners: bool = False,
pretrained: str = None):
super(OCRNetHRNetW18, self).__init__()
self.backbone = HRNet_W18()
self.backbone_indices = backbone_indices
in_channels = [self.backbone.feat_channels[i] for i in backbone_indices]
self.head = OCRHead(
num_classes=num_classes,
in_channels=in_channels,
ocr_mid_channels=ocr_mid_channels,
ocr_key_channels=ocr_key_channels)
self.align_corners = align_corners
self.transforms = T.Compose([T.Padding(target_size=(512, 512)), T.Normalize()])
if pretrained is not None:
model_dict = paddle.load(pretrained)
self.set_dict(model_dict)
print("load custom parameters success")
else:
checkpoint = os.path.join(self.directory, 'ocrnet_hrnetw18.pdparams')
model_dict = paddle.load(checkpoint)
self.set_dict(model_dict)
print("load pretrained parameters success")
def transform(self, img: np.ndarray) -> np.ndarray:
return self.transforms(img)
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
feats = self.backbone(x)
feats = [feats[i] for i in self.backbone_indices]
logit_list = self.head(feats)
logit_list = [
F.interpolate(
logit,
x.shape[2:],
mode='bilinear',
align_corners=self.align_corners) for logit in logit_list]
return logit_list
class OCRHead(nn.Layer):
"""
The Object contextual representation head.
Args:
num_classes(int): The unique number of target classes.
in_channels(tuple): The number of input channels.
ocr_mid_channels(int, optional): The number of middle channels in OCRHead. Default: 512.
ocr_key_channels(int, optional): The number of key channels in ObjectAttentionBlock. Default: 256.
"""
def __init__(self,
num_classes: int,
in_channels: int,
ocr_mid_channels: int = 512,
ocr_key_channels: int = 256):
super().__init__()
self.num_classes = num_classes
self.spatial_gather = SpatialGatherBlock()
self.spatial_ocr = SpatialOCRModule(ocr_mid_channels, ocr_key_channels,
ocr_mid_channels)
self.indices = [-2, -1] if len(in_channels) > 1 else [-1, -1]
self.conv3x3_ocr = L.ConvBNReLU(
in_channels[self.indices[1]], ocr_mid_channels, 3, padding=1)
self.cls_head = nn.Conv2D(ocr_mid_channels, self.num_classes, 1)
self.aux_head = nn.Sequential(
L.ConvBNReLU(in_channels[self.indices[0]],
in_channels[self.indices[0]], 1),
nn.Conv2D(in_channels[self.indices[0]], self.num_classes, 1))
def forward(self, feat_list: List[paddle.Tensor]) -> paddle.Tensor:
feat_shallow, feat_deep = feat_list[self.indices[0]], feat_list[
self.indices[1]]
soft_regions = self.aux_head(feat_shallow)
pixels = self.conv3x3_ocr(feat_deep)
object_regions = self.spatial_gather(pixels, soft_regions)
ocr = self.spatial_ocr(pixels, object_regions)
logit = self.cls_head(ocr)
return [logit, soft_regions]
class SpatialGatherBlock(nn.Layer):
"""Aggregation layer to compute the pixel-region representation."""
def forward(self, pixels: paddle.Tensor, regions: paddle.Tensor) -> paddle.Tensor:
n, c, h, w = pixels.shape
_, k, _, _ = regions.shape
# pixels: from (n, c, h, w) to (n, h*w, c)
pixels = paddle.reshape(pixels, (n, c, h * w))
pixels = paddle.transpose(pixels, [0, 2, 1])
# regions: from (n, k, h, w) to (n, k, h*w)
regions = paddle.reshape(regions, (n, k, h * w))
regions = F.softmax(regions, axis=2)
# feats: from (n, k, c) to (n, c, k, 1)
feats = paddle.bmm(regions, pixels)
feats = paddle.transpose(feats, [0, 2, 1])
feats = paddle.unsqueeze(feats, axis=-1)
return feats
class SpatialOCRModule(nn.Layer):
"""Aggregate the global object representation to update the representation for each pixel."""
def __init__(self,
in_channels: int,
key_channels: int,
out_channels: int,
dropout_rate: float = 0.1):
super().__init__()
self.attention_block = ObjectAttentionBlock(in_channels, key_channels)
self.conv1x1 = nn.Sequential(
L.ConvBNReLU(2 * in_channels, out_channels, 1),
nn.Dropout2D(dropout_rate))
def forward(self, pixels: paddle.Tensor, regions: paddle.Tensor) -> paddle.Tensor:
context = self.attention_block(pixels, regions)
feats = paddle.concat([context, pixels], axis=1)
feats = self.conv1x1(feats)
return feats
class ObjectAttentionBlock(nn.Layer):
"""A self-attention module."""
def __init__(self, in_channels: int, key_channels: int):
super().__init__()
self.in_channels = in_channels
self.key_channels = key_channels
self.f_pixel = nn.Sequential(
L.ConvBNReLU(in_channels, key_channels, 1),
L.ConvBNReLU(key_channels, key_channels, 1))
self.f_object = nn.Sequential(
L.ConvBNReLU(in_channels, key_channels, 1),
L.ConvBNReLU(key_channels, key_channels, 1))
self.f_down = L.ConvBNReLU(in_channels, key_channels, 1)
self.f_up = L.ConvBNReLU(key_channels, in_channels, 1)
def forward(self, x: paddle.Tensor, proxy: paddle.Tensor) -> paddle.Tensor:
n, _, h, w = x.shape
# query : from (n, c1, h1, w1) to (n, h1*w1, key_channels)
query = self.f_pixel(x)
query = paddle.reshape(query, (n, self.key_channels, -1))
query = paddle.transpose(query, [0, 2, 1])
# key : from (n, c2, h2, w2) to (n, key_channels, h2*w2)
key = self.f_object(proxy)
key = paddle.reshape(key, (n, self.key_channels, -1))
# value : from (n, c2, h2, w2) to (n, h2*w2, key_channels)
value = self.f_down(proxy)
value = paddle.reshape(value, (n, self.key_channels, -1))
value = paddle.transpose(value, [0, 2, 1])
# sim_map (n, h1*w1, h2*w2)
sim_map = paddle.bmm(query, key)
sim_map = (self.key_channels**-.5) * sim_map
sim_map = F.softmax(sim_map, axis=-1)
# context from (n, h1*w1, key_channels) to (n , out_channels, h1, w1)
context = paddle.bmm(sim_map, value)
context = paddle.transpose(context, [0, 2, 1])
context = paddle.reshape(context, (n, self.key_channels, h, w))
context = self.f_up(context)
return context
\ No newline at end of file
......@@ -17,3 +17,5 @@ from paddlehub.datasets.flowers import Flowers
from paddlehub.datasets.minicoco import MiniCOCO
from paddlehub.datasets.chnsenticorp import ChnSentiCorp
from paddlehub.datasets.msra_ner import MSRA_NER
from paddlehub.datasets.base_seg_dataset import SegDataset
from paddlehub.datasets.opticdiscseg import OpticDiscSeg
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Tuple, Callable
import paddle
import numpy as np
from PIL import Image
class SegDataset(paddle.io.Dataset):
"""
Pass in a custom dataset that conforms to the format.
Args:
transforms (Callable): Transforms for image.
dataset_root (str): The dataset directory.
num_classes (int): Number of classes.
mode (str, optional): which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'.
train_path (str, optional): The train dataset file. When mode is 'train', train_path is necessary.
The contents of train_path file are as follow:
image1.jpg ground_truth1.png
image2.jpg ground_truth2.png
val_path (str. optional): The evaluation dataset file. When mode is 'val', val_path is necessary.
The contents is the same as train_path
test_path (str, optional): The test dataset file. When mode is 'test', test_path is necessary.
The annotation file is not necessary in test_path file.
separator (str, optional): The separator of dataset list. Default: ' '.
edge (bool, optional): Whether to compute edge while training. Default: False
"""
def __init__(self,
transforms: Callable,
dataset_root: str,
num_classes: int,
mode: str = 'train',
train_path: str = None,
val_path: str = None,
test_path: str = None,
separator: str = ' ',
ignore_index: int = 255,
edge: bool = False):
self.dataset_root = dataset_root
self.transforms = transforms
self.file_list = list()
mode = mode.lower()
self.mode = mode
self.num_classes = num_classes
self.ignore_index = ignore_index
self.edge = edge
if mode.lower() not in ['train', 'val', 'test']:
raise ValueError(
"mode should be 'train', 'val' or 'test', but got {}.".format(
mode))
if self.transforms is None:
raise ValueError("`transforms` is necessary, but it is None.")
self.dataset_root = dataset_root
if not os.path.exists(self.dataset_root):
raise FileNotFoundError('there is not `dataset_root`: {}.'.format(
self.dataset_root))
if mode == 'train':
if train_path is None:
raise ValueError(
'When `mode` is "train", `train_path` is necessary, but it is None.'
)
elif not os.path.exists(train_path):
raise FileNotFoundError(
'`train_path` is not found: {}'.format(train_path))
else:
file_path = train_path
elif mode == 'val':
if val_path is None:
raise ValueError(
'When `mode` is "val", `val_path` is necessary, but it is None.'
)
elif not os.path.exists(val_path):
raise FileNotFoundError(
'`val_path` is not found: {}'.format(val_path))
else:
file_path = val_path
else:
if test_path is None:
raise ValueError(
'When `mode` is "test", `test_path` is necessary, but it is None.'
)
elif not os.path.exists(test_path):
raise FileNotFoundError(
'`test_path` is not found: {}'.format(test_path))
else:
file_path = test_path
with open(file_path, 'r') as f:
for line in f:
items = line.strip().split(separator)
if len(items) != 2:
if mode == 'train' or mode == 'val':
raise ValueError(
"File list format incorrect! In training or evaluation task it should be"
" image_name{}label_name\\n".format(separator))
image_path = os.path.join(self.dataset_root, items[0])
label_path = None
else:
image_path = os.path.join(self.dataset_root, items[0])
label_path = os.path.join(self.dataset_root, items[1])
self.file_list.append([image_path, label_path])
def __getitem__(self, idx: int) -> Tuple[np.ndarray]:
image_path, label_path = self.file_list[idx]
if self.mode == 'test':
im, _ = self.transforms(im=image_path)
im = im[np.newaxis, ...]
return im, image_path
elif self.mode == 'val':
im, _ = self.transforms(im=image_path)
label = np.asarray(Image.open(label_path))
label = label[np.newaxis, :, :]
return im, label
else:
im, label = self.transforms(im=image_path, label=label_path)
return im, label
def __len__(self) -> int:
return len(self.file_list)
# coding:utf-8
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Callable
import paddle
import numpy as np
from PIL import Image
import paddlehub.env as hubenv
from paddlehub.utils.download import download_data
from paddlehub.datasets.base_seg_dataset import SegDataset
@download_data(url='https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip')
class OpticDiscSeg(SegDataset):
"""
OpticDiscSeg dataset is extraced from iChallenge-AMD
(https://ai.baidu.com/broad/subordinate?dataset=amd).
Args:
transforms (Callable): Transforms for image.
mode (str, optional): Which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'.
edge (bool, optional): Whether to compute edge while training. Default: False
"""
def __init__(self,
transforms: Callable = None,
mode: str = 'train'):
self.transforms = transforms
mode = mode.lower()
self.mode = mode
self.file_list = list()
self.num_classes = 2
self.ignore_index = 255
if mode not in ['train', 'val', 'test']:
raise ValueError(
"`mode` should be 'train', 'val' or 'test', but got {}.".format(
mode))
if self.transforms is None:
raise ValueError("`transforms` is necessary, but it is None.")
if mode == 'train':
file_path = os.path.join(hubenv.DATA_HOME, 'optic_disc_seg', 'train_list.txt')
elif mode == 'test':
file_path = os.path.join(hubenv.DATA_HOME, 'optic_disc_seg', 'test_list.txt')
else:
file_path = os.path.join(hubenv.DATA_HOME, 'optic_disc_seg', 'val_list.txt')
with open(file_path, 'r') as f:
for line in f:
items = line.strip().split()
if len(items) != 2:
if mode == 'train' or mode == 'val':
raise Exception(
"File list format incorrect! It should be"
" image_name label_name\\n")
image_path = os.path.join(hubenv.DATA_HOME, 'optic_disc_seg', items[0])
grt_path = None
else:
image_path = os.path.join(hubenv.DATA_HOME, 'optic_disc_seg', items[0])
grt_path = os.path.join(hubenv.DATA_HOME, 'optic_disc_seg', items[1])
self.file_list.append([image_path, grt_path])
\ No newline at end of file
......@@ -17,7 +17,7 @@ import time
import os
import base64
import argparse
from typing import List, Union
from typing import List, Union, Tuple
from collections import OrderedDict
import cv2
......@@ -629,4 +629,113 @@ class StyleTransferModule(RunModule, ImageServing):
self.arg_input_group.add_argument(
'--input_path', type=str, help="path to image.")
self.arg_input_group.add_argument(
'--style_path', type=str, help="path to style image.")
\ No newline at end of file
'--style_path', type=str, help="path to style image.")
class ImageSegmentationModule(ImageServing, RunModule):
def training_step(self, batch: List[paddle.Tensor], batch_idx: int) -> dict:
'''
One step for training, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]): The one batch data, which contains images, ground truth boxes, labels and scores.
batch_idx(int): The index of batch.
Returns:
results(dict): The model outputs, such as loss.
'''
return self.validation_step(batch, batch_idx)
def validation_step(self, batch: List[paddle.Tensor], batch_idx: int) -> dict:
"""
One step for validation, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]): The one batch data, which contains images and labels.
batch_idx(int): The index of batch.
Returns:
results(dict) : The model outputs, such as metrics.
"""
label = batch[1].astype('int64')
criterionCE = nn.loss.CrossEntropyLoss()
logits = self(batch[0])
loss = 0
for i in range(len(logits)):
logit = logits[i]
if logit.shape[-2:] != label.shape[-2:]:
logit = F.resize_bilinear(logit, label.shape[-2:])
logit = logit.transpose([0,2,3,1])
loss_ce = criterionCE(logit, label)
loss += loss_ce / len(logits)
return {"loss": loss}
def predict(self, images: Union[str, np.ndarray], batch_size: int = 1, visualization: bool = True, save_path: str = 'seg_result') -> List[np.ndarray]:
'''
Obtain segmentation results.
Args:
images(list[str|np.array]): Content image path or BGR image.
batch_size(int): Batch size for prediciton.
visualization(bool): Whether to save colorized images.
save_path(str) : Path to save colorized images.
Returns:
output(list[np.ndarray]) : The segmentation mask.
'''
self.eval()
result=[]
total_num = len(images)
loop_num = int(np.ceil(total_num / batch_size))
for iter_id in range(loop_num):
batch_data = []
handle_id = iter_id * batch_size
for image_id in range(batch_size):
try:
image, _ = self.transform(images[handle_id + image_id])
batch_data.append(image)
except:
pass
batch_image = np.array(batch_data).astype('float32')
pred = self(paddle.to_tensor(batch_image))
pred = paddle.argmax(pred[0], axis=1, keepdim=True, dtype='int32')
for num in range(pred.shape[0]):
if isinstance(images[handle_id+num], str):
image = cv2.imread(images[handle_id+num])
else:
image = images[handle_id+num]
h, w, c = image.shape
pred_final = utils.reverse_transform(pred[num: num+1], (h,w), self.transforms.transforms)
pred_final = paddle.squeeze(pred_final)
pred_final = pred_final.numpy().astype('uint8')
if visualization:
added_image = utils.visualize(images[handle_id+num], pred_final, weight=0.6)
pred_mask = utils.get_pseudo_color_map(pred_final)
pred_image_path = os.path.join(save_path, 'image', str(time.time()) + ".png")
pred_mask_path = os.path.join(save_path, 'mask', str(time.time()) + ".png")
if not os.path.exists(os.path.dirname(pred_image_path)):
os.makedirs(os.path.dirname(pred_image_path))
if not os.path.exists(os.path.dirname(pred_mask_path)):
os.makedirs(os.path.dirname(pred_mask_path))
cv2.imwrite(pred_image_path, added_image)
pred_mask.save(pred_mask_path)
result.append(pred_final)
return result
@serving
def serving_method(self, images: List[str], **kwargs):
"""
Run as a service.
"""
images_decode = [base64_to_cv2(image) for image in images]
visual = self.predict(images=images_decode, **kwargs)
final=[]
for mask in visual:
final.append(cv2_to_base64(mask))
return final
\ No newline at end of file
# coding: utf8
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
from typing import Callable, Union, List, Tuple
import cv2
import numpy as np
from PIL import Image
import paddlehub.vision.functional as F
class Compose:
"""
Do transformation on input data with corresponding pre-processing and augmentation operations.
The shape of input data to all operations is [height, width, channels].
Args:
transforms (list): A list contains data pre-processing or augmentation.
to_rgb (bool, optional): If converting image to RGB color space. Default: True.
Raises:
TypeError: When 'transforms' is not a list.
ValueError: when the length of 'transforms' is less than 1.
"""
def __init__(self, transforms: Callable, to_rgb: bool = True):
if not isinstance(transforms, list):
raise TypeError('The transforms must be a list!')
if len(transforms) < 1:
raise ValueError('The length of transforms ' + \
'must be equal or larger than 1!')
self.transforms = transforms
self.to_rgb = to_rgb
def __call__(self, im: Union[np.ndarray, str], label: Union[np.ndarray, str] = None) -> Tuple:
"""
Args:
im (str|np.ndarray): It is either image path or image object.
label (str|np.ndarray): It is either label path or label ndarray.
Returns:
(tuple). A tuple including image, image info, and label after transformation.
"""
if isinstance(im, str):
im = cv2.imread(im).astype('float32')
if isinstance(label, str):
label = np.asarray(Image.open(label))
if im is None:
raise ValueError('Can\'t read The image file {}!'.format(im))
if self.to_rgb:
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
for op in self.transforms:
outputs = op(im, label)
im = outputs[0]
if len(outputs) == 2:
label = outputs[1]
im = np.transpose(im, (2, 0, 1))
return (im, label)
class ColorMap:
"Calculate color map for mapping segmentation result."
def __init__(self, num_classes: int = 256):
self.num_classes = num_classes + 1
def __call__(self) -> np.ndarray:
color_map = self.num_classes * [0, 0, 0]
for i in range(0, self.num_classes):
j = 0
lab = i
while lab:
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
j += 1
lab >>= 3
color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
color_map = color_map[1:]
return color_map
class SegmentVisual:
"""Visualization the segmentation result.
Args:
weight(float): weight of original image in combining image, default is 0.6.
"""
def __init__(self, weight: float = 0.6):
self.weight = weight
self.get_color_map_list = ColorMap(256)
def __call__(self, image: str, result: np.ndarray, save_dir: str) -> np.ndarray:
color_map = self.get_color_map_list()
color_map = np.array(color_map).astype("uint8")
# Use OpenCV LUT for color mapping
c1 = cv2.LUT(result, color_map[:, 0])
c2 = cv2.LUT(result, color_map[:, 1])
c3 = cv2.LUT(result, color_map[:, 2])
pseudo_img = np.dstack((c1, c2, c3))
im = cv2.imread(image)
vis_result = cv2.addWeighted(im, self.weight, pseudo_img, 1 - self.weight, 0)
if save_dir is not None:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
image_name = os.path.split(image)[-1]
out_path = os.path.join(save_dir, image_name)
cv2.imwrite(out_path, vis_result)
return vis_result
class Padding:
"""
Add bottom-right padding to a raw image or annotation image.
Args:
target_size (list|tuple): The target size after padding.
im_padding_value (list, optional): The padding value of raw image.
Default: [127.5, 127.5, 127.5].
label_padding_value (int, optional): The padding value of annotation image. Default: 255.
Raises:
TypeError: When target_size is neither list nor tuple.
ValueError: When the length of target_size is not 2.
"""
def __init__(self,
target_size: Union[List[int], Tuple[int], int],
im_padding_value: Union[List[int], Tuple[int], int] = (128, 128, 128),
label_padding_value: int = 255):
if isinstance(target_size, list) or isinstance(target_size, tuple):
if len(target_size) != 2:
raise ValueError(
'`target_size` should include 2 elements, but it is {}'.
format(target_size))
else:
raise TypeError(
"Type of target_size is invalid. It should be list or tuple, now is {}"
.format(type(target_size)))
self.target_size = target_size
self.im_padding_value = im_padding_value
self.label_padding_value = label_padding_value
def __call__(self, im: np.ndarray , label: np.ndarray = None) -> Tuple:
"""
Args:
im (np.ndarray): The Image data.
label (np.ndarray, optional): The label data. Default: None.
Returns:
(tuple). When label is None, it returns (im, ), otherwise it returns (im, label).
"""
im_height, im_width = im.shape[0], im.shape[1]
if isinstance(self.target_size, int):
target_height = self.target_size
target_width = self.target_size
else:
target_height = self.target_size[1]
target_width = self.target_size[0]
pad_height = target_height - im_height
pad_width = target_width - im_width
if pad_height < 0 or pad_width < 0:
raise ValueError(
'The size of image should be less than `target_size`, but the size of image ({}, {}) is larger than `target_size` ({}, {})'
.format(im_width, im_height, target_width, target_height))
else:
im = cv2.copyMakeBorder(im, 0, pad_height, 0, pad_width, cv2.BORDER_CONSTANT,
value=self.im_padding_value)
if label is not None:
label = cv2.copyMakeBorder(label, 0, pad_height, 0, pad_width, cv2.BORDER_CONSTANT,
value=self.label_padding_value)
if label is None:
return (im,)
else:
return (im, label)
class Normalize:
"""
Normalize an image.
Args:
mean (list|tuple): The mean value of a data set. Default: [0.5, 0.5, 0.5].
std (list|tuple): The standard deviation of a data set. Default: [0.5, 0.5, 0.5].
Raises:
ValueError: When mean/std is not list or any value in std is 0.
"""
def __init__(self, mean: Union[List[float], Tuple[float]] = (0.5, 0.5, 0.5),
std: Union[List[float], Tuple[float]] = (0.5, 0.5, 0.5)):
self.mean = mean
self.std = std
if not (isinstance(self.mean, (list, tuple))
and isinstance(self.std, (list, tuple))):
raise ValueError(
"{}: input type is invalid. It should be list or tuple".format(
self))
from functools import reduce
if reduce(lambda x, y: x * y, self.std) == 0:
raise ValueError('{}: std is invalid!'.format(self))
def __call__(self, im: np.ndarray, label: np.ndarray = None) -> Tuple:
"""
Args:
im (np.ndarray): The Image data.
label (np.ndarray, optional): The label data. Default: None.
Returns:
(tuple). When label is None, it returns (im, ), otherwise it returns (im, label).
"""
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
im = F.normalize(im, mean, std)
if label is None:
return (im,)
else:
return (im, label)
class Resize:
"""
Resize an image.
Args:
target_size (list|tuple, optional): The target size of image. Default: (512, 512).
interp (str, optional): The interpolation mode of resize is consistent with opencv.
['NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM']. Note that when it is
'RANDOM', a random interpolation mode would be specified. Default: "LINEAR".
Raises:
TypeError: When 'target_size' type is neither list nor tuple.
ValueError: When "interp" is out of pre-defined methods ('NEAREST', 'LINEAR', 'CUBIC',
'AREA', 'LANCZOS4', 'RANDOM').
"""
# The interpolation mode
interp_dict = {
'NEAREST': cv2.INTER_NEAREST,
'LINEAR': cv2.INTER_LINEAR,
'CUBIC': cv2.INTER_CUBIC,
'AREA': cv2.INTER_AREA,
'LANCZOS4': cv2.INTER_LANCZOS4
}
def __init__(self, target_size: Union[List[int], Tuple[int]] = (512, 512), interp: str = 'LINEAR'):
self.interp = interp
if not (interp == "RANDOM" or interp in self.interp_dict):
raise ValueError("`interp` should be one of {}".format(
self.interp_dict.keys()))
if isinstance(target_size, list) or isinstance(target_size, tuple):
if len(target_size) != 2:
raise ValueError(
'`target_size` should include 2 elements, but it is {}'.
format(target_size))
else:
raise TypeError(
"Type of `target_size` is invalid. It should be list or tuple, but it is {}"
.format(type(target_size)))
self.target_size = target_size
def __call__(self, im: np.ndarray, label: np.ndarray = None) -> Tuple:
"""
Args:
im (np.ndarray): The Image data.
label (np.ndarray, optional): The label data. Default: None.
Returns:
(tuple). When label is None, it returns (im, ), otherwise it returns (im, label),
Raises:
TypeError: When the 'img' type is not numpy.
ValueError: When the length of "im" shape is not 3.
"""
if not isinstance(im, np.ndarray):
raise TypeError("Resize: image type is not numpy.")
if len(im.shape) != 3:
raise ValueError('Resize: image is not 3-dimensional.')
if self.interp == "RANDOM":
interp = random.choice(list(self.interp_dict.keys()))
else:
interp = self.interp
im = F.resize(im, self.target_size, self.interp_dict[interp])
if label is not None:
label = F.resize(label, self.target_size,
cv2.INTER_NEAREST)
if label is None:
return (im,)
else:
return (im, label)
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,11 +13,14 @@
# limitations under the License.
import os
from typing import Callable, Union, List, Tuple
import cv2
import paddle
import PIL
import numpy as np
import matplotlib as plt
import paddle.nn.functional as F
def is_image_file(filename: str) -> bool:
......@@ -26,7 +29,7 @@ def is_image_file(filename: str) -> bool:
return ext in ['.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff']
def get_img_file(dir_name: str) -> list:
def get_img_file(dir_name: str) -> List[str]:
'''Get all image file paths in several directories which have the same parent directory.'''
images = []
for parent, _, filenames in os.walk(dir_name):
......@@ -39,7 +42,7 @@ def get_img_file(dir_name: str) -> list:
return images
def box_crop(boxes: np.ndarray, labels: np.ndarray, scores: np.ndarray, crop: list, img_shape: list):
def box_crop(boxes: np.ndarray, labels: np.ndarray, scores: np.ndarray, crop: List[int], img_shape: List[int]) -> Tuple:
"""Crop the boxes ,labels, scores according to the given shape"""
x, y, w, h = map(float, crop)
......@@ -99,7 +102,7 @@ def draw_boxes_on_image(image_path: str,
boxes: np.ndarray,
scores: np.ndarray,
labels: np.ndarray,
label_names: list,
label_names: List[str],
score_thresh: float = 0.5,
save_path: str = 'result'):
"""Draw boxes on images."""
......@@ -145,7 +148,7 @@ def draw_boxes_on_image(image_path: str,
plt.close('all')
def get_label_infos(file_list: str):
def get_label_infos(file_list: str) -> str:
"""Get label names by corresponding category ids."""
from pycocotools.coco import COCO
map_label = COCO(file_list)
......@@ -175,10 +178,115 @@ def gram_matrix(data: paddle.Tensor) -> paddle.Tensor:
return gram
def npmax(array: np.ndarray):
def npmax(array: np.ndarray) -> Tuple[int]:
"""Get max value and index."""
arrayindex = array.argmax(1)
arrayvalue = array.max(1)
i = arrayvalue.argmax()
j = arrayindex[i]
return i, j
def visualize(image: Union[np.ndarray, str], result: np.ndarray, weight: float = 0.6) -> np.ndarray:
"""
Convert segmentation result to color image, and save added image.
Args:
image (str|np.ndarray): The path of origin image or bgr image.
result (np.ndarray): The predict result of image.
weight (float): The image weight of visual image, and the result weight is (1 - weight). Default: 0.6
Returns:
vis_result (np.ndarray): return the visualized result.
"""
color_map = get_color_map_list(256)
color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
color_map = np.array(color_map).astype("uint8")
# Use OpenCV LUT for color mapping
c1 = cv2.LUT(result, color_map[:, 0])
c2 = cv2.LUT(result, color_map[:, 1])
c3 = cv2.LUT(result, color_map[:, 2])
pseudo_img = np.dstack((c1, c2, c3))
if isinstance(image, str):
im = cv2.imread(image)
else:
im = image
vis_result = cv2.addWeighted(im, weight, pseudo_img, 1 - weight, 0)
return vis_result
def get_pseudo_color_map(pred: np.ndarray) -> PIL.Image.Image:
'''visualization the segmentation mask.'''
pred_mask = PIL.Image.fromarray(pred.astype(np.uint8), mode='P')
color_map = get_color_map_list(256)
pred_mask.putpalette(color_map)
return pred_mask
def get_color_map_list(num_classes: int) -> List[int]:
"""
Returns the color map for visualizing the segmentation mask,
which can support arbitrary number of classes.
Args:
num_classes (int): Number of classes.
Returns:
(list). The color map.
"""
num_classes += 1
color_map = num_classes * [0, 0, 0]
for i in range(0, num_classes):
j = 0
lab = i
while lab:
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
j += 1
lab >>= 3
color_map = color_map[3:]
return color_map
def get_reverse_list(ori_shape: List[int], transforms: List[Callable]) -> List[tuple]:
"""
get reverse list of transform.
Args:
ori_shape (list): Origin shape of image.
transforms (list): List of transform.
Returns:
list: List of tuple, there are two format:
('resize', (h, w)) The image shape before resize,
('padding', (h, w)) The image shape before padding.
"""
reverse_list = []
h, w = ori_shape[0], ori_shape[1]
for op in transforms:
if op.__class__.__name__ in ['Resize', 'ResizeByLong']:
reverse_list.append(('resize', (h, w)))
h, w = op.target_size[0], op.target_size[1]
if op.__class__.__name__ in ['Padding']:
reverse_list.append(('padding', (h, w)))
w, h = op.target_size[0], op.target_size[1]
return reverse_list
def reverse_transform(pred: paddle.Tensor, ori_shape: List[int], transforms: List[int]) -> paddle.Tensor:
"""recover pred to origin shape"""
reverse_list = get_reverse_list(ori_shape, transforms)
for item in reverse_list[::-1]:
if item[0] == 'resize':
h, w = item[1][0], item[1][1]
pred = F.interpolate(pred, (h, w), mode='nearest')
elif item[0] == 'padding':
h, w = item[1][0], item[1][1]
pred = pred[:, :, 0:h, 0:w]
else:
raise Exception("Unexpected info '{}' in im_info".format(item[0]))
return pred
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册