diff --git a/docs/apis/models/semantic_segmentation.md b/docs/apis/models/semantic_segmentation.md
index ab62bd57fa25f4ed2bc35551996688de0cedcacc..b46a6273c660a017700c00598515891904fb9dde 100755
--- a/docs/apis/models/semantic_segmentation.md
+++ b/docs/apis/models/semantic_segmentation.md
@@ -110,6 +110,34 @@ batch_predict(self, img_file_list, transforms=None):
> > - **dict**: 每个元素都为列表,表示各图像的预测结果。各图像的预测结果用字典表示,包含关键字'label_map'和'score_map', 'label_map'存储预测结果灰度图,像素值表示对应的类别,'score_map'存储各类别的概率,shape=(h, w, num_classes)。
+### overlap_tile_predict
+
+```
+overlap_tile_predict(self, img_file, tile_size=[512, 512], pad_size=[64, 64], batch_size=32, transforms=None)
+```
+
+> DeepLabv3p模型的滑动预测接口, 支持有重叠和无重叠两种方式。
+
+> **无重叠的滑动窗口预测**:在输入图片上以固定大小的窗口滑动,分别对每个窗口下的图像进行预测,最后将各窗口的预测结果拼接成输入图片的预测结果。**使用时需要把参数`pad_size`设置为`[0, 0]`**。
+
+> **有重叠的滑动窗口预测**:在Unet论文中,作者提出一种有重叠的滑动窗口预测策略(Overlap-tile strategy)来消除拼接处的裂痕感。对各滑动窗口预测时,会向四周扩展一定的面积,对扩展后的窗口进行预测,例如下图中的蓝色部分区域,到拼接时只取各窗口中间部分的预测结果,例如下图中的黄色部分区域。位于输入图像边缘处的窗口,其扩展面积下的像素则通过将边缘部分像素镜像填补得到。
+
+![](../../../examples/remote_sensing/images/overlap_tile.png)
+
+> 需要注意的是,只有在训练过程中定义了eval_dataset,模型在保存时才会将预测时的图像处理流程保存在`DeepLabv3p.test_transforms`和`DeepLabv3p.eval_transforms`中。如未在训练时定义eval_dataset,那在调用预测`overlap_tile_predict`接口时,用户需要再重新定义test_transforms传入给`overlap_tile_predict`接口。
+
+> **参数**
+> >
+> > - **img_file** (str|np.ndarray): 预测图像路径或numpy数组(HWC排列,BGR格式)。
+> > - **tile_size** (list|tuple): 滑动窗口的大小,该区域内用于拼接预测结果,格式为(W,H)。默认值为[512, 512]。
+> > - **pad_size** (list|tuple): 滑动窗口向四周扩展的大小,扩展区域内不用于拼接预测结果,格式为(W,H)。默认值为[64, 64]。
+> > - **batch_size** (int):对窗口进行批量预测时的批量大小。默认值为32。
+> > - **transforms** (paddlex.seg.transforms): 数据预处理操作。
+
+> **返回值**
+> >
+> > - **dict**: 包含关键字'label_map'和'score_map', 'label_map'存储预测结果灰度图,像素值表示对应的类别,'score_map'存储各类别的概率,shape=(h, w, num_classes)。
+
## paddlex.seg.UNet
@@ -133,6 +161,7 @@ paddlex.seg.UNet(num_classes=2, upsample_mode='bilinear', use_bce_loss=False, us
> - evaluate 评估接口说明同 [DeepLabv3p模型evaluate接口](#evaluate)
> - predict 预测接口说明同 [DeepLabv3p模型predict接口](#predict)
> - batch_predict 批量预测接口说明同 [DeepLabv3p模型predict接口](#batch-predict)
+> - overlap_tile_predict 滑动窗口预测接口同 [DeepLabv3p模型poverlap_tile_predict接口](#overlap-tile-predict)
## paddlex.seg.HRNet
@@ -156,6 +185,7 @@ paddlex.seg.HRNet(num_classes=2, width=18, use_bce_loss=False, use_dice_loss=Fal
> - evaluate 评估接口说明同 [DeepLabv3p模型evaluate接口](#evaluate)
> - predict 预测接口说明同 [DeepLabv3p模型predict接口](#predict)
> - batch_predict 批量预测接口说明同 [DeepLabv3p模型predict接口](#batch-predict)
+> - overlap_tile_predict 滑动窗预测接口同 [DeepLabv3p模型poverlap_tile_predict接口](#overlap-tile-predict)
## paddlex.seg.FastSCNN
@@ -179,3 +209,4 @@ paddlex.seg.FastSCNN(num_classes=2, use_bce_loss=False, use_dice_loss=False, cla
> - evaluate 评估接口说明同 [DeepLabv3p模型evaluate接口](#evaluate)
> - predict 预测接口说明同 [DeepLabv3p模型predict接口](#predict)
> - batch_predict 批量预测接口说明同 [DeepLabv3p模型predict接口](#batch-predict)
+> - overlap_tile_predict 滑动窗预测接口同 [DeepLabv3p模型poverlap_tile_predict接口](#overlap-tile-predict)
diff --git a/docs/examples/index.rst b/docs/examples/index.rst
index 8d7e68ee64a2f8cb47e9c489d018664054a568d6..f232951fdc7ee1301043140ad792bf977fee5853 100755
--- a/docs/examples/index.rst
+++ b/docs/examples/index.rst
@@ -13,3 +13,4 @@ PaddleX精选飞桨视觉开发套件在产业实践中的成熟模型结构,
meter_reader.md
human_segmentation.md
multi-channel_remote_sensing/README.md
+ remote_sensing.md
diff --git a/docs/examples/multi-channel_remote_sensing/README.md b/docs/examples/multi-channel_remote_sensing/README.md
index 1a46a6133e1cf75803c0d8a646840d096c21ee24..cfa91aa608483bea9d0b786a89f34d5feec029c2 100644
--- a/docs/examples/multi-channel_remote_sensing/README.md
+++ b/docs/examples/multi-channel_remote_sensing/README.md
@@ -7,7 +7,7 @@
## 前置依赖
* Paddle paddle >= 1.8.4
* Python >= 3.5
-* PaddleX >= 1.1.0
+* PaddleX >= 1.1.4
安装的相关问题参考[PaddleX安装](../../install.md)
diff --git a/docs/examples/remote_sensing.md b/docs/examples/remote_sensing.md
new file mode 100644
index 0000000000000000000000000000000000000000..26c13b005fdf8642f232132266f943855e9f50e9
--- /dev/null
+++ b/docs/examples/remote_sensing.md
@@ -0,0 +1,82 @@
+# RGB遥感影像分割
+
+本案例基于PaddleX实现遥感影像分割,提供滑动窗口预测方式,以避免在直接对大尺寸图片进行预测时显存不足的发生。此外,滑动窗口之间的重叠程度可配置,以此消除最终预测结果中各窗口拼接处的裂痕感。
+
+## 前置依赖
+
+* Paddle paddle >= 1.8.4
+* Python >= 3.5
+* PaddleX >= 1.1.4
+
+安装的相关问题参考[PaddleX安装](../install.md)
+
+下载PaddleX源码:
+
+```
+git clone https://github.com/PaddlePaddle/PaddleX
+```
+
+该案例所有脚本均位于`PaddleX/examples/remote_sensing/`,进入该目录:
+
+```
+cd PaddleX/examples/remote_sensing/
+```
+
+## 数据准备
+
+本案例使用2015 CCF大数据比赛提供的高清遥感影像,包含5张带标注的RGB图像,图像尺寸最大有7969 × 7939、最小有4011 × 2470。该数据集共标注了5类物体,分别是背景(标记为0)、植被(标记为1)、建筑(标记为2)、水体(标记为3)、道路 (标记为4)。
+
+本案例将前4张图片划分入训练集,第5张图片作为验证集。为增加训练时的批量大小,以滑动窗口为(1024,1024)、步长为(512, 512)对前4张图片进行切分,加上原本的4张大尺寸图片,训练集一共有688张图片。在训练过程中直接对大图片进行验证会导致显存不足,为避免此类问题的出现,针对验证集以滑动窗口为(769, 769)、步长为(769,769)对第5张图片进行切分,得到40张子图片。
+
+运行以下脚本,下载原始数据集,并完成数据集的切分:
+
+```
+python3 prepare_data.py
+```
+
+## 模型训练
+
+分割模型选择Backbone为MobileNetv3_large_ssld的Deeplabv3模型,该模型兼备高性能高精度的优点。运行以下脚本,进行模型训练:
+```
+python3 train.py
+```
+
+也可以跳过模型训练步骤,直接下载预训练模型进行后续的模型预测和评估:
+```
+wget https://bj.bcebos.com/paddlex/examples/remote_sensing/models/ccf_remote_model.tar.gz
+tar -xvf ccf_remote_model.tar.gz
+```
+
+## 模型预测
+
+直接对大尺寸图片进行预测会导致显存不足,为避免此类问题的出现,本案例提供了滑动窗口预测接口,支持有重叠和无重叠两种方式。
+
+* 无重叠的滑动窗口预测
+
+在输入图片上以固定大小的窗口滑动,分别对每个窗口下的图像进行预测,最后将各窗口的预测结果拼接成输入图片的预测结果。由于每个窗口边缘部分的预测效果会比中间部分的差,因此每个窗口拼接处可能会有明显的裂痕感。
+
+该预测方式的API接口详见[overlap_tile_predict](https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#overlap-tile-predict),**使用时需要把参数`pad_size`设置为`[0, 0]`**。
+
+* 有重叠的滑动窗口预测
+
+在Unet论文中,作者提出一种有重叠的滑动窗口预测策略(Overlap-tile strategy)来消除拼接处的裂痕感。对各滑动窗口预测时,会向四周扩展一定的面积,对扩展后的窗口进行预测,例如下图中的蓝色部分区域,到拼接时只取各窗口中间部分的预测结果,例如下图中的黄色部分区域。位于输入图像边缘处的窗口,其扩展面积下的像素则通过将边缘部分像素镜像填补得到。
+
+该预测方式的API接口说明详见[overlap_tile_predict](https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#overlap-tile-predict)。
+
+![](../../examples/remote_sensing/images/overlap_tile.png)
+
+相比无重叠的滑动窗口预测,有重叠的滑动窗口预测策略将本案例的模型精度miou从80.58%提升至81.52%,并且将预测可视化结果中裂痕感显著消除,可见下图中两种预测方式的效果对比。
+
+![](../../examples/remote_sensing/images/visualize_compare.jpg)
+
+运行以下脚本使用有重叠的滑动窗口进行预测:
+```
+python3 predict.py
+```
+
+## 模型评估
+
+在训练过程中,每隔10个迭代轮数会评估一次模型在验证集的精度。由于已事先将原始大尺寸图片切分成小块,此时相当于使用无重叠的大图切小图预测方式,最优模型精度miou为80.58%。运行以下脚本,将采用有重叠的大图切小图的预测方式,重新评估原始大尺寸图片的模型精度,此时miou为81.52%。
+```
+python3 eval.py
+```
diff --git a/docs/examples/remote_sensing/index.rst b/docs/examples/remote_sensing/index.rst
deleted file mode 100755
index dc375659be121c4bd04843fd281416a4d00ad865..0000000000000000000000000000000000000000
--- a/docs/examples/remote_sensing/index.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-遥感分割案例
-=======================================
-
-
-这里面写遥感分割案例,可根据需求拆分为多个文档
diff --git a/examples/multi-channel_remote_sensing/README.md b/examples/multi-channel_remote_sensing/README.md
index 8554e3d858ad7101d125d066ad1df19095eb2525..63ec786dee429652f16c472389708919ee33f4a7 100644
--- a/examples/multi-channel_remote_sensing/README.md
+++ b/examples/multi-channel_remote_sensing/README.md
@@ -14,7 +14,7 @@
* Paddle paddle >= 1.8.4
* Python >= 3.5
-* PaddleX >= 1.1.0
+* PaddleX >= 1.1.4
安装的相关问题参考[PaddleX安装](../../docs/install.md)
diff --git a/examples/remote_sensing/README.md b/examples/remote_sensing/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..2663754d1995bc810f1fafb838c4045ab4761faa
--- /dev/null
+++ b/examples/remote_sensing/README.md
@@ -0,0 +1,88 @@
+# RGB遥感影像分割
+
+本案例基于PaddleX实现遥感影像分割,提供滑动窗口预测方式,以避免在直接对大尺寸图片进行预测时显存不足的发生。此外,滑动窗口之间的重叠程度可配置,以此消除最终预测结果中各窗口拼接处的裂痕感。
+
+## 目录
+* [数据准备](#1)
+* [模型训练](#2)
+* [模型预测](#3)
+* [模型评估](#4)
+
+#### 前置依赖
+
+* Paddle paddle >= 1.8.4
+* Python >= 3.5
+* PaddleX >= 1.1.4
+
+安装的相关问题参考[PaddleX安装](../install.md)
+
+下载PaddleX源码:
+
+```
+git clone https://github.com/PaddlePaddle/PaddleX
+```
+
+该案例所有脚本均位于`PaddleX/examples/remote_sensing/`,进入该目录:
+
+```
+cd PaddleX/examples/remote_sensing/
+```
+
+##
数据准备
+
+本案例使用2015 CCF大数据比赛提供的高清遥感影像,包含5张带标注的RGB图像,图像尺寸最大有7969 × 7939、最小有4011 × 2470。该数据集共标注了5类物体,分别是背景(标记为0)、植被(标记为1)、建筑(标记为2)、水体(标记为3)、道路 (标记为4)。
+
+本案例将前4张图片划分入训练集,第5张图片作为验证集。为增加训练时的批量大小,以滑动窗口为(1024,1024)、步长为(512, 512)对前4张图片进行切分,加上原本的4张大尺寸图片,训练集一共有688张图片。在训练过程中直接对大图片进行验证会导致显存不足,为避免此类问题的出现,针对验证集以滑动窗口为(769, 769)、步长为(769,769)对第5张图片进行切分,得到40张子图片。
+
+运行以下脚本,下载原始数据集,并完成数据集的切分:
+
+```
+python3 prepare_data.py
+```
+
+## 模型训练
+
+分割模型选择Backbone为MobileNetv3_large_ssld的Deeplabv3模型,该模型兼备高性能高精度的优点。运行以下脚本,进行模型训练:
+```
+python3 train.py
+```
+
+也可以跳过模型训练步骤,直接下载预训练模型进行后续的模型预测和评估:
+```
+wget https://bj.bcebos.com/paddlex/examples/remote_sensing/models/ccf_remote_model.tar.gz
+tar -xvf ccf_remote_model.tar.gz
+```
+
+## 模型预测
+
+直接对大尺寸图片进行预测会导致显存不足,为避免此类问题的出现,本案例提供了滑动窗口预测接口,支持有重叠和无重叠两种方式。
+
+* 无重叠的滑动窗口预测
+
+在输入图片上以固定大小的窗口滑动,分别对每个窗口下的图像进行预测,最后将各窗口的预测结果拼接成输入图片的预测结果。由于每个窗口边缘部分的预测效果会比中间部分的差,因此每个窗口拼接处可能会有明显的裂痕感。
+
+该预测方式的API接口详见[overlap_tile_predict](https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#overlap-tile-predict),**使用时需要把参数`pad_size`设置为`[0, 0]`**。
+
+* 有重叠的滑动窗口预测
+
+在Unet论文中,作者提出一种有重叠的滑动窗口预测策略(Overlap-tile strategy)来消除拼接处的裂痕感。对各滑动窗口预测时,会向四周扩展一定的面积,对扩展后的窗口进行预测,例如下图中的蓝色部分区域,到拼接时只取各窗口中间部分的预测结果,例如下图中的黄色部分区域。位于输入图像边缘处的窗口,其扩展面积下的像素则通过将边缘部分像素镜像填补得到。
+
+该预测方式的API接口说明详见[overlap_tile_predict](https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#overlap-tile-predict)。
+
+![](images/overlap_tile.png)
+
+相比无重叠的滑动窗口预测,有重叠的滑动窗口预测策略将本案例的模型精度miou从80.58%提升至81.52%,并且将预测可视化结果中裂痕感显著消除,可见下图中两种预测方式的效果对比。
+
+![](images/visualize_compare.jpg)
+
+运行以下脚本使用有重叠的滑动窗口进行预测:
+```
+python3 predict.py
+```
+
+## 模型评估
+
+在训练过程中,每隔10个迭代轮数会评估一次模型在验证集的精度。由于已事先将原始大尺寸图片切分成小块,此时相当于使用无重叠的滑动窗口预测方式,最优模型精度miou为80.58%。运行以下脚本,将采用有重叠的滑动窗口预测方式,重新评估原始大尺寸图片的模型精度,此时miou为81.52%。
+```
+python3 eval.py
+```
diff --git a/examples/remote_sensing/eval.py b/examples/remote_sensing/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..540494501102fdc0ee0e0dd166e8a4ae77863589
--- /dev/null
+++ b/examples/remote_sensing/eval.py
@@ -0,0 +1,44 @@
+# 环境变量配置,用于控制是否使用GPU
+# 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu
+import os
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+import numpy as np
+import cv2
+from PIL import Image
+from collections import OrderedDict
+
+import paddlex as pdx
+import paddlex.utils.logging as logging
+from paddlex.cv.models.utils.seg_eval import ConfusionMatrix
+
+
+def update_confusion_matrix(confusion_matrix, predction, label):
+ pred = predction["label_map"]
+ pred = pred[np.newaxis, :, :, np.newaxis]
+ pred = pred.astype(np.int64)
+ label = label[np.newaxis, np.newaxis, :, :]
+ mask = label != model.ignore_index
+ confusion_matrix.calculate(pred=pred, label=label, ignore=mask)
+
+
+model_dir = 'output/deeplabv3p_mobilenetv3_large_ssld/best_model'
+img_file = "dataset/JPEGImages/5.png"
+label_file = "dataset/Annotations/5_class.png"
+
+model = pdx.load_model(model_dir)
+
+conf_mat = ConfusionMatrix(model.num_classes, streaming=True)
+
+# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#overlap-tile-predict
+overlap_tile_predict = model.overlap_tile_predict(
+ img_file=img_file, tile_size=(769, 769), pad_size=[64, 64], batch_size=32)
+
+label = np.asarray(Image.open(label_file))
+update_confusion_matrix(conf_mat, overlap_tile_predict, label)
+
+category_iou, miou = conf_mat.mean_iou()
+category_acc, macc = conf_mat.accuracy()
+logging.info(
+ "miou={:.6f} category_iou={} macc={:.6f} category_acc={} kappa={:.6f}".
+ format(miou, category_iou, macc, category_acc, conf_mat.kappa()))
diff --git a/examples/remote_sensing/images/overlap_tile.png b/examples/remote_sensing/images/overlap_tile.png
new file mode 100644
index 0000000000000000000000000000000000000000..60347caeb41b0807ad1cec84fac690e7318d20e1
Binary files /dev/null and b/examples/remote_sensing/images/overlap_tile.png differ
diff --git a/examples/remote_sensing/images/visualize_compare.jpg b/examples/remote_sensing/images/visualize_compare.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1833f143ede0d493c3e89533e3cf2caa567ca417
Binary files /dev/null and b/examples/remote_sensing/images/visualize_compare.jpg differ
diff --git a/examples/remote_sensing/predict.py b/examples/remote_sensing/predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..c22eeef3b71518c61efbddab19b809ec8650696e
--- /dev/null
+++ b/examples/remote_sensing/predict.py
@@ -0,0 +1,18 @@
+# 环境变量配置,用于控制是否使用GPU
+# 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu
+import os
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+import paddlex as pdx
+
+model_dir = 'output/deeplabv3p_mobilenetv3_large_ssld/best_model'
+img_file = "dataset/JPEGImages/5.png"
+save_dir = 'output/deeplabv3p_mobilenetv3_large_ssld/'
+
+model = pdx.load_model('output/deeplabv3p_mobilenetv3_large_ssld/best_model')
+
+# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#overlap-tile-predict
+pred = model.overlap_tile_predict(
+ img_file=img_file, tile_size=(769, 769), pad_size=[64, 64], batch_size=32)
+
+pdx.seg.visualize(img_file, pred, weight=0., save_dir=save_dir)
diff --git a/examples/remote_sensing/prepara_data.py b/examples/remote_sensing/prepara_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..6951bf8eb510564e06701c7ddb1c59fb4fc1b25b
--- /dev/null
+++ b/examples/remote_sensing/prepara_data.py
@@ -0,0 +1,95 @@
+import os
+import os.path as osp
+import numpy as np
+import cv2
+import shutil
+from PIL import Image
+import paddlex as pdx
+
+# 定义训练集切分时的滑动窗口大小和步长,格式为(W, H)
+train_tile_size = (1024, 1024)
+train_stride = (512, 512)
+# 定义验证集切分时的滑动窗口大小和步长,格式(W, H)
+val_tile_size = (769, 769)
+val_stride = (769, 769)
+
+# 下载并解压2015 CCF大数据比赛提供的高清遥感影像
+ccf_remote_dataset = 'https://bj.bcebos.com/paddlex/examples/remote_sensing/datasets/ccf_remote_dataset.tar.gz'
+pdx.utils.download_and_decompress(ccf_remote_dataset, path='./')
+
+if not osp.exists('./dataset/JPEGImages'):
+ os.makedirs('./dataset/JPEGImages')
+if not osp.exists('./dataset/Annotations'):
+ os.makedirs('./dataset/Annotations')
+
+# 将前4张图片划分入训练集,并切分成小块之后加入到训练集中
+# 并生成train_list.txt
+for train_id in range(1, 5):
+ shutil.copyfile("ccf_remote_dataset/{}.png".format(train_id),
+ "./dataset/JPEGImages/{}.png".format(train_id))
+ shutil.copyfile("ccf_remote_dataset/{}_class.png".format(train_id),
+ "./dataset/Annotations/{}_class.png".format(train_id))
+ mode = 'w' if train_id == 1 else 'a'
+ with open('./dataset/train_list.txt', mode) as f:
+ f.write("JPEGImages/{}.png Annotations/{}_class.png\n".format(
+ train_id, train_id))
+
+for train_id in range(1, 5):
+ image = cv2.imread('ccf_remote_dataset/{}.png'.format(train_id))
+ label = Image.open('ccf_remote_dataset/{}_class.png'.format(train_id))
+ H, W, C = image.shape
+ train_tile_id = 1
+ for h in range(0, H, train_stride[1]):
+ for w in range(0, W, train_stride[0]):
+ left = w
+ upper = h
+ right = min(w + train_tile_size[0] * 2, W)
+ lower = min(h + train_tile_size[1] * 2, H)
+ tile_image = image[upper:lower, left:right, :]
+ cv2.imwrite("./dataset/JPEGImages/{}_{}.png".format(
+ train_id, train_tile_id), tile_image)
+ cut_label = label.crop((left, upper, right, lower))
+ cut_label.save("./dataset/Annotations/{}_class_{}.png".format(
+ train_id, train_tile_id))
+ with open('./dataset/train_list.txt', 'a') as f:
+ f.write("JPEGImages/{}_{}.png Annotations/{}_class_{}.png\n".
+ format(train_id, train_tile_id, train_id,
+ train_tile_id))
+ train_tile_id += 1
+
+# 将第5张图片切分成小块之后加入到验证集中
+val_id = 5
+val_tile_id = 1
+shutil.copyfile("ccf_remote_dataset/{}.png".format(val_id),
+ "./dataset/JPEGImages/{}.png".format(val_id))
+shutil.copyfile("ccf_remote_dataset/{}_class.png".format(val_id),
+ "./dataset/Annotations/{}_class.png".format(val_id))
+image = cv2.imread('ccf_remote_dataset/{}.png'.format(val_id))
+label = Image.open('ccf_remote_dataset/{}_class.png'.format(val_id))
+H, W, C = image.shape
+for h in range(0, H, val_stride[1]):
+ for w in range(0, W, val_stride[0]):
+ left = w
+ upper = h
+ right = min(w + val_tile_size[0], W)
+ lower = min(h + val_tile_size[1], H)
+ cut_image = image[upper:lower, left:right, :]
+ cv2.imwrite("./dataset/JPEGImages/{}_{}.png".format(
+ val_id, val_tile_id), cut_image)
+ cut_label = label.crop((left, upper, right, lower))
+ cut_label.save("./dataset/Annotations/{}_class_{}.png".format(
+ val_id, val_tile_id))
+ mode = 'w' if val_tile_id == 1 else 'a'
+ with open('./dataset/val_list.txt', mode) as f:
+ f.write("JPEGImages/{}_{}.png Annotations/{}_class_{}.png\n".
+ format(val_id, val_tile_id, val_id, val_tile_id))
+ val_tile_id += 1
+
+# 生成labels.txt
+label_list = ['background', 'vegetation', 'building', 'water', 'road']
+for i, label in enumerate(label_list):
+ mode = 'w' if i == 0 else 'a'
+ with open('./dataset/labels.txt', 'a') as f:
+ name = "{}\n".format(label) if i < len(
+ label_list) - 1 else "{}".format(label)
+ f.write(name)
diff --git a/examples/remote_sensing/train.py b/examples/remote_sensing/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a7ff121a9e54648ef8aa754d77360cc14e871f8
--- /dev/null
+++ b/examples/remote_sensing/train.py
@@ -0,0 +1,55 @@
+# 环境变量配置,用于控制是否使用GPU
+# 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu
+import os
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+import paddlex as pdx
+from paddlex.seg import transforms
+
+# 定义训练和验证时的transforms
+# API说明 https://paddlex.readthedocs.io/zh_CN/develop/apis/transforms/seg_transforms.html
+train_transforms = transforms.Compose([
+ transforms.RandomPaddingCrop(crop_size=769),
+ transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(),
+ transforms.Normalize()
+])
+
+eval_transforms = transforms.Compose(
+ [transforms.Padding(target_size=769), transforms.Normalize()])
+
+# 定义训练和验证所用的数据集
+# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/datasets.html#paddlex-datasets-segdataset
+train_dataset = pdx.datasets.SegDataset(
+ data_dir='dataset',
+ file_list='dataset/train_list.txt',
+ label_list='dataset/labels.txt',
+ transforms=train_transforms,
+ shuffle=True)
+eval_dataset = pdx.datasets.SegDataset(
+ data_dir='dataset',
+ file_list='dataset/val_list.txt',
+ label_list='dataset/labels.txt',
+ transforms=eval_transforms)
+
+## 初始化模型,并进行训练
+## 可使用VisualDL查看训练指标,参考https://paddlex.readthedocs.io/zh_CN/develop/train/visualdl.html
+num_classes = len(train_dataset.labels)
+
+# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#paddlex-seg-deeplabv3p
+model = pdx.seg.DeepLabv3p(
+ num_classes=num_classes,
+ backbone='MobileNetV3_large_x1_0_ssld',
+ pooling_crop_size=(769, 769))
+
+# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#train
+# 各参数介绍与调整说明:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html
+model.train(
+ num_epochs=400,
+ train_dataset=train_dataset,
+ train_batch_size=16,
+ eval_dataset=eval_dataset,
+ learning_rate=0.01,
+ save_interval_epochs=10,
+ pretrain_weights='CITYSCAPES',
+ save_dir='output/deeplabv3p_mobilenetv3_large_ssld',
+ use_vdl=True)
diff --git a/paddlex/cv/datasets/dataset.py b/paddlex/cv/datasets/dataset.py
index 82a29f5443c56c9caab2ad725e72493e0bc4bd51..bedbc5ab63ce8f78fd1a40a5c9c89990e190f21f 100644
--- a/paddlex/cv/datasets/dataset.py
+++ b/paddlex/cv/datasets/dataset.py
@@ -239,9 +239,8 @@ def generate_minibatch(batch_data, label_padding_value=255, mapper=None):
_, label_h, label_w = data[1].shape
padding_label[:, :label_h, :label_w] = data[1]
padding_batch.append((padding_im, padding_label))
- elif len(data[1]) == 0 or isinstance(
- data[1][0],
- tuple) and data[1][0][0] in ['resize', 'padding']:
+ elif len(data[1]) == 0 or isinstance(data[1][0], tuple) and data[
+ 1][0][0] in ['origin_shape', 'resize', 'padding']:
# padding the image and insert 'padding' into `im_info`
# of segmentation during the infering phase
if len(data[1]) == 0 or 'padding' not in [
diff --git a/paddlex/cv/models/deeplabv3p.py b/paddlex/cv/models/deeplabv3p.py
index 49a6a1d33e31ccc871df7c02301f40ba606a51dc..8f7341c7107f6103a6c49ca4fc615c41f2231af6 100644
--- a/paddlex/cv/models/deeplabv3p.py
+++ b/paddlex/cv/models/deeplabv3p.py
@@ -24,6 +24,7 @@ import paddlex.utils.logging as logging
import paddlex
from paddlex.cv.transforms import arrange_transforms
from paddlex.cv.datasets import generate_minibatch
+from paddlex.cv.transforms.seg_transforms import Compose
from collections import OrderedDict
from .base import BaseAPI
from .utils.seg_eval import ConfusionMatrix
@@ -448,7 +449,11 @@ class DeepLabv3p(BaseAPI):
return metrics
@staticmethod
- def _preprocess(images, transforms, model_type, class_name, thread_pool=None):
+ def _preprocess(images,
+ transforms,
+ model_type,
+ class_name,
+ thread_pool=None):
arrange_transforms(
model_type=model_type,
class_name=class_name,
@@ -554,3 +559,102 @@ class DeepLabv3p(BaseAPI):
preds = DeepLabv3p._postprocess(result, im_info)
return preds
+
+ def overlap_tile_predict(self,
+ img_file,
+ tile_size=[512, 512],
+ pad_size=[64, 64],
+ batch_size=32,
+ transforms=None):
+ """有重叠的大图切小图预测。
+ Args:
+ img_file(str|np.ndarray): 预测图像路径,或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。
+ tile_size(list|tuple): 滑动窗口的大小,该区域内用于拼接预测结果,格式为(W,H)。默认值为[512, 512]。
+ pad_size(list|tuple): 滑动窗口向四周扩展的大小,扩展区域内不用于拼接预测结果,格式为(W,H)。默认值为[64,64]。
+ batch_size(int):对窗口进行批量预测时的批量大小。默认值为32
+ transforms(paddlex.cv.transforms): 数据预处理操作。
+
+
+ Returns:
+ dict: 包含关键字'label_map'和'score_map', 'label_map'存储预测结果灰度图,
+ 像素值表示对应的类别,'score_map'存储各类别的概率,shape=(h, w, num_classes)
+ """
+
+ if transforms is None and not hasattr(self, 'test_transforms'):
+ raise Exception("transforms need to be defined, now is None.")
+
+ if isinstance(img_file, str):
+ image, _ = Compose.decode_image(img_file, None)
+ elif isinstance(img_file, np.ndarray):
+ image = img_file.copy()
+ else:
+ raise Exception("im_file must be list/tuple")
+
+ height, width, channel = image.shape
+ image_tile_list = list()
+
+ # Padding along the left and right sides
+ if pad_size[0] > 0:
+ left_pad = cv2.flip(image[0:height, 0:pad_size[0], :], 1)
+ right_pad = cv2.flip(image[0:height, -pad_size[0]:width, :], 1)
+ padding_image = cv2.hconcat([left_pad, image])
+ padding_image = cv2.hconcat([padding_image, right_pad])
+ else:
+ import copy
+ padding_image = copy.deepcopy(image)
+
+ # Padding along the upper and lower sides
+ padding_height, padding_width, _ = padding_image.shape
+ if pad_size[1] > 0:
+ upper_pad = cv2.flip(
+ padding_image[0:pad_size[1], 0:padding_width, :], 0)
+ lower_pad = cv2.flip(
+ padding_image[-pad_size[1]:padding_height, 0:padding_width, :],
+ 0)
+ padding_image = cv2.vconcat([upper_pad, padding_image])
+ padding_image = cv2.vconcat([padding_image, lower_pad])
+
+ # crop the padding image into tile pieces
+ padding_height, padding_width, _ = padding_image.shape
+
+ for h_id in range(0, height // tile_size[1] + 1):
+ for w_id in range(0, width // tile_size[0] + 1):
+ left = w_id * tile_size[0]
+ upper = h_id * tile_size[1]
+ right = min(left + tile_size[0] + pad_size[0] * 2,
+ padding_width)
+ lower = min(upper + tile_size[1] + pad_size[1] * 2,
+ padding_height)
+ image_tile = padding_image[upper:lower, left:right, :]
+ image_tile_list.append(image_tile)
+
+ # predict
+ label_map = np.zeros((height, width), dtype=np.uint8)
+ score_map = np.zeros(
+ (height, width, self.num_classes), dtype=np.float32)
+ num_tiles = len(image_tile_list)
+ for i in range(0, num_tiles, batch_size):
+ begin = i
+ end = min(i + batch_size, num_tiles)
+ res = self.batch_predict(
+ img_file_list=image_tile_list[begin:end],
+ transforms=transforms)
+ for j in range(begin, end):
+ h_id = j // (width // tile_size[0] + 1)
+ w_id = j % (width // tile_size[0] + 1)
+ left = w_id * tile_size[0]
+ upper = h_id * tile_size[1]
+ right = min((w_id + 1) * tile_size[0], width)
+ lower = min((h_id + 1) * tile_size[1], height)
+ tile_label_map = res[j - begin]["label_map"]
+ tile_score_map = res[j - begin]["score_map"]
+ tile_upper = pad_size[1]
+ tile_lower = tile_label_map.shape[0] - pad_size[1]
+ tile_left = pad_size[0]
+ tile_right = tile_label_map.shape[1] - pad_size[0]
+ label_map[upper:lower, left:right] = \
+ tile_label_map[tile_upper:tile_lower, tile_left:tile_right]
+ score_map[upper:lower, left:right, :] = \
+ tile_score_map[tile_upper:tile_lower, tile_left:tile_right, :]
+ result = {"label_map": label_map, "score_map": score_map}
+ return result
diff --git a/paddlex/cv/transforms/seg_transforms.py b/paddlex/cv/transforms/seg_transforms.py
index c482930ca18a39a2e684c17d470b931dfc6e5823..a59e405f627e9fa7ab00f12c6eb38668ee0734e5 100644
--- a/paddlex/cv/transforms/seg_transforms.py
+++ b/paddlex/cv/transforms/seg_transforms.py
@@ -723,28 +723,25 @@ class Padding(SegTransform):
target_width = self.target_size[0]
pad_height = target_height - im_height
pad_width = target_width - im_width
- if pad_height < 0 or pad_width < 0:
- raise ValueError(
- 'the size of image should be less than target_size, but the size of image ({}, {}), is larger than target_size ({}, {})'
- .format(im_width, im_height, target_width, target_height))
- else:
- im = cv2.copyMakeBorder(
- im,
+ pad_height = max(pad_height, 0)
+ pad_width = max(pad_width, 0)
+ im = cv2.copyMakeBorder(
+ im,
+ 0,
+ pad_height,
+ 0,
+ pad_width,
+ cv2.BORDER_CONSTANT,
+ value=self.im_padding_value)
+ if label is not None:
+ label = cv2.copyMakeBorder(
+ label,
0,
pad_height,
0,
pad_width,
cv2.BORDER_CONSTANT,
- value=self.im_padding_value)
- if label is not None:
- label = cv2.copyMakeBorder(
- label,
- 0,
- pad_height,
- 0,
- pad_width,
- cv2.BORDER_CONSTANT,
- value=self.label_padding_value)
+ value=self.label_padding_value)
if label is None:
return (im, im_info)
else:
diff --git a/paddlex/deploy.py b/paddlex/deploy.py
index e7a9264240ff52007ad3480ed794064cc171320f..747cf16454e16d0daa7d5e415a16faee55448ce5 100644
--- a/paddlex/deploy.py
+++ b/paddlex/deploy.py
@@ -94,7 +94,7 @@ class Predictor:
use_gpu=True,
gpu_id=0,
use_mkl=False,
- mkl_thread_num=psutil.cpu_count(),
+ mkl_thread_num=mp.cpu_count(),
use_trt=False,
use_glog=False,
memory_optimize=True):
diff --git a/tutorials/train/object_detection/ppyolo.py b/tutorials/train/object_detection/ppyolo.py
index 63b47a95671692e89761251e9a1059cac9b542eb..7f1d4e32dd055851babd6eed5f823d4ea9c637e1 100644
--- a/tutorials/train/object_detection/ppyolo.py
+++ b/tutorials/train/object_detection/ppyolo.py
@@ -52,7 +52,7 @@ model.train(
train_dataset=train_dataset,
train_batch_size=8,
eval_dataset=eval_dataset,
- learning_rate=0.000125,
+ learning_rate=0.0005,
lr_decay_epochs=[210, 240],
save_dir='output/ppyolo',
use_vdl=True)