提交 32d7baa2 编写于 作者: F FlyingQianMM

add remote_sensing

上级 00b64384
...@@ -112,6 +112,46 @@ batch_predict(self, img_file_list, transforms=None, thread_num=2): ...@@ -112,6 +112,46 @@ batch_predict(self, img_file_list, transforms=None, thread_num=2):
### tile_predict
> DeepLabv3p模型无重叠的大图切小图预测接口。将大图像切分成互不重叠多个小块,分别对每个小块进行预测,最后将小块预测结果拼接成大图预测结果。由于每个小块边缘部分的预测效果会比中间部分的差,因此每个小块拼接处可能会有明显的裂痕感。
> 需要注意的是,只有在训练过程中定义了eval_dataset,模型在保存时才会将预测时的图像处理流程保存在`DeepLabv3p.test_transforms`和`DeepLabv3p.eval_transforms`中。如未在训练时定义eval_dataset,那在调用预测`tile_predict`接口时,用户需要再重新定义test_transforms传入给`tile_predict`接口。
> **参数**
> >
> > - **img_file** (str|np.ndarray): 预测图像路径或numpy数组(HWC排列,BGR格式)。
> > - **tile_size** (list|tuple): 切分小块的大小,格式为(W,H)。默认值为[512, 512]。
> > - **batch_size** (int):对小块进行批量预测时的批量大小。默认值为32。
> > - **thread_num** (int): 并发执行各小块预处理时的线程数。默认值为8。
> > - **transforms** (paddlex.seg.transforms): 数据预处理操作。
> **返回值**
> >
> > - **dict**: 包含关键字'label_map'和'score_map', 'label_map'存储预测结果灰度图,像素值表示对应的类别,'score_map'存储各类别的概率,shape=(h, w, num_classes)。
### overlap_tile_predict
> DeepLabv3p模型有重叠的大图切小图预测接口。Unet论文作者提出一种有重叠的大图切小图策略(Overlap-tile strategy)来消除拼接处的裂痕感。每次划分小块时向四周扩展面积,例如下图中的蓝色部分区域,到拼接大图时取小块中间部分的预测结果,例如下图中的黄色部分区域,对于处于原始图像边缘处的小块,其扩展面积下的像素则通过将边缘部分像素镜像填补得到。
![](../../../examples/remote_sensing/images/overlap_tile.png)
> 需要注意的是,只有在训练过程中定义了eval_dataset,模型在保存时才会将预测时的图像处理流程保存在`DeepLabv3p.test_transforms`和`DeepLabv3p.eval_transforms`中。如未在训练时定义eval_dataset,那在调用预测`overlap_tile_predict`接口时,用户需要再重新定义test_transforms传入给`overlap_tile_predict`接口。
> **参数**
> >
> > - **img_file** (str|np.ndarray): 预测图像路径或numpy数组(HWC排列,BGR格式)。
> > - **tile_size** (list|tuple): 切分小块中间部分用于拼接预测结果的大小,格式为(W,H)。默认值为[512, 512]。
> > - **pad_size** (list|tuple): 切分小块向四周扩展的大小,格式为(W,H)。
> > - **batch_size** (int):对小块进行批量预测时的批量大小。默认值为32。
> > - **thread_num** (int): 并发执行各小块预处理时的线程数。默认值为8。
> > - **transforms** (paddlex.seg.transforms): 数据预处理操作。
> **返回值**
> >
> > - **dict**: 包含关键字'label_map'和'score_map', 'label_map'存储预测结果灰度图,像素值表示对应的类别,'score_map'存储各类别的概率,shape=(h, w, num_classes)。
## paddlex.seg.UNet ## paddlex.seg.UNet
```python ```python
...@@ -133,6 +173,8 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us ...@@ -133,6 +173,8 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us
> - evaluate 评估接口说明同 [DeepLabv3p模型evaluate接口](#evaluate) > - evaluate 评估接口说明同 [DeepLabv3p模型evaluate接口](#evaluate)
> - predict 预测接口说明同 [DeepLabv3p模型predict接口](#predict) > - predict 预测接口说明同 [DeepLabv3p模型predict接口](#predict)
> - batch_predict 批量预测接口说明同 [DeepLabv3p模型predict接口](#batch-predict) > - batch_predict 批量预测接口说明同 [DeepLabv3p模型predict接口](#batch-predict)
> - tile_predict 无重叠的大图切小图预测接口同 [DeepLabv3p模型tile_predict接口](#tile-predict)
> - overlap_tile_predict 有重叠的大图切小图预测接口同 [DeepLabv3p模型poverlap_tile_predict接口](#overlap-tile-predict)
## paddlex.seg.HRNet ## paddlex.seg.HRNet
...@@ -155,6 +197,8 @@ paddlex.seg.HRNet(num_classes=2, width=18, use_bce_loss=False, use_dice_loss=Fal ...@@ -155,6 +197,8 @@ paddlex.seg.HRNet(num_classes=2, width=18, use_bce_loss=False, use_dice_loss=Fal
> - evaluate 评估接口说明同 [DeepLabv3p模型evaluate接口](#evaluate) > - evaluate 评估接口说明同 [DeepLabv3p模型evaluate接口](#evaluate)
> - predict 预测接口说明同 [DeepLabv3p模型predict接口](#predict) > - predict 预测接口说明同 [DeepLabv3p模型predict接口](#predict)
> - batch_predict 批量预测接口说明同 [DeepLabv3p模型predict接口](#batch-predict) > - batch_predict 批量预测接口说明同 [DeepLabv3p模型predict接口](#batch-predict)
> - tile_predict 无重叠的大图切小图预测接口同 [DeepLabv3p模型tile_predict接口](#tile-predict)
> - overlap_tile_predict 有重叠的大图切小图预测接口同 [DeepLabv3p模型poverlap_tile_predict接口](#overlap-tile-predict)
## paddlex.seg.FastSCNN ## paddlex.seg.FastSCNN
...@@ -177,3 +221,5 @@ paddlex.seg.FastSCNN(num_classes=2, use_bce_loss=False, use_dice_loss=False, cla ...@@ -177,3 +221,5 @@ paddlex.seg.FastSCNN(num_classes=2, use_bce_loss=False, use_dice_loss=False, cla
> - evaluate 评估接口说明同 [DeepLabv3p模型evaluate接口](#evaluate) > - evaluate 评估接口说明同 [DeepLabv3p模型evaluate接口](#evaluate)
> - predict 预测接口说明同 [DeepLabv3p模型predict接口](#predict) > - predict 预测接口说明同 [DeepLabv3p模型predict接口](#predict)
> - batch_predict 批量预测接口说明同 [DeepLabv3p模型predict接口](#batch-predict) > - batch_predict 批量预测接口说明同 [DeepLabv3p模型predict接口](#batch-predict)
> - tile_predict 无重叠的大图切小图预测接口同 [DeepLabv3p模型tile_predict接口](#tile-predict)
> - overlap_tile_predict 有重叠的大图切小图预测接口同 [DeepLabv3p模型poverlap_tile_predict接口](#overlap-tile-predict)
...@@ -12,3 +12,4 @@ PaddleX精选飞桨视觉开发套件在产业实践中的成熟模型结构, ...@@ -12,3 +12,4 @@ PaddleX精选飞桨视觉开发套件在产业实践中的成熟模型结构,
solutions.md solutions.md
meter_reader.md meter_reader.md
human_segmentation.md human_segmentation.md
remote_sensing.md
# 遥感影像分割
本案例基于PaddleX实现遥感影像分割,提供无重叠的大图切小图以及有重叠的大图切小图两种预测方式。
## 前置依赖
* Paddle paddle >= 1.8.4
* Python >= 3.5
* PaddleX >= 1.1.0
安装的相关问题参考[PaddleX安装](../install.md)
下载PaddleX源码:
```
git clone https://github.com/PaddlePaddle/PaddleX
```
该案例所有脚本均位于`PaddleX/examples/remote_sensing/`,进入该目录:
```
cd PaddleX/examples/remote_sensing/
```
## 数据准备
本案例使用2015 CCF大数据比赛提供的高清遥感影像,包含5张带标注的RGB图像,图像尺寸最大有7969 × 7939、最小有4011 × 2470。该数据集共标注了5类物体,分别是背景(标记为0)、植被(标记为1)、建筑(标记为2)、水体(标记为3)、道路 (标记为4)。
本案例将前4张图片划分入训练集,第5张图片作为验证集。为增加训练时的批量大小,以滑动窗口为(1024,1024)、步长为(512, 512)对前4张图片进行切分,加上原本的4张大尺寸图片,训练集一共有688张图片。直接对大图片进行验证会导致显存不足,为避免此类问题的出现,针对验证集,以滑动窗口为(769, 769)、步长为(769,769)对第5张图片进行切分,得到40张子图片。
运行以下脚本,下载原始数据集,并完成数据集的切分:
```
python3 prepare_data.py
```
## 模型训练
分割模型选择Backbone为MobileNetv3_large_ssld的Deeplabv3模型,该模型兼备高性能高精度的优点。运行以下脚本,进行模型训练:
```
python3 train.py
```
## 模型预测
直接对大图片进行预测会导致显存不足,为避免此类问题的出现,本案例提供了两种预测方式:无重叠的大图切小图和有重叠的大图切小图。
* 无重叠的大图切小图
将大图像切分成互不重叠多个小块,分别对每个小块进行预测,最后将小块预测结果拼接成大图预测结果。由于每个小块边缘部分的预测效果会比中间部分的差,因此每个小块拼接处可能会有明显的裂痕感。
该预测方式的API接口详见[tile_predict](../apis/models/semantic_segmentation.html#tile-predict)
* 有重叠的大图切小图
Unet论文作者提出一种有重叠的大图切小图策略(Overlap-tile strategy)来消除拼接处的裂痕感。每次划分小块时向四周扩展面积,例如下图中的蓝色部分区域,到拼接大图时取小块中间部分的预测结果,例如下图中的黄色部分区域,对于处于原始图像边缘处的小块,其扩展面积下的像素则通过将边缘部分像素镜像填补得到。
该预测方式的API接口说明详见[overlap_tile_predict](../apis/models/semantic_segmentation.html#overlap-tile-predict)
![](../../examples/remote_sensing/images/overlap_tile.png)
相比无重叠的大图切小图,有重叠的大图切小图策略将本案例的模型精度miou从80.58%提升至81.52%,并且将预测可视化结果中裂痕感显著消除,可见下图中两种预测方式的效果对比。
![](../../examples/remote_sensing/images/visualize_compare.png)
运行以下脚本使用有重叠的大图切小图预测方式进行预测。如需使用无重叠的大图切小图的预测方式,参考以下脚本中的注释修改模型预测接口:
```
python3 predict.py
```
## 模型评估
在训练过程中,每隔10个迭代轮数会评估一次模型在验证集的精度。由于已事先将原始大尺寸图片切分成小块,此时相当于使用无重叠的大图切小图预测方式,最优模型精度miou为80.58%。运行以下脚本,将采用有重叠的大图切小图的预测方式,重新评估原始大尺寸图片的模型精度,此时miou为81.52%。
```
python3 eval.py
```
遥感分割案例
=======================================
这里面写遥感分割案例,可根据需求拆分为多个文档
# 遥感影像分割
本案例基于PaddleX实现遥感影像分割,提供无重叠的大图切小图以及有重叠的大图切小图两种预测方式。
## 目录
* [数据准备](#1)
* [模型训练](#2)
* [模型预测](#3)
* [模型评估](#4)
#### 前置依赖
* Paddle paddle >= 1.8.4
* Python >= 3.5
* PaddleX >= 1.1.0
安装的相关问题参考[PaddleX安装](../install.md)
下载PaddleX源码:
```
git clone https://github.com/PaddlePaddle/PaddleX
```
该案例所有脚本均位于`PaddleX/examples/remote_sensing/`,进入该目录:
```
cd PaddleX/examples/remote_sensing/
```
## <h2 id="1">数据准备</h2>
本案例使用2015 CCF大数据比赛提供的高清遥感影像,包含5张带标注的RGB图像,图像尺寸最大有7969 × 7939、最小有4011 × 2470。该数据集共标注了5类物体,分别是背景(标记为0)、植被(标记为1)、建筑(标记为2)、水体(标记为3)、道路 (标记为4)。
本案例将前4张图片划分入训练集,第5张图片作为验证集。为增加训练时的批量大小,以滑动窗口为(1024,1024)、步长为(512, 512)对前4张图片进行切分,加上原本的4张大尺寸图片,训练集一共有688张图片。直接对大图片进行验证会导致显存不足,为避免此类问题的出现,针对验证集,以滑动窗口为(769, 769)、步长为(769,769)对第5张图片进行切分,得到40张子图片。
运行以下脚本,下载原始数据集,并完成数据集的切分:
```
python3 prepare_data.py
```
## <h2 id="2">模型训练</h2>
分割模型选择Backbone为MobileNetv3_large_ssld的Deeplabv3模型,该模型兼备高性能高精度的优点。运行以下脚本,进行模型训练:
```
python3 train.py
```
## <h2 id="2">模型预测</h2>
直接对大图片进行预测会导致显存不足,为避免此类问题的出现,本案例提供了两种预测方式:无重叠的大图切小图和有重叠的大图切小图。
* 无重叠的大图切小图
将大图像切分成互不重叠多个小块,分别对每个小块进行预测,最后将小块预测结果拼接成大图预测结果。由于每个小块边缘部分的预测效果会比中间部分的差,因此每个小块拼接处可能会有明显的裂痕感。
该预测方式的API接口详见[tile_predict](https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#tile-predict)
* 有重叠的大图切小图
Unet论文作者提出一种有重叠的大图切小图策略(Overlap-tile strategy)来消除拼接处的裂痕感。每次划分小块时向四周扩展面积,例如下图中的蓝色部分区域,到拼接大图时取小块中间部分的预测结果,例如下图中的黄色部分区域,对于处于原始图像边缘处的小块,其扩展面积下的像素则通过将边缘部分像素镜像填补得到。
该预测方式的API接口说明详见[overlap_tile_predict](https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#overlap-tile-predict)
![](images/overlap_tile.png)
相比无重叠的大图切小图,有重叠的大图切小图策略将本案例的模型精度miou从80.58%提升至81.52%,并且将预测可视化结果中裂痕感显著消除,可见下图中两种预测方式的效果对比。
![](images/visualize_compare.png)
运行以下脚本使用有重叠的大图切小图预测方式进行预测。如需使用无重叠的大图切小图的预测方式,参考以下脚本中的注释修改模型预测接口:
```
python3 predict.py
```
## <h2 id="2">模型评估</h2>
在训练过程中,每隔10个迭代轮数会评估一次模型在验证集的精度。由于已事先将原始大尺寸图片切分成小块,此时相当于使用无重叠的大图切小图预测方式,最优模型精度miou为80.58%。运行以下脚本,将采用有重叠的大图切小图的预测方式,重新评估原始大尺寸图片的模型精度,此时miou为81.52%。
```
python3 eval.py
```
# 环境变量配置,用于控制是否使用GPU
# 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import numpy as np
import cv2
from PIL import Image
from collections import OrderedDict
import paddlex as pdx
import paddlex.utils.logging as logging
from paddlex.cv.models.utils.seg_eval import ConfusionMatrix
# 导入模型参数
model = pdx.load_model('output/deeplabv3p_mobilenetv3_large_ssld/best_model')
# 指定待评估图像路径及其标注文件路径
img_file = "dataset/JPEGImages/5.png"
label_file = "dataset/Annotations/5_class.png"
# 定义用于计算miou、iou、macc、acc、kapp指标的混淆矩阵类
conf_mat = ConfusionMatrix(model.num_classes, streaming=True)
# 使用"无重叠的大图切小图"方式进行预测:将大图像切分成互不重叠多个小块,分别对每个小块进行预测
# 最后将小块预测结果拼接成大图预测结果
# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#tile-predict
# tile_predict = model.tile_predict(img_file=img_file, tile_size=(769, 769))
# pred = tile_predict["label_map"]
# 使用"有重叠的大图切小图"策略进行预测:将大图像切分成相互重叠的多个小块,
# 分别对每个小块进行预测,将小块预测结果的中间部分拼接成大图预测结果
# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#overlap-tile-predict
overlap_tile_predict = model.overlap_tile_predict(
img_file=img_file, tile_size=(769, 769))
pred = overlap_tile_predict["label_map"]
# 更新混淆矩阵
pred = pred[np.newaxis, :, :, np.newaxis]
pred = pred.astype(np.int64)
label = np.asarray(Image.open("dataset/Annotations/5_class.png"))
label = label[np.newaxis, np.newaxis, :, :]
mask = label != model.ignore_index
conf_mat.calculate(pred=pred, label=label, ignore=mask)
# 计算miou、iou、macc、acc、kapp
category_iou, miou = conf_mat.mean_iou()
category_acc, macc = conf_mat.accuracy()
logging.info(
"miou={:.6f} category_iou={} macc={:.6f} category_acc={} kappa={:.6f}".
format(miou, category_iou, macc, category_acc, conf_mat.kappa()))
# 环境变量配置,用于控制是否使用GPU
# 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import paddlex as pdx
# 导入模型参数
model = pdx.load_model('output/deeplabv3p_mobilenetv3_large_ssld/best_model')
# 指定待预测图像路径
img_file = "dataset/JPEGImages/5.png"
# 使用"无重叠的大图切小图"方式进行预测:将大图像切分成互不重叠多个小块,分别对每个小块进行预测
# 最后将小块预测结果拼接成大图预测结果
# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#tile-predict
# pred = model.tile_predict(img_file=img_file, tile_size=(769, 769))
# 使用"有重叠的大图切小图"策略进行预测:将大图像切分成相互重叠的多个小块,
# 分别对每个小块进行预测,将小块预测结果的中间部分拼接成大图预测结果
# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#overlap-tile-predict
pred = model.overlap_tile_predict(img_file=img_file, tile_size=(769, 769))
# 可视化预测结果
# API说明:
pdx.seg.visualize(
img_file,
pred,
weight=0.,
save_dir='output/deeplabv3p_mobilenetv3_large_ssld/')
import os
import os.path as osp
import numpy as np
import cv2
import shutil
from PIL import Image
import paddlex as pdx
# 定义训练集切分时的滑动窗口大小和步长,格式为(W, H)
train_tile_size = (1024, 1024)
train_stride = (512, 512)
# 定义验证集切分时的滑动窗口大小和步长,格式(W, H)
val_tile_size = (769, 769)
val_stride = (769, 769)
# 下载并解压2015 CCF大数据比赛提供的高清遥感影像
ccf_remote_dataset = 'https://bj.bcebos.com/paddlex/examples/remote_sensing/datasets/ccf_remote_dataset.tar.gz'
pdx.utils.download_and_decompress(ccf_remote_dataset, path='./')
if not osp.exists('./dataset/JPEGImages'):
os.makedirs('./dataset/JPEGImages')
if not osp.exists('./dataset/Annotations'):
os.makedirs('./dataset/Annotations')
# 将前4张图片划分入训练集,并切分成小块之后加入到训练集中
# 并生成train_list.txt
for train_id in range(1, 5):
shutil.copyfile("ccf_remote_dataset/{}.png".format(train_id),
"./dataset/JPEGImages/{}.png".format(train_id))
shutil.copyfile("ccf_remote_dataset/{}_class.png".format(train_id),
"./dataset/Annotations/{}_class.png".format(train_id))
mode = 'w' if train_id == 1 else 'a'
with open('./dataset/train_list.txt', mode) as f:
f.write("JPEGImages/{}.png Annotations/{}_class.png\n".format(
train_id, train_id))
for train_id in range(1, 5):
image = cv2.imread('ccf_remote_dataset/{}.png'.format(train_id))
label = Image.open('ccf_remote_dataset/{}_class.png'.format(train_id))
H, W, C = image.shape
train_tile_id = 1
for h in range(0, H, train_stride[1]):
for w in range(0, W, train_stride[0]):
left = w
upper = h
right = min(w + train_tile_size[0] * 2, W)
lower = min(h + train_tile_size[1] * 2, H)
tile_image = image[upper:lower, left:right, :]
cv2.imwrite("./dataset/JPEGImages/{}_{}.png".format(
train_id, train_tile_id), tile_image)
cut_label = label.crop((left, upper, right, lower))
cut_label.save("./dataset/Annotations/{}_class_{}.png".format(
train_id, train_tile_id))
with open('./dataset/train_list.txt', 'a') as f:
f.write("JPEGImages/{}_{}.png Annotations/{}_class_{}.png\n".
format(train_id, train_tile_id, train_id,
train_tile_id))
train_tile_id += 1
# 将第5张图片切分成小块之后加入到验证集中
val_id = 5
val_tile_id = 1
shutil.copyfile("ccf_remote_dataset/{}.png".format(val_id),
"./dataset/JPEGImages/{}.png".format(val_id))
shutil.copyfile("ccf_remote_dataset/{}_class.png".format(val_id),
"./dataset/Annotations/{}_class.png".format(val_id))
image = cv2.imread('ccf_remote_dataset/{}.png'.format(val_id))
label = Image.open('ccf_remote_dataset/{}_class.png'.format(val_id))
H, W, C = image.shape
for h in range(0, H, val_stride[1]):
for w in range(0, W, val_stride[0]):
left = w
upper = h
right = min(w + val_tile_size[0], W)
lower = min(h + val_tile_size[1], H)
cut_image = image[upper:lower, left:right, :]
cv2.imwrite("./dataset/JPEGImages/{}_{}.png".format(
val_id, val_tile_id), cut_image)
cut_label = label.crop((left, upper, right, lower))
cut_label.save("./dataset/Annotations/{}_class_{}.png".format(
val_id, val_tile_id))
mode = 'w' if val_tile_id == 1 else 'a'
with open('./dataset/val_list.txt', mode) as f:
f.write("JPEGImages/{}_{}.png Annotations/{}_class_{}.png\n".
format(val_id, val_tile_id, val_id, val_tile_id))
val_tile_id += 1
# 生成labels.txt
label_list = ['background', 'vegetation', 'building', 'water', 'road']
for i, label in enumerate(label_list):
mode = 'w' if i == 0 else 'a'
with open('./dataset/labels.txt', 'a') as f:
name = "{}\n".format(label) if i < len(
label_list) - 1 else "{}".format(label)
f.write(name)
# 环境变量配置,用于控制是否使用GPU
# 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import paddlex as pdx
from paddlex.seg import transforms
# 定义训练和验证时的transforms
# API说明 https://paddlex.readthedocs.io/zh_CN/develop/apis/transforms/seg_transforms.html
train_transforms = transforms.Compose([
transforms.RandomPaddingCrop(crop_size=769),
transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(),
transforms.Normalize()
])
eval_transforms = transforms.Compose(
[transforms.Padding(target_size=769), transforms.Normalize()])
# 定义训练和验证所用的数据集
# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/datasets.html#paddlex-datasets-segdataset
train_dataset = pdx.datasets.SegDataset(
data_dir='dataset',
file_list='dataset/train_list.txt',
label_list='dataset/labels.txt',
transforms=train_transforms,
shuffle=True)
eval_dataset = pdx.datasets.SegDataset(
data_dir='dataset',
file_list='dataset/val_list.txt',
label_list='dataset/labels.txt',
transforms=eval_transforms)
## 初始化模型,并进行训练
## 可使用VisualDL查看训练指标,参考https://paddlex.readthedocs.io/zh_CN/develop/train/visualdl.html
num_classes = len(train_dataset.labels)
# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#paddlex-seg-deeplabv3p
model = pdx.seg.DeepLabv3p(
num_classes=num_classes,
backbone='MobileNetV3_large_x1_0_ssld',
pooling_crop_size=(769, 769))
# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#train
# 各参数介绍与调整说明:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html
model.train(
num_epochs=400,
train_dataset=train_dataset,
train_batch_size=16,
eval_dataset=eval_dataset,
learning_rate=0.01,
save_interval_epochs=10,
pretrain_weights='CITYSCAPES',
save_dir='output/deeplabv3p_mobilenetv3_large_ssld',
use_vdl=True)
...@@ -553,8 +553,31 @@ class DeepLabv3p(BaseAPI): ...@@ -553,8 +553,31 @@ class DeepLabv3p(BaseAPI):
img_file, img_file,
tile_size=[512, 512], tile_size=[512, 512],
batch_size=32, batch_size=32,
thread_num=8): thread_num=8,
transforms=None):
"""无重叠的大图切小图预测。
Args:
img_file(str|np.ndarray): 预测图像路径,或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。
tile_size(list|tuple): 切分小块的大小,格式为(W,H)。默认值为[512, 512]。
batch_size(int):对小块进行批量预测时的批量大小。默认值为32。
thread_num (int): 并发执行各小块预处理时的线程数。默认值为8。
transforms(paddlex.cv.transforms): 数据预处理操作。
Returns:
dict: 包含关键字'label_map'和'score_map', 'label_map'存储预测结果灰度图,
像素值表示对应的类别,'score_map'存储各类别的概率,shape=(h, w, num_classes)
"""
if transforms is None and not hasattr(self, 'test_transforms'):
raise Exception("transforms need to be defined, now is None.")
if isinstance(img_file, str):
image = cv2.imread(img_file) image = cv2.imread(img_file)
elif isinstance(img_file, np.ndarray):
image = img_file.copy()
else:
raise Exception("im_file must be list/tuple")
height, width, channel = image.shape height, width, channel = image.shape
image_tile_list = list() image_tile_list = list()
# crop the image into tile pieces # crop the image into tile pieces
...@@ -577,7 +600,8 @@ class DeepLabv3p(BaseAPI): ...@@ -577,7 +600,8 @@ class DeepLabv3p(BaseAPI):
end = min(i + batch_size, num_tiles) end = min(i + batch_size, num_tiles)
res = self.batch_predict( res = self.batch_predict(
img_file_list=image_tile_list[begin:end], img_file_list=image_tile_list[begin:end],
thread_num=thread_num) thread_num=thread_num,
transforms=transforms)
for j in range(begin, end): for j in range(begin, end):
h_id = j // (width // tile_size[0] + 1) h_id = j // (width // tile_size[0] + 1)
w_id = j % (width // tile_size[0] + 1) w_id = j % (width // tile_size[0] + 1)
...@@ -598,7 +622,31 @@ class DeepLabv3p(BaseAPI): ...@@ -598,7 +622,31 @@ class DeepLabv3p(BaseAPI):
pad_size=[64, 64], pad_size=[64, 64],
batch_size=32, batch_size=32,
thread_num=8): thread_num=8):
"""有重叠的大图切小图预测。
Args:
img_file(str|np.ndarray): 预测图像路径,或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。
tile_size(list|tuple): 切分小块中间部分用于拼接预测结果的大小,格式为(W,H)。默认值为[512, 512]。
pad_size(list|tuple): 切分小块向四周扩展的大小,格式为(W,H)。默认值为[64,64]。
batch_size(int):对小块进行批量预测时的批量大小。默认值为32
thread_num (int): 并发执行各小块预处理时的线程数。默认值为8。
transforms(paddlex.cv.transforms): 数据预处理操作。
Returns:
dict: 包含关键字'label_map'和'score_map', 'label_map'存储预测结果灰度图,
像素值表示对应的类别,'score_map'存储各类别的概率,shape=(h, w, num_classes)
"""
if transforms is None and not hasattr(self, 'test_transforms'):
raise Exception("transforms need to be defined, now is None.")
if isinstance(img_file, str):
image = cv2.imread(img_file) image = cv2.imread(img_file)
elif isinstance(img_file, np.ndarray):
image = img_file.copy()
else:
raise Exception("im_file must be list/tuple")
height, width, channel = image.shape height, width, channel = image.shape
image_tile_list = list() image_tile_list = list()
...@@ -638,7 +686,8 @@ class DeepLabv3p(BaseAPI): ...@@ -638,7 +686,8 @@ class DeepLabv3p(BaseAPI):
end = min(i + batch_size, num_tiles) end = min(i + batch_size, num_tiles)
res = self.batch_predict( res = self.batch_predict(
img_file_list=image_tile_list[begin:end], img_file_list=image_tile_list[begin:end],
thread_num=thread_num) thread_num=thread_num,
transforms=transforms)
for j in range(begin, end): for j in range(begin, end):
h_id = j // (width // tile_size[0] + 1) h_id = j // (width // tile_size[0] + 1)
w_id = j % (width // tile_size[0] + 1) w_id = j % (width // tile_size[0] + 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册