diff --git a/contrib/RemoteSensing/README.md b/contrib/RemoteSensing/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3b84b5a274a882f892be4efb4a44f72b50ae8c2c --- /dev/null +++ b/contrib/RemoteSensing/README.md @@ -0,0 +1,244 @@ +# 遥感分割(RemoteSensing) +遥感影像分割是图像分割领域中的重要应用场景,广泛应用于土地测绘、环境监测、城市建设等领域。遥感影像分割的目标多种多样,有诸如积雪、农作物、道路、建筑、水源等地物目标,也有例如云层的空中目标。 + +PaddleSeg提供了针对遥感专题的语义分割库RemoteSensing,涵盖图像预处理、数据增强、模型训练、预测流程,帮助大家利用深度学习技术解决遥感影像分割问题。 + +针对遥感数据多通道、分布范围大、分布不均的特点,我们支持多通道训练预测,内置一系列多通道预处理和数据增强的策略,可结合实际业务场景进行定制组合,提升模型泛化能力和鲁棒性。 + +**Note:** 所有命令需要在`PaddleSeg/contrib/RemoteSensing/`目录下执行。 + +## 前置依赖 +- Paddle 1.7.1+ +由于图像分割模型计算开销大,推荐在GPU版本的PaddlePaddle下使用。 +PaddlePaddle的安装, 请按照[官网指引](https://paddlepaddle.org.cn/install/quick)安装合适自己的版本。 + +- Python 3.5+ + +- 其他依赖安装 +通过以下命令安装python包依赖,请确保至少执行过一次以下命令: +``` +cd RemoteSensing +pip install -r requirements.txt +``` + +## 目录结构说明 + ``` +RemoteSensing # 根目录 + |-- dataset # 数据集 + |-- docs # 文档 + |-- models # 模型类定义模块 + |-- nets # 组网模块 + |-- readers # 数据读取模块 + |-- tools # 工具集 + |-- transforms # 数据增强模块 + |-- utils # 公用模块 + |-- train_demo.py # 训练demo脚本 + |-- predict_demo.py # 预测demo脚本 + |-- README.md # 使用手册 + + ``` +## 数据协议 +数据集包含原图、标注图及相应的文件列表文件。 + +参考数据文件结构如下: +``` +./dataset/ # 数据集根目录 +|--images # 原图目录 +| |--xxx1.npy +| |--... +| └--... +| +|--annotations # 标注图目录 +| |--xxx1.png +| |--... +| └--... +| +|--train_list.txt # 训练文件列表文件 +| +|--val_list.txt # 验证文件列表文件 +| +└--labels.txt # 标签列表 + +``` +其中,相应的文件名可根据需要自行定义。 + +由于遥感领域图像格式多种多样,不同传感器产生的数据格式可能不同。本分割库目前采用npy格式作为遥感数据的格式,采用png无损压缩格式作为标注图片格式。 + +标注图像为单通道图像,像素值即为对应的类别,像素标注类别需要从0开始递增, +例如0,1,2,3表示有4种类别,标注类别最多为256类。其中可以指定特定的像素值用于表示该值的像素不参与训练和评估(默认为255)。 + +`train_list.txt`和`val_list.txt`文本以空格为分割符分为两列,第一列为图像文件相对于dataset的相对路径,第二列为标注图像文件相对于dataset的相对路径。如下所示: +``` +images/xxx1.npy annotations/xxx1.png +images/xxx2.npy annotations/xxx2.png +... +``` + +具体要求和如何生成文件列表可参考[文件列表规范](../../docs/data_prepare.md#文件列表)。 + +`labels.txt`: 每一行为一个单独的类别,相应的行号即为类别对应的id(行号从0开始),如下所示: +``` +labelA +labelB +... +``` + + + +## 快速上手 + +本章节在一个小数据集上展示了如何通过RemoteSensing进行训练预测。 + +### 1. 准备数据集 +为了快速体验,我们准备了一个小型demo数据集,已位于`RemoteSensing/dataset/demo/`目录下. + +对于您自己的数据集,您需要按照上述的数据协议进行格式转换,可分别使用numpy和pil库保存遥感数据和标注图片。其中numpy api示例如下: +```python +import numpy as np + +# 保存遥感数据 +# img类型:numpy.ndarray +np.save(save_path, img) +``` + +### 2. 训练代码开发 +通过如下`train_demo.py`代码进行训练。 + +> 导入RemoteSensing api +```python +import transforms.transforms as T +from readers.reader import Reader +from models import UNet +``` + +> 定义训练和验证时的数据处理和增强流程, 在`train_transforms`中加入了`RandomVerticalFlip`,`RandomHorizontalFlip`等数据增强方式。 +```python +train_transforms = T.Compose([ + T.RandomVerticalFlip(0.5), + T.RandomHorizontalFlip(0.5), + T.ResizeStepScaling(0.5, 2.0, 0.25), + T.RandomPaddingCrop(256), + T.Normalize(mean=[0.5] * channel, std=[0.5] * channel), +]) + +eval_transforms = T.Compose([ + T.Normalize(mean=[0.5] * channel, std=[0.5] * channel), +]) +``` + +> 定义数据读取器 +```python +import os +import os.path as osp + +train_list = osp.join(data_dir, 'train.txt') +val_list = osp.join(data_dir, 'val.txt') +label_list = osp.join(data_dir, 'labels.txt') + +train_reader = Reader( + data_dir=data_dir, + file_list=train_list, + label_list=label_list, + transforms=train_transforms, + num_workers=8, + buffer_size=16, + shuffle=True, + parallel_method='thread') + +eval_reader = Reader( + data_dir=data_dir, + file_list=val_list, + label_list=label_list, + transforms=eval_transforms, + num_workers=8, + buffer_size=16, + shuffle=False, + parallel_method='thread') +``` +> 模型构建 +```python +model = UNet( + num_classes=2, input_channel=channel, use_bce_loss=True, use_dice_loss=True) +``` +> 模型训练,并开启边训边评估 +```python +model.train( + num_epochs=num_epochs, + train_reader=train_reader, + train_batch_size=train_batch_size, + eval_reader=eval_reader, + save_interval_epochs=5, + log_interval_steps=10, + save_dir=save_dir, + pretrain_weights=None, + optimizer=None, + learning_rate=lr, +) +``` + + +### 3. 模型训练 +> 设置GPU卡号 +```shell script +export CUDA_VISIBLE_DEVICES=0 +``` +> 在RemoteSensing目录下运行`train_demo.py`即可开始训练。 +```shell script +python train_demo.py --data_dir dataset/demo/ --save_dir saved_model/unet/ --channel 3 --num_epochs 20 +``` +### 4. 模型预测代码开发 +通过如下`predict_demo.py`代码进行预测。 + +> 导入RemoteSensing api +```python +from models import load_model +``` +> 加载训练过程中最好的模型,设置预测结果保存路径。 +```python +import os +import os.path as osp +model = load_model(osp.join(save_dir, 'best_model')) +pred_dir = osp.join(save_dir, 'pred') +if not osp.exists(pred_dir): + os.mkdir(pred_dir) +``` + +> 使用模型对验证集进行测试,并保存预测结果。 +```python +import numpy as np +from PIL import Image as Image +val_list = osp.join(data_dir, 'val.txt') +color_map = [0, 0, 0, 255, 255, 255] +with open(val_list) as f: + lines = f.readlines() + for line in lines: + img_path = line.split(' ')[0] + print('Predicting {}'.format(img_path)) + img_path_ = osp.join(data_dir, img_path) + + pred = model.predict(img_path_) + + # 以伪彩色png图片保存预测结果 + pred_name = osp.basename(img_path).rstrip('npy') + 'png' + pred_path = osp.join(pred_dir, pred_name) + pred_mask = Image.fromarray(pred.astype(np.uint8), mode='P') + pred_mask.putpalette(color_map) + pred_mask.save(pred_path) +``` + +### 5. 模型预测 +> 设置GPU卡号 +```shell script +export CUDA_VISIBLE_DEVICES=0 +``` +> 在RemoteSensing目录下运行`predict_demo.py`即可开始训练。 +```shell script +python predict_demo.py --data_dir dataset/demo/ --load_model_dir saved_model/unet/ +``` + + +## Api说明 + +您可以使用`RemoteSensing`目录下提供的api构建自己的分割代码。 + +- [数据处理-transforms](docs/transforms.md) diff --git a/contrib/RemoteSensing/__init__.py b/contrib/RemoteSensing/__init__.py index 236eb52217024227624afb662a4a421b7ee16c1c..ea9751219a9eda9e50a80e9dff2a8b3d7cba0066 100644 --- a/contrib/RemoteSensing/__init__.py +++ b/contrib/RemoteSensing/__init__.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import utils -from . import nets -from . import models -from . import transforms -from . import readers -from .utils.utils import get_environ_info +import utils +import nets +import models +import transforms +import readers +from utils.utils import get_environ_info env_info = get_environ_info() diff --git a/contrib/RemoteSensing/dataset/demo/annotations/0.png b/contrib/RemoteSensing/dataset/demo/annotations/0.png new file mode 100644 index 0000000000000000000000000000000000000000..cf1b91544aac136d78f25c6818ae3aaf8aca23bb Binary files /dev/null and b/contrib/RemoteSensing/dataset/demo/annotations/0.png differ diff --git a/contrib/RemoteSensing/dataset/demo/annotations/1.png b/contrib/RemoteSensing/dataset/demo/annotations/1.png new file mode 100644 index 0000000000000000000000000000000000000000..b9f0d5df904ff9cd9df1bffdf456d48e3fae38f9 Binary files /dev/null and b/contrib/RemoteSensing/dataset/demo/annotations/1.png differ diff --git a/contrib/RemoteSensing/dataset/demo/annotations/10.png b/contrib/RemoteSensing/dataset/demo/annotations/10.png new file mode 100644 index 0000000000000000000000000000000000000000..59950c118bd981bcfaa805e27a3d4929daa7b213 Binary files /dev/null and b/contrib/RemoteSensing/dataset/demo/annotations/10.png differ diff --git a/contrib/RemoteSensing/dataset/demo/annotations/100.png b/contrib/RemoteSensing/dataset/demo/annotations/100.png new file mode 100644 index 0000000000000000000000000000000000000000..6fef4400ce8e4e0f13937bca398ba50a4aab729a Binary files /dev/null and b/contrib/RemoteSensing/dataset/demo/annotations/100.png differ diff --git a/contrib/RemoteSensing/dataset/demo/annotations/1000.png b/contrib/RemoteSensing/dataset/demo/annotations/1000.png new file mode 100644 index 0000000000000000000000000000000000000000..891dfdcaa591640a9dab4b046bdf34f8606d282b Binary files /dev/null and b/contrib/RemoteSensing/dataset/demo/annotations/1000.png differ diff --git a/contrib/RemoteSensing/dataset/demo/annotations/1001.png b/contrib/RemoteSensing/dataset/demo/annotations/1001.png new file mode 100644 index 0000000000000000000000000000000000000000..891dfdcaa591640a9dab4b046bdf34f8606d282b Binary files /dev/null and b/contrib/RemoteSensing/dataset/demo/annotations/1001.png differ diff --git a/contrib/RemoteSensing/dataset/demo/annotations/1002.png b/contrib/RemoteSensing/dataset/demo/annotations/1002.png new file mode 100644 index 0000000000000000000000000000000000000000..e247cb90c9044a8044e0e595a9917e14d69d42de Binary files /dev/null and b/contrib/RemoteSensing/dataset/demo/annotations/1002.png differ diff --git a/contrib/RemoteSensing/dataset/demo/annotations/1003.png b/contrib/RemoteSensing/dataset/demo/annotations/1003.png new file mode 100644 index 0000000000000000000000000000000000000000..f98df538a2c1c027deeb3e6530d04e8900ef6e07 Binary files /dev/null and b/contrib/RemoteSensing/dataset/demo/annotations/1003.png differ diff --git a/contrib/RemoteSensing/dataset/demo/annotations/1004.png b/contrib/RemoteSensing/dataset/demo/annotations/1004.png new file mode 100644 index 0000000000000000000000000000000000000000..1da4b7b5bcdb9ff3f3b70438409753e0e4e28fe4 Binary files /dev/null and b/contrib/RemoteSensing/dataset/demo/annotations/1004.png differ diff --git a/contrib/RemoteSensing/dataset/demo/annotations/1005.png b/contrib/RemoteSensing/dataset/demo/annotations/1005.png new file mode 100644 index 0000000000000000000000000000000000000000..09173b87cbc02dbe92fcfd92b0ea376fb8d8a91d Binary files /dev/null and b/contrib/RemoteSensing/dataset/demo/annotations/1005.png differ diff --git a/contrib/RemoteSensing/dataset/demo/images/0.npy b/contrib/RemoteSensing/dataset/demo/images/0.npy new file mode 100644 index 0000000000000000000000000000000000000000..4cbb1c56d8d902629585fb20e026f35773a5f7a4 Binary files /dev/null and b/contrib/RemoteSensing/dataset/demo/images/0.npy differ diff --git a/contrib/RemoteSensing/dataset/demo/images/1.npy b/contrib/RemoteSensing/dataset/demo/images/1.npy new file mode 100644 index 0000000000000000000000000000000000000000..11b6433300481381a2877da6453e04a7f116c4aa Binary files /dev/null and b/contrib/RemoteSensing/dataset/demo/images/1.npy differ diff --git a/contrib/RemoteSensing/dataset/demo/images/10.npy b/contrib/RemoteSensing/dataset/demo/images/10.npy new file mode 100644 index 0000000000000000000000000000000000000000..cfbf1ab896203d4962ccb254ad046487648af8ce Binary files /dev/null and b/contrib/RemoteSensing/dataset/demo/images/10.npy differ diff --git a/contrib/RemoteSensing/dataset/demo/images/100.npy b/contrib/RemoteSensing/dataset/demo/images/100.npy new file mode 100644 index 0000000000000000000000000000000000000000..7162a79fc2ce958e86b1f728f97a4266b5b4f6cd Binary files /dev/null and b/contrib/RemoteSensing/dataset/demo/images/100.npy differ diff --git a/contrib/RemoteSensing/dataset/demo/images/1000.npy b/contrib/RemoteSensing/dataset/demo/images/1000.npy new file mode 100644 index 0000000000000000000000000000000000000000..7ddf3cb11b906a0776a0e407090a0ddefe5980f9 Binary files /dev/null and b/contrib/RemoteSensing/dataset/demo/images/1000.npy differ diff --git a/contrib/RemoteSensing/dataset/demo/images/1001.npy b/contrib/RemoteSensing/dataset/demo/images/1001.npy new file mode 100644 index 0000000000000000000000000000000000000000..cbf6b692692cb57f0d66f6f6908361e1315e0b89 Binary files /dev/null and b/contrib/RemoteSensing/dataset/demo/images/1001.npy differ diff --git a/contrib/RemoteSensing/dataset/demo/images/1002.npy b/contrib/RemoteSensing/dataset/demo/images/1002.npy new file mode 100644 index 0000000000000000000000000000000000000000..d5d4a4775248299347f430575c4716511f24a808 Binary files /dev/null and b/contrib/RemoteSensing/dataset/demo/images/1002.npy differ diff --git a/contrib/RemoteSensing/dataset/demo/images/1003.npy b/contrib/RemoteSensing/dataset/demo/images/1003.npy new file mode 100644 index 0000000000000000000000000000000000000000..9b4c94db3368ded7f615f20e2943dbd8b9a75372 Binary files /dev/null and b/contrib/RemoteSensing/dataset/demo/images/1003.npy differ diff --git a/contrib/RemoteSensing/dataset/demo/images/1004.npy b/contrib/RemoteSensing/dataset/demo/images/1004.npy new file mode 100644 index 0000000000000000000000000000000000000000..6b2f51dfc0893da79208cb6602baa403bd1a35ea Binary files /dev/null and b/contrib/RemoteSensing/dataset/demo/images/1004.npy differ diff --git a/contrib/RemoteSensing/dataset/demo/images/1005.npy b/contrib/RemoteSensing/dataset/demo/images/1005.npy new file mode 100644 index 0000000000000000000000000000000000000000..21198e2cbe958e96d4fbeab81f1c88026b9d2fab Binary files /dev/null and b/contrib/RemoteSensing/dataset/demo/images/1005.npy differ diff --git a/contrib/RemoteSensing/dataset/demo/labels.txt b/contrib/RemoteSensing/dataset/demo/labels.txt new file mode 100644 index 0000000000000000000000000000000000000000..69548aabb6c89d4535c6567b7e1160c3ba2874ca --- /dev/null +++ b/contrib/RemoteSensing/dataset/demo/labels.txt @@ -0,0 +1,2 @@ +__background__ +cloud \ No newline at end of file diff --git a/contrib/RemoteSensing/dataset/demo/train.txt b/contrib/RemoteSensing/dataset/demo/train.txt new file mode 100644 index 0000000000000000000000000000000000000000..babb17608b22ecda5c38db00e11e6c4579722784 --- /dev/null +++ b/contrib/RemoteSensing/dataset/demo/train.txt @@ -0,0 +1,7 @@ +images/1001.npy annotations/1001.png +images/1002.npy annotations/1002.png +images/1005.npy annotations/1005.png +images/0.npy annotations/0.png +images/1003.npy annotations/1003.png +images/1000.npy annotations/1000.png +images/1004.npy annotations/1004.png diff --git a/contrib/RemoteSensing/dataset/demo/val.txt b/contrib/RemoteSensing/dataset/demo/val.txt new file mode 100644 index 0000000000000000000000000000000000000000..073dbf76d4309dfeea0b242e6eace3bc6024ba61 --- /dev/null +++ b/contrib/RemoteSensing/dataset/demo/val.txt @@ -0,0 +1,3 @@ +images/100.npy annotations/100.png +images/1.npy annotations/1.png +images/10.npy annotations/10.png diff --git a/contrib/RemoteSensing/docs/transforms.md b/contrib/RemoteSensing/docs/transforms.md new file mode 100644 index 0000000000000000000000000000000000000000..a35e6cd1bdcf03dc84687a6bb7a4e13c274dc572 --- /dev/null +++ b/contrib/RemoteSensing/docs/transforms.md @@ -0,0 +1,145 @@ +# transforms.transforms + +对用于分割任务的数据进行操作。可以利用[Compose](#compose)类将图像预处理/增强操作进行组合。 + + +## Compose类 +```python +transforms.transforms.Compose(transforms) +``` +根据数据预处理/数据增强列表对输入数据进行操作。 +### 参数 +* **transforms** (list): 数据预处理/数据增强列表。 + + +## RandomHorizontalFlip类 +```python +transforms.transforms.RandomHorizontalFlip(prob=0.5) +``` +以一定的概率对图像进行水平翻转,模型训练时的数据增强操作。 +### 参数 +* **prob** (float): 随机水平翻转的概率。默认值为0.5。 + + +## RandomVerticalFlip类 +```python +transforms.transforms.RandomVerticalFlip(prob=0.1) +``` +以一定的概率对图像进行垂直翻转,模型训练时的数据增强操作。 +### 参数 +* **prob** (float): 随机垂直翻转的概率。默认值为0.1。 + + +## Resize类 +```python +transforms.transforms.Resize(target_size, interp='LINEAR') +``` +调整图像大小(resize)。 + +- 当目标大小(target_size)类型为int时,根据插值方式, + 将图像resize为[target_size, target_size]。 +- 当目标大小(target_size)类型为list或tuple时,根据插值方式, + 将图像resize为target_size, target_size的输入应为[w, h]或(w, h)。 +### 参数 +* **target_size** (int|list|tuple): 目标大小 +* **interp** (str): resize的插值方式,与opencv的插值方式对应, +可选的值为['NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4'],默认为"LINEAR"。 + + +## ResizeByLong类 +```python +transforms.transforms.ResizeByLong(long_size) +``` +对图像长边resize到固定值,短边按比例进行缩放。 +### 参数 +* **long_size** (int): resize后图像的长边大小。 + + +## ResizeRangeScaling类 +```python +transforms.transforms.ResizeRangeScaling(min_value=400, max_value=600) +``` +对图像长边随机resize到指定范围内,短边按比例进行缩放,模型训练时的数据增强操作。 +### 参数 +* **min_value** (int): 图像长边resize后的最小值。默认值400。 +* **max_value** (int): 图像长边resize后的最大值。默认值600。 + + +## ResizeStepScaling类 +```python +transforms.transforms.ResizeStepScaling(min_scale_factor=0.75, max_scale_factor=1.25, scale_step_size=0.25) +``` +对图像按照某一个比例resize,这个比例以scale_step_size为步长,在[min_scale_factor, max_scale_factor]随机变动,模型训练时的数据增强操作。 +### 参数 +* **min_scale_factor**(float), resize最小尺度。默认值0.75。 +* **max_scale_factor** (float), resize最大尺度。默认值1.25。 +* **scale_step_size** (float), resize尺度范围间隔。默认值0.25。 + + +## Clip类 +```python +transforms.transforms.Clip(min_val=[0, 0, 0], max_val=[255.0, 255.0, 255.0]) +``` +对图像上超出一定范围的数据进行裁剪。 + +### 参数 +* **min_var** (list): 裁剪的下限,小于min_val的数值均设为min_val. 默认值[0, 0, 0]. +* **max_var** (list): 裁剪的上限,大于max_val的数值均设为max_val. 默认值[255.0, 255.0, 255.0] + + +## Normalize类 +```python +transforms.transforms.Normalize(min_val=[0, 0, 0], max_val=[255.0, 255.0, 255.0], mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) +``` +对图像进行标准化。 + +1.图像像素归一化到区间 [0.0, 1.0]。 +2.对图像进行减均值除以标准差操作。 +### 参数 +* **min_val** (list): 图像数据集的最小值。默认值[0, 0, 0]. +* **max_val** (list): 图像数据集的最大值。默认值[255.0, 255.0, 255.0] +* **mean** (list): 图像数据集的均值。默认值[0.5, 0.5, 0.5]。 +* **std** (list): 图像数据集的标准差。默认值[0.5, 0.5, 0.5]。 + + +## Padding类 +```python +transforms.transforms.Padding(target_size, im_padding_value=127.5, label_padding_value=255) +``` +对图像或标注图像进行padding,padding方向为右和下。根据提供的值对图像或标注图像进行padding操作。 +### 参数 +* **target_size** (int|list|tuple): padding后图像的大小。 +* **im_padding_value** (list): 图像padding的值。默认为127.5 +* **label_padding_value** (int): 标注图像padding的值。默认值为255(仅在训练时需要设定该参数)。 + + +## RandomPaddingCrop类 +```python +transforms.transforms.RandomPaddingCrop(crop_size=512, im_padding_value=127.5, label_padding_value=255) +``` +对图像和标注图进行随机裁剪,当所需要的裁剪尺寸大于原图时,则进行padding操作,模型训练时的数据增强操作。 +### 参数 +* **crop_size**(int|list|tuple): 裁剪图像大小。默认为512。 +* **im_padding_value** (list): 图像padding的值。默认为127.5。 +* **label_padding_value** (int): 标注图像padding的值。默认值为255。 + + +## RandomBlur类 +```python +transforms.transforms.RandomBlur(prob=0.1) +``` +以一定的概率对图像进行高斯模糊,模型训练时的数据增强操作。 +### 参数 +* **prob** (float): 图像模糊概率。默认为0.1。 + + +## RandomScaleAspect类 +```python +transforms.transforms.RandomScaleAspect(min_scale=0.5, aspect_ratio=0.33) +``` +裁剪并resize回原始尺寸的图像和标注图像,模型训练时的数据增强操作。 + +按照一定的面积比和宽高比对图像进行裁剪,并reszie回原始图像的图像,当存在标注图时,同步进行。 +### 参数 +* **min_scale** (float):裁取图像占原始图像的面积比,取值[0,1],为0时则返回原图。默认为0.5。 +* **aspect_ratio** (float): 裁取图像的宽高比范围,非负值,为0时返回原图。默认为0.33。 diff --git a/contrib/RemoteSensing/main.py b/contrib/RemoteSensing/main.py deleted file mode 100644 index 346a595a931928478a19c4c69a6a9901a1587b73..0000000000000000000000000000000000000000 --- a/contrib/RemoteSensing/main.py +++ /dev/null @@ -1,108 +0,0 @@ -import sys -import os -import os.path as osp -import cv2 -import numpy as np -from PIL import Image as Image - -#================================setting======================== -os.environ['CUDA_VISIBLE_DEVICES'] = "1" - -batch_size = 4 -channel = 10 -epochs = 1 - -save_dir = 'saved_model/snow2019_unet_all_channel_vertical' -data_dir = "../../../dataset/snow2019/all_channel_data/" -#============================================================= - -sys.path.append(osp.join(os.getcwd(), '..')) -import RemoteSensing.transforms.transforms as T -from RemoteSensing.readers.reader import Reader -from RemoteSensing.models import UNet, load_model - -if not os.path.exists(save_dir): - os.makedirs(save_dir) - -train_list = osp.join(data_dir, 'train.txt') -val_list = osp.join(data_dir, 'val.txt') -label_list = osp.join(data_dir, 'labels.txt') - -os.system('cp ./{} {}'.format(__file__, osp.join(save_dir, __file__))) - -train_transforms = T.Compose([ - T.RandomVerticalFlip(0.5), - T.RandomHorizontalFlip(0.5), - T.ResizeStepScaling(0.5, 2.0, 0.25), - T.RandomPaddingCrop(769), - T.Normalize(mean=[0.5] * channel, std=[0.5] * channel), -]) - -eval_transforms = T.Compose([ - T.Padding([1049, 1049]), - T.Normalize(mean=[0.5] * channel, std=[0.5] * channel), -]) - -test_transforms = T.Compose([ - T.Padding([1049, 1049]), - T.Normalize(mean=[0.5] * channel, std=[0.5] * channel), -]) - -train_reader = Reader( - data_dir=data_dir, - file_list=train_list, - label_list=label_list, - transforms=train_transforms, - num_workers=8, - buffer_size=16, - shuffle=True, - parallel_method='thread') - -eval_reader = Reader( - data_dir=data_dir, - file_list=val_list, - label_list=label_list, - transforms=eval_transforms, - num_workers=8, - buffer_size=16, - shuffle=False, - parallel_method='thread') - -model = UNet( - num_classes=2, input_channel=channel, use_bce_loss=True, use_dice_loss=True) - -model.train( - num_epochs=epochs, - train_reader=train_reader, - train_batch_size=batch_size, - eval_reader=eval_reader, - save_interval_epochs=5, - log_interval_steps=10, - save_dir=save_dir, - pretrain_weights=None, - optimizer=None, - learning_rate=0.01, -) - -# predict -model = load_model(osp.join(save_dir, 'best_model')) -pred_dir = osp.join(save_dir, 'pred') -if not osp.exists(pred_dir): - os.mkdir(pred_dir) - -color_map = [0, 0, 0, 255, 255, 255] - -with open(val_list) as f: - lines = f.readlines() - for line in lines: - img_path = line.split(' ')[0] - print('Predicting {}'.format(img_path)) - img_path_ = osp.join(data_dir, img_path) - - pred = model.predict(img_path_) - - pred_name = osp.basename(img_path).rstrip('npy') + 'png' - pred_path = osp.join(pred_dir, pred_name) - pred_mask = Image.fromarray(pred.astype(np.uint8), mode='P') - pred_mask.putpalette(color_map) - pred_mask.save(pred_path) diff --git a/contrib/RemoteSensing/models/base.py b/contrib/RemoteSensing/models/base.py index de1dca4c1ca79adf9ee7a9528dc000ff9d8f75a3..6c3fada7bf9eeb18b132d11c24b403785c03f743 100644 --- a/contrib/RemoteSensing/models/base.py +++ b/contrib/RemoteSensing/models/base.py @@ -21,13 +21,13 @@ import math import yaml import copy import json -import functools -import RemoteSensing.utils.logging as logging -import RemoteSensing +import utils.logging as logging from collections import OrderedDict from os import path as osp -from paddle.fluid.framework import Program -from ..utils.pretrain_weights import get_pretrain_weights +from utils.pretrain_weights import get_pretrain_weights +import transforms.transforms as T +import utils +import __init__ def dict2str(dict_input): @@ -46,7 +46,7 @@ class BaseAPI: # 现有的CV模型都有这个属性,而这个属且也需要在eval时用到 self.num_classes = None self.labels = None - if RemoteSensing.env_info['place'] == 'cpu': + if __init__.env_info['place'] == 'cpu': self.places = fluid.cpu_places() else: self.places = fluid.cuda_places() @@ -73,8 +73,8 @@ class BaseAPI: else: raise Exception("Please support correct batch_size, \ which can be divided by available cards({}) in {}". - format(RemoteSensing.env_info['num'], - RemoteSensing.env_info['place'])) + format(__init__.env_info['num'], + __init__.env_info['place'])) def build_program(self): # 构建训练网络 @@ -93,12 +93,9 @@ class BaseAPI: def arrange_transforms(self, transforms, mode='train'): # 给transforms添加arrange操作 if transforms.transforms[-1].__class__.__name__.startswith('Arrange'): - transforms.transforms[ - -1] = RemoteSensing.transforms.transforms.ArrangeSegmenter( - mode=mode) + transforms.transforms[-1] = T.ArrangeSegmenter(mode=mode) else: - transforms.transforms.append( - RemoteSensing.transforms.transforms.ArrangeSegmenter(mode=mode)) + transforms.transforms.append(T.ArrangeSegmenter(mode=mode)) def build_train_data_loader(self, reader, batch_size): # 初始化data_loader @@ -134,8 +131,8 @@ class BaseAPI: if pretrain_weights is not None: logging.info( "Load pretrain weights from {}.".format(pretrain_weights)) - RemoteSensing.utils.utils.load_pretrain_weights( - self.exe, self.train_prog, pretrain_weights, fuse_bn) + utils.utils.load_pretrain_weights(self.exe, self.train_prog, + pretrain_weights, fuse_bn) # 进行裁剪 if sensitivities_file is not None: from .slim.prune_config import get_sensitivities @@ -211,46 +208,6 @@ class BaseAPI: open(osp.join(save_dir, '.success'), 'w').close() logging.info("Model saved in {}.".format(save_dir)) - def export_inference_model(self, save_dir): - test_input_names = [var.name for var in list(self.test_inputs.values())] - test_outputs = list(self.test_outputs.values()) - if self.__class__.__name__ == 'MaskRCNN': - from RemoteSensing.utils.save import save_mask_inference_model - save_mask_inference_model( - dirname=save_dir, - executor=self.exe, - params_filename='__params__', - feeded_var_names=test_input_names, - target_vars=test_outputs, - main_program=self.test_prog) - else: - fluid.io.save_inference_model( - dirname=save_dir, - executor=self.exe, - params_filename='__params__', - feeded_var_names=test_input_names, - target_vars=test_outputs, - main_program=self.test_prog) - model_info = self.get_model_info() - model_info['status'] = 'Infer' - - # 保存模型输出的变量描述 - model_info['_ModelInputsOutputs'] = dict() - model_info['_ModelInputsOutputs']['test_inputs'] = [ - [k, v.name] for k, v in self.test_inputs.items() - ] - model_info['_ModelInputsOutputs']['test_outputs'] = [ - [k, v.name] for k, v in self.test_outputs.items() - ] - - with open( - osp.join(save_dir, 'model.yml'), encoding='utf-8', - mode='w') as f: - yaml.dump(model_info, f) - # 模型保存成功的标志 - open(osp.join(save_dir, '.success'), 'w').close() - logging.info("Model for inference deploy saved in {}.".format(save_dir)) - def train_loop(self, num_epochs, train_reader, @@ -287,8 +244,7 @@ class BaseAPI: if self.parallel_train_prog is None: build_strategy = fluid.compiler.BuildStrategy() build_strategy.fuse_all_optimizer_ops = False - if RemoteSensing.env_info['place'] != 'cpu' and len( - self.places) > 1: + if __init__.env_info['place'] != 'cpu' and len(self.places) > 1: build_strategy.sync_batch_norm = self.sync_bn exec_strategy = fluid.ExecutionStrategy() exec_strategy.num_iteration_per_drop_scope = 1 diff --git a/contrib/RemoteSensing/models/load_model.py b/contrib/RemoteSensing/models/load_model.py index c572b77d890796c00791ba16e504b6696e889cc1..fb55c13125c7ad194196082be00fb5df7c037dd8 100644 --- a/contrib/RemoteSensing/models/load_model.py +++ b/contrib/RemoteSensing/models/load_model.py @@ -19,8 +19,8 @@ import copy from collections import OrderedDict import paddle.fluid as fluid from paddle.fluid.framework import Parameter -from ..utils import logging -import RemoteSensing +from utils import logging +import models def load_model(model_dir): @@ -30,12 +30,11 @@ def load_model(model_dir): info = yaml.load(f.read(), Loader=yaml.Loader) status = info['status'] - if not hasattr(RemoteSensing.models, info['Model']): - raise Exception( - "There's no attribute {} in RemoteSensing.models".format( - info['Model'])) + if not hasattr(models, info['Model']): + raise Exception("There's no attribute {} in models".format( + info['Model'])) - model = getattr(RemoteSensing.models, info['Model'])(**info['_init_params']) + model = getattr(models, info['Model'])(**info['_init_params']) if status == "Normal" or \ status == "Prune": startup_prog = fluid.Program() @@ -82,7 +81,7 @@ def load_model(model_dir): def build_transforms(transforms_info): - from ..transforms import transforms as T + from transforms import transforms as T transforms = list() for op_info in transforms_info: op_name = list(op_info.keys())[0] diff --git a/contrib/RemoteSensing/models/unet.py b/contrib/RemoteSensing/models/unet.py index 571cca3a296dbf54a566c798346daa1d8dd398af..ff57b2b7f3fc704a4dff227917fc9c35e0f6670f 100644 --- a/contrib/RemoteSensing/models/unet.py +++ b/contrib/RemoteSensing/models/unet.py @@ -18,11 +18,11 @@ import numpy as np import math import cv2 import paddle.fluid as fluid -import RemoteSensing -import RemoteSensing.utils.logging as logging +import utils.logging as logging from collections import OrderedDict from .base import BaseAPI -from ..utils.metrics import ConfusionMatrix +from utils.metrics import ConfusionMatrix +import nets class UNet(BaseAPI): @@ -90,7 +90,7 @@ class UNet(BaseAPI): self.trainable = True def build_net(self, mode='train'): - model = RemoteSensing.nets.UNet( + model = nets.UNet( self.num_classes, mode=mode, upsample_mode=self.upsample_mode, @@ -152,9 +152,9 @@ class UNet(BaseAPI): Args: num_epochs (int): 训练迭代轮数。 - train_reader (RemoteSensing.readers): 训练数据读取器。 + train_reader (readers): 训练数据读取器。 train_batch_size (int): 训练数据batch大小。同时作为验证数据batch大小。默认2。 - eval_reader (RemoteSensing.readers): 评估数据读取器。 + eval_reader (readers): 评估数据读取器。 save_interval_epochs (int): 模型保存间隔(单位:迭代轮数)。默认为1。 log_interval_steps (int): 训练日志输出间隔(单位:迭代次数)。默认为2。 save_dir (str): 模型保存路径。默认'output'。 @@ -216,7 +216,7 @@ class UNet(BaseAPI): """评估。 Args: - eval_reader (RemoteSensing.readers): 评估数据读取器。 + eval_reader (readers): 评估数据读取器。 batch_size (int): 评估时的batch大小。默认1。 verbose (bool): 是否打印日志。默认True。 epoch_id (int): 当前评估模型所在的训练轮数。 @@ -241,6 +241,8 @@ class UNet(BaseAPI): for step, data in enumerate(data_generator()): images = np.array([d[0] for d in data]) + images = images.astype(np.float32) + labels = np.array([d[1] for d in data]) num_samples = images.shape[0] if num_samples < batch_size: @@ -283,7 +285,7 @@ class UNet(BaseAPI): """预测。 Args: img_file(str): 预测图像路径。 - transforms(RemoteSensing.transforms): 数据预处理操作。 + transforms(transforms): 数据预处理操作。 Returns: np.ndarray: 预测结果灰度图。 @@ -297,6 +299,7 @@ class UNet(BaseAPI): self.arrange_transforms( transforms=self.test_transforms, mode='test') im, im_info = self.test_transforms(im_file) + im = im.astype(np.float32) im = np.expand_dims(im, axis=0) result = self.exe.run( self.test_prog, diff --git a/contrib/RemoteSensing/predict_demo.py b/contrib/RemoteSensing/predict_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..444887b05e9178ecb93394e5833650f4cf18d99d --- /dev/null +++ b/contrib/RemoteSensing/predict_demo.py @@ -0,0 +1,53 @@ +import os +import os.path as osp +import numpy as np +from PIL import Image as Image +import argparse +from models import load_model + + +def parse_args(): + parser = argparse.ArgumentParser(description='RemoteSensing predict') + parser.add_argument( + '--data_dir', + dest='data_dir', + help='dataset directory', + default=None, + type=str) + parser.add_argument( + '--load_model_dir', + dest='load_model_dir', + help='model load directory', + default=None, + type=str) + return parser.parse_args() + + +args = parse_args() + +data_dir = args.data_dir +load_model_dir = args.load_model_dir + +# predict +model = load_model(osp.join(load_model_dir, 'best_model')) +pred_dir = osp.join(load_model_dir, 'pred') +if not osp.exists(pred_dir): + os.mkdir(pred_dir) + +val_list = osp.join(data_dir, 'val.txt') +color_map = [0, 0, 0, 255, 255, 255] +with open(val_list) as f: + lines = f.readlines() + for line in lines: + img_path = line.split(' ')[0] + print('Predicting {}'.format(img_path)) + img_path_ = osp.join(data_dir, img_path) + + pred = model.predict(img_path_) + + # 以伪彩色png图片保存预测结果 + pred_name = osp.basename(img_path).rstrip('npy') + 'png' + pred_path = osp.join(pred_dir, pred_name) + pred_mask = Image.fromarray(pred.astype(np.uint8), mode='P') + pred_mask.putpalette(color_map) + pred_mask.save(pred_path) diff --git a/contrib/RemoteSensing/readers/base.py b/contrib/RemoteSensing/readers/base.py index e0517613b6bc9614a36d471aa81a35f095c7a881..1427bd60ad4637a3f13c8a08f59291f15fe5ac82 100644 --- a/contrib/RemoteSensing/readers/base.py +++ b/contrib/RemoteSensing/readers/base.py @@ -22,7 +22,7 @@ import copy import random import platform import chardet -from ..utils import logging +from utils import logging class EndSignal(): diff --git a/contrib/RemoteSensing/readers/reader.py b/contrib/RemoteSensing/readers/reader.py index d4d32a238828925e116a73931d2e6fc178ba926f..343d25b15034e1905a1e55ae926fbdfa62916cf1 100644 --- a/contrib/RemoteSensing/readers/reader.py +++ b/contrib/RemoteSensing/readers/reader.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import os.path as osp import random -from ..utils import logging +from utils import logging from .base import BaseReader from .base import get_encoding from collections import OrderedDict diff --git a/contrib/RemoteSensing/requirements.txt b/contrib/RemoteSensing/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..23b5cfdf626667c7c042e6a18eaee9154e95766b --- /dev/null +++ b/contrib/RemoteSensing/requirements.txt @@ -0,0 +1,9 @@ +pre-commit +yapf == 0.26.0 +flake8 +pyyaml >= 5.1 +Pillow +numpy +six +opencv-python +tqdm \ No newline at end of file diff --git a/contrib/RemoteSensing/tools/create_dataset_list.py b/contrib/RemoteSensing/tools/create_dataset_list.py new file mode 100644 index 0000000000000000000000000000000000000000..8fec77d55455a242f5cf20f435c9cbf3b04e50de --- /dev/null +++ b/contrib/RemoteSensing/tools/create_dataset_list.py @@ -0,0 +1,143 @@ +# coding: utf8 +# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import os.path +import argparse +import warnings + + +def parse_args(): + parser = argparse.ArgumentParser( + description='PaddleSeg generate file list on your customized dataset.') + parser.add_argument('dataset_root', help='dataset root directory', type=str) + parser.add_argument( + '--separator', + dest='separator', + help='file list separator', + default=" ", + type=str) + parser.add_argument( + '--folder', + help='the folder names of images and labels', + type=str, + nargs=2, + default=['images', 'annotations']) + parser.add_argument( + '--second_folder', + help= + 'the second-level folder names of train set, validation set, test set', + type=str, + nargs='*', + default=['train', 'val', 'test']) + parser.add_argument( + '--format', + help='data format of images and labels, default npy, png.', + type=str, + nargs=2, + default=['npy', 'png']) + parser.add_argument( + '--label_class', + help='label class names', + type=str, + nargs='*', + default=['__background__', '__foreground__']) + parser.add_argument( + '--postfix', + help='postfix of images or labels', + type=str, + nargs=2, + default=['', '']) + + return parser.parse_args() + + +def get_files(image_or_label, dataset_split, args): + dataset_root = args.dataset_root + postfix = args.postfix + format = args.format + folder = args.folder + + pattern = '*%s.%s' % (postfix[image_or_label], format[image_or_label]) + + search_files = os.path.join(dataset_root, folder[image_or_label], + dataset_split, pattern) + search_files2 = os.path.join(dataset_root, folder[image_or_label], + dataset_split, "*", pattern) # 包含子目录 + search_files3 = os.path.join(dataset_root, folder[image_or_label], + dataset_split, "*", "*", pattern) # 包含三级目录 + + filenames = glob.glob(search_files) + filenames2 = glob.glob(search_files2) + filenames3 = glob.glob(search_files3) + + filenames = filenames + filenames2 + filenames3 + + return sorted(filenames) + + +def generate_list(args): + dataset_root = args.dataset_root + separator = args.separator + + file_list = os.path.join(dataset_root, 'labels.txt') + with open(file_list, "w") as f: + for label_class in args.label_class: + f.write(label_class + '\n') + + for dataset_split in args.second_folder: + print("Creating {}.txt...".format(dataset_split)) + image_files = get_files(0, dataset_split, args) + label_files = get_files(1, dataset_split, args) + if not image_files: + img_dir = os.path.join(dataset_root, args.folder[0], dataset_split) + warnings.warn("No images in {} !!!".format(img_dir)) + num_images = len(image_files) + + if not label_files: + label_dir = os.path.join(dataset_root, args.folder[1], + dataset_split) + warnings.warn("No labels in {} !!!".format(label_dir)) + num_label = len(label_files) + + if num_images != num_label and num_label > 0: + raise Exception( + "Number of images = {} number of labels = {} \n" + "Either number of images is equal to number of labels, " + "or number of labels is equal to 0.\n" + "Please check your dataset!".format(num_images, num_label)) + + file_list = os.path.join(dataset_root, dataset_split + '.txt') + with open(file_list, "w") as f: + for item in range(num_images): + left = image_files[item].replace(dataset_root, '') + if left[0] == os.path.sep: + left = left.lstrip(os.path.sep) + + try: + right = label_files[item].replace(dataset_root, '') + if right[0] == os.path.sep: + right = right.lstrip(os.path.sep) + line = left + separator + right + '\n' + except: + line = left + '\n' + + f.write(line) + print(line) + + +if __name__ == '__main__': + args = parse_args() + generate_list(args) diff --git a/contrib/RemoteSensing/train_demo.py b/contrib/RemoteSensing/train_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..3b70c1343f5a326b1aefb69966e6f23b64381ac7 --- /dev/null +++ b/contrib/RemoteSensing/train_demo.py @@ -0,0 +1,106 @@ +import os.path as osp +import argparse +import transforms.transforms as T +from readers.reader import Reader +from models import UNet + + +def parse_args(): + parser = argparse.ArgumentParser(description='RemoteSensing training') + parser.add_argument( + '--data_dir', + dest='data_dir', + help='dataset directory', + default=None, + type=str) + parser.add_argument( + '--save_dir', + dest='save_dir', + help='model save directory', + default=None, + type=str) + parser.add_argument( + '--channel', + dest='channel', + help='number of data channel', + default=3, + type=int) + parser.add_argument( + '--num_epochs', + dest='num_epochs', + help='number of traing epochs', + default=100, + type=int) + parser.add_argument( + '--train_batch_size', + dest='train_batch_size', + help='training batch size', + default=4, + type=int) + parser.add_argument( + '--lr', dest='lr', help='learning rate', default=0.01, type=float) + return parser.parse_args() + + +args = parse_args() + +data_dir = args.data_dir +save_dir = args.save_dir +channel = args.channel +num_epochs = args.num_epochs +train_batch_size = args.train_batch_size +lr = args.lr + +# 定义训练和验证时的transforms +train_transforms = T.Compose([ + T.RandomVerticalFlip(0.5), + T.RandomHorizontalFlip(0.5), + T.ResizeStepScaling(0.5, 2.0, 0.25), + T.RandomPaddingCrop(256), + T.Normalize(mean=[0.5] * channel, std=[0.5] * channel), +]) + +eval_transforms = T.Compose([ + T.Normalize(mean=[0.5] * channel, std=[0.5] * channel), +]) + +train_list = osp.join(data_dir, 'train.txt') +val_list = osp.join(data_dir, 'val.txt') +label_list = osp.join(data_dir, 'labels.txt') + +# 定义数据读取器 +train_reader = Reader( + data_dir=data_dir, + file_list=train_list, + label_list=label_list, + transforms=train_transforms, + num_workers=8, + buffer_size=16, + shuffle=True, + parallel_method='thread') + +eval_reader = Reader( + data_dir=data_dir, + file_list=val_list, + label_list=label_list, + transforms=eval_transforms, + num_workers=8, + buffer_size=16, + shuffle=False, + parallel_method='thread') + +model = UNet( + num_classes=2, input_channel=channel, use_bce_loss=True, use_dice_loss=True) + +model.train( + num_epochs=num_epochs, + train_reader=train_reader, + train_batch_size=train_batch_size, + eval_reader=eval_reader, + save_interval_epochs=5, + log_interval_steps=10, + save_dir=save_dir, + pretrain_weights=None, + optimizer=None, + learning_rate=lr, +) diff --git a/contrib/RemoteSensing/transforms/ops.py b/contrib/RemoteSensing/transforms/ops.py index e81c414380765bd3c7fe1de364e049273aefcc9a..e04e695410e5f1e089de838526889c02cadd7da1 100644 --- a/contrib/RemoteSensing/transforms/ops.py +++ b/contrib/RemoteSensing/transforms/ops.py @@ -18,8 +18,12 @@ import numpy as np from PIL import Image, ImageEnhance -def normalize(im, mean, std): - im = im.astype(np.float32, copy=False) / 255.0 +def normalize(im, min_value, max_value, mean, std): + # Rescaling (min-max normalization) + range_value = [max_value[i] - min_value[i] for i in range(len(max_value))] + im = (im.astype(np.float32, copy=False) - min_value) / range_value + + # Standardization (Z-score Normalization) im -= mean im /= std return im diff --git a/contrib/RemoteSensing/transforms/transforms.py b/contrib/RemoteSensing/transforms/transforms.py index 81ba9a507bfc2eb06ec553f88853ae83b4846d37..abac1746e09e8e95d4149e8243d6ea4258f347ef 100644 --- a/contrib/RemoteSensing/transforms/transforms.py +++ b/contrib/RemoteSensing/transforms/transforms.py @@ -170,7 +170,7 @@ class Resize: def __init__(self, target_size, interp='LINEAR'): self.interp = interp assert interp in self.interp_dict, "interp should be one of {}".format( - interp_dict.keys()) + self.interp_dict.keys()) if isinstance(target_size, list) or isinstance(target_size, tuple): if len(target_size) != 2: raise ValueError( @@ -271,17 +271,6 @@ class ResizeByLong: -shape_before_resize (tuple): 保存resize之前图像的形状(h, w)。 """ if im_info is None: - im = np.pad( - im, - pad_width=((0, pad_height), (0, pad_width), (0, 0)), - mode='constant', - constant_values=(self.im_padding_value, self.im_padding_value)) - label = np.pad( - label, - pad_width=((0, pad_height), (0, pad_width)), - mode='constant', - constant_values=(self.label_padding_value, - self.label_padding_value)) im_info = OrderedDict() im_info['shape_before_resize'] = im.shape[:2] @@ -420,20 +409,58 @@ class ResizeStepScaling: return (im, im_info, label) +class Clip: + """ + 对图像上超出一定范围的数据进行裁剪。 + + Args: + min_val (list): 裁剪的下限,小于min_val的数值均设为min_val. 默认值[0, 0, 0]. + max_val (list): 裁剪的上限,大于max_val的数值均设为max_val. 默认值[255.0, 255.0, 255.0] + """ + + def __init__(self, min_val=[0, 0, 0], max_val=[255.0, 255.0, 255.0]): + self.min_val = min_val + self.max_val = max_val + + def __call__(self, im, im_info=None, label=None): + if isinstance(self.min_val, list) and isinstance(self.max_val, list): + for k in range(im.shape[2]): + np.clip( + im[:, :, k], + self.min_val[k], + self.max_val[k], + out=im[:, :, k]) + else: + raise TypeError('min_val and max_val must be list') + + if label is None: + return (im, im_info) + else: + return (im, im_info, label) + + class Normalize: """对图像进行标准化。 - 1.尺度缩放到 [0,1]。 + 1.图像像素归一化到区间 [0.0, 1.0]。 2.对图像进行减均值除以标准差操作。 Args: - mean (list): 图像数据集的均值。默认值[0.5, 0.5, 0.5]。 - std (list): 图像数据集的标准差。默认值[0.5, 0.5, 0.5]。 + min_val (list): 图像数据集的最小值。默认值[0, 0, 0]. + max_val (list): 图像数据集的最大值。默认值[255.0, 255.0, 255.0] + mean (list): 图像数据集的均值。默认值[0.5, 0.5, 0.5]. + std (list): 图像数据集的标准差。默认值[0.5, 0.5, 0.5]. Raises: ValueError: mean或std不是list对象。std包含0。 """ - def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): + def __init__(self, + min_val=[0, 0, 0], + max_val=[255.0, 255.0, 255.0], + mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5]): + self.min_val = min_val + self.max_val = max_val self.mean = mean self.std = std if not (isinstance(self.mean, list) and isinstance(self.std, list)): @@ -457,7 +484,8 @@ class Normalize: mean = np.array(self.mean)[np.newaxis, np.newaxis, :] std = np.array(self.std)[np.newaxis, np.newaxis, :] - im = normalize(im, mean, std) + + im = normalize(im, self.min_val, self.max_val, mean, std) if label is None: return (im, im_info) @@ -471,7 +499,7 @@ class Padding: Args: target_size (int/list/tuple): padding后图像的大小。 - im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。 + im_padding_value (list): 图像padding的值。默认为127.5。 label_padding_value (int): 标注图像padding的值。默认值为255。 Raises: @@ -554,7 +582,7 @@ class RandomPaddingCrop: Args: crop_size(int or list or tuple): 裁剪图像大小。默认为512。 - im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。 + im_padding_value (list): 图像padding的值。默认为127.5 label_padding_value (int): 标注图像padding的值。默认值为255。 Raises: @@ -684,75 +712,6 @@ class RandomBlur: return (im, im_info, label) -class RandomRotation: - """对图像进行随机旋转。 - 在不超过最大旋转角度的情况下,图像进行随机旋转,当存在标注图像时,同步进行, - 并对旋转后的图像和标注图像进行相应的padding。 - - Args: - max_rotation (float): 最大旋转角度。默认为15度。 - im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。 - label_padding_value (int): 标注图像padding的值。默认为255。 - - """ - - def __init__(self, - max_rotation=15, - im_padding_value=[127.5, 127.5, 127.5], - label_padding_value=255): - self.max_rotation = max_rotation - self.im_padding_value = im_padding_value - self.label_padding_value = label_padding_value - - def __call__(self, im, im_info=None, label=None): - """ - Args: - im (np.ndarray): 图像np.ndarray数据。 - im_info (dict): 存储与图像相关的信息。 - label (np.ndarray): 标注图像np.ndarray数据。 - - Returns: - tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; - 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 - 存储与图像相关信息的字典和标注图像np.ndarray数据。 - """ - if self.max_rotation > 0: - (h, w) = im.shape[:2] - do_rotation = np.random.uniform(-self.max_rotation, - self.max_rotation) - pc = (w // 2, h // 2) - r = cv2.getRotationMatrix2D(pc, do_rotation, 1.0) - cos = np.abs(r[0, 0]) - sin = np.abs(r[0, 1]) - - nw = int((h * sin) + (w * cos)) - nh = int((h * cos) + (w * sin)) - - (cx, cy) = pc - r[0, 2] += (nw / 2) - cx - r[1, 2] += (nh / 2) - cy - dsize = (nw, nh) - im = cv2.warpAffine( - im, - r, - dsize=dsize, - flags=cv2.INTER_LINEAR, - borderMode=cv2.BORDER_CONSTANT, - borderValue=self.im_padding_value) - label = cv2.warpAffine( - label, - r, - dsize=dsize, - flags=cv2.INTER_NEAREST, - borderMode=cv2.BORDER_CONSTANT, - borderValue=self.label_padding_value) - - if label is None: - return (im, im_info) - else: - return (im, im_info, label) - - class RandomScaleAspect: """裁剪并resize回原始尺寸的图像和标注图像。 按照一定的面积比和宽高比对图像进行裁剪,并reszie回原始图像的图像,当存在标注图时,同步进行。 @@ -813,116 +772,6 @@ class RandomScaleAspect: return (im, im_info, label) -class RandomDistort: - """对图像进行随机失真。 - - 1. 确定随机失真操作[变换明亮度、变换对比度、变换饱和度、变换色彩]的执行顺序。 - 2. 以一定的概率执行每个随机扰动操作。 - - Args: - brightness_range (float): 明亮度因子的范围。默认为0.5。 - brightness_prob (float): 随机调整明亮度的概率。默认为0.5。 - contrast_range (float): 对比度因子的范围。默认为0.5。 - contrast_prob (float): 随机调整对比度的概率。默认为0.5。 - saturation_range (float): 饱和度因子的范围。默认为0.5。 - saturation_prob (float): 随机调整饱和度的概率。默认为0.5。 - hue_range (int): 色调因子的范围。默认为18。 - hue_prob (float): 随机调整色调的概率。默认为0.5。 - is_order (bool): 是否按照固定顺序 - [变换明亮度、变换对比度、变换饱和度、变换色彩] - 执行像素内容变换操作。默认为False。 - """ - - def __init__(self, - brightness_range=0.5, - brightness_prob=0.5, - contrast_range=0.5, - contrast_prob=0.5, - saturation_range=0.5, - saturation_prob=0.5, - hue_range=18, - hue_prob=0.5, - is_order=False): - self.brightness_range = brightness_range - self.brightness_prob = brightness_prob - self.contrast_range = contrast_range - self.contrast_prob = contrast_prob - self.saturation_range = saturation_range - self.saturation_prob = saturation_prob - self.hue_range = hue_range - self.hue_prob = hue_prob - self.is_order = is_order - - def __call__(self, im, im_info=None, label_info=None): - """ - Args: - im (np.ndarray): 图像np.ndarray数据。 - im_info (dict): 存储与图像相关的信息。 - label (np.ndarray): 标注图像np.ndarray数据。 - - Returns: - tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; - 当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、 - 存储与图像相关信息的字典和标注图像np.ndarray数据。 - """ - brightness_lower = 1 - self.brightness_range - brightness_upper = 1 + self.brightness_range - contrast_lower = 1 - self.contrast_range - contrast_upper = 1 + self.contrast_range - saturation_lower = 1 - self.saturation_range - saturation_upper = 1 + self.saturation_range - hue_lower = -self.hue_range - hue_upper = self.hue_range - ops = [brightness, contrast, saturation, hue] - if self.is_order: - prob = np.random.uniform(0, 1) - if prob < 0.5: - ops = [ - brightness, - saturation, - hue, - contrast, - ] - else: - random.shuffle(ops) - params_dict = { - 'brightness': { - 'brightness_lower': brightness_lower, - 'brightness_upper': brightness_upper - }, - 'contrast': { - 'contrast_lower': contrast_lower, - 'contrast_upper': contrast_upper - }, - 'saturation': { - 'saturation_lower': saturation_lower, - 'saturation_upper': saturation_upper - }, - 'hue': { - 'hue_lower': hue_lower, - 'hue_upper': hue_upper - } - } - prob_dict = { - 'brightness': self.brightness_prob, - 'contrast': self.contrast_prob, - 'saturation': self.saturation_prob, - 'hue': self.hue_prob - } - im = Image.fromarray(im) - for id in range(4): - params = params_dict[ops[id].__name__] - prob = prob_dict[ops[id].__name__] - params['im'] = im - if np.random.uniform(0, 1) < prob: - im = ops[id](**params) - im = np.asarray(im) - if label is None: - return (im, im_info) - else: - return (im, im_info, label) - - class ArrangeSegmenter: """获取训练/验证/预测所需的信息。 diff --git a/contrib/RemoteSensing/utils/logging.py b/contrib/RemoteSensing/utils/logging.py index 46f8e38b99f4832f4ac5596230eb32e8f45878b8..6d14b1a5df23827c2ddea2a0959801fab6e70552 100644 --- a/contrib/RemoteSensing/utils/logging.py +++ b/contrib/RemoteSensing/utils/logging.py @@ -15,7 +15,7 @@ import time import os import sys -import RemoteSensing +import __init__ levels = {0: 'ERROR', 1: 'WARNING', 2: 'INFO', 3: 'DEBUG'} @@ -24,7 +24,7 @@ def log(level=2, message=""): current_time = time.time() time_array = time.localtime(current_time) current_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array) - if RemoteSensing.log_level >= level: + if __init__.log_level >= level: print("{} [{}]\t{}".format(current_time, levels[level], message).encode("utf-8").decode("latin1")) sys.stdout.flush() diff --git a/contrib/RemoteSensing/utils/pretrain_weights.py b/contrib/RemoteSensing/utils/pretrain_weights.py index df104fc0a49f090162d8a03cccecf5088c05d668..e23686406897bc84e705640640bd7ee17d9d95ec 100644 --- a/contrib/RemoteSensing/utils/pretrain_weights.py +++ b/contrib/RemoteSensing/utils/pretrain_weights.py @@ -1,5 +1,3 @@ -import RemoteSensing -import os import os.path as osp