未验证 提交 41f762ea 编写于 作者: L LutaoChu 提交者: GitHub

Optimize transform methods and add documents for remote sensing (#233)

* add RemoteSensing

* optimize data clip and normalize method for remote sensing

* add README.md

* polish README.md, optimize main.py and transforms.py, add transforms.md

* add demo dataset

* add create_dataset_list tool, requirements.txt

* add channel-by-channel normalize and clip, remove RemoteSensing import
上级 8281cc6f
# 遥感分割(RemoteSensing)
遥感影像分割是图像分割领域中的重要应用场景,广泛应用于土地测绘、环境监测、城市建设等领域。遥感影像分割的目标多种多样,有诸如积雪、农作物、道路、建筑、水源等地物目标,也有例如云层的空中目标。
PaddleSeg提供了针对遥感专题的语义分割库RemoteSensing,涵盖图像预处理、数据增强、模型训练、预测流程,帮助大家利用深度学习技术解决遥感影像分割问题。
针对遥感数据多通道、分布范围大、分布不均的特点,我们支持多通道训练预测,内置一系列多通道预处理和数据增强的策略,可结合实际业务场景进行定制组合,提升模型泛化能力和鲁棒性。
**Note:** 所有命令需要在`PaddleSeg/contrib/RemoteSensing/`目录下执行。
## 前置依赖
- Paddle 1.7.1+
由于图像分割模型计算开销大,推荐在GPU版本的PaddlePaddle下使用。
PaddlePaddle的安装, 请按照[官网指引](https://paddlepaddle.org.cn/install/quick)安装合适自己的版本。
- Python 3.5+
- 其他依赖安装
通过以下命令安装python包依赖,请确保至少执行过一次以下命令:
```
cd RemoteSensing
pip install -r requirements.txt
```
## 目录结构说明
```
RemoteSensing # 根目录
|-- dataset # 数据集
|-- docs # 文档
|-- models # 模型类定义模块
|-- nets # 组网模块
|-- readers # 数据读取模块
|-- tools # 工具集
|-- transforms # 数据增强模块
|-- utils # 公用模块
|-- train_demo.py # 训练demo脚本
|-- predict_demo.py # 预测demo脚本
|-- README.md # 使用手册
```
## 数据协议
数据集包含原图、标注图及相应的文件列表文件。
参考数据文件结构如下:
```
./dataset/ # 数据集根目录
|--images # 原图目录
| |--xxx1.npy
| |--...
| └--...
|
|--annotations # 标注图目录
| |--xxx1.png
| |--...
| └--...
|
|--train_list.txt # 训练文件列表文件
|
|--val_list.txt # 验证文件列表文件
|
└--labels.txt # 标签列表
```
其中,相应的文件名可根据需要自行定义。
由于遥感领域图像格式多种多样,不同传感器产生的数据格式可能不同。本分割库目前采用npy格式作为遥感数据的格式,采用png无损压缩格式作为标注图片格式。
标注图像为单通道图像,像素值即为对应的类别,像素标注类别需要从0开始递增,
例如0,1,2,3表示有4种类别,标注类别最多为256类。其中可以指定特定的像素值用于表示该值的像素不参与训练和评估(默认为255)。
`train_list.txt``val_list.txt`文本以空格为分割符分为两列,第一列为图像文件相对于dataset的相对路径,第二列为标注图像文件相对于dataset的相对路径。如下所示:
```
images/xxx1.npy annotations/xxx1.png
images/xxx2.npy annotations/xxx2.png
...
```
具体要求和如何生成文件列表可参考[文件列表规范](../../docs/data_prepare.md#文件列表)
`labels.txt`: 每一行为一个单独的类别,相应的行号即为类别对应的id(行号从0开始),如下所示:
```
labelA
labelB
...
```
## 快速上手
本章节在一个小数据集上展示了如何通过RemoteSensing进行训练预测。
### 1. 准备数据集
为了快速体验,我们准备了一个小型demo数据集,已位于`RemoteSensing/dataset/demo/`目录下.
对于您自己的数据集,您需要按照上述的数据协议进行格式转换,可分别使用numpy和pil库保存遥感数据和标注图片。其中numpy api示例如下:
```python
import numpy as np
# 保存遥感数据
# img类型:numpy.ndarray
np.save(save_path, img)
```
### 2. 训练代码开发
通过如下`train_demo.py`代码进行训练。
> 导入RemoteSensing api
```python
import transforms.transforms as T
from readers.reader import Reader
from models import UNet
```
> 定义训练和验证时的数据处理和增强流程, 在`train_transforms`中加入了`RandomVerticalFlip`,`RandomHorizontalFlip`等数据增强方式。
```python
train_transforms = T.Compose([
T.RandomVerticalFlip(0.5),
T.RandomHorizontalFlip(0.5),
T.ResizeStepScaling(0.5, 2.0, 0.25),
T.RandomPaddingCrop(256),
T.Normalize(mean=[0.5] * channel, std=[0.5] * channel),
])
eval_transforms = T.Compose([
T.Normalize(mean=[0.5] * channel, std=[0.5] * channel),
])
```
> 定义数据读取器
```python
import os
import os.path as osp
train_list = osp.join(data_dir, 'train.txt')
val_list = osp.join(data_dir, 'val.txt')
label_list = osp.join(data_dir, 'labels.txt')
train_reader = Reader(
data_dir=data_dir,
file_list=train_list,
label_list=label_list,
transforms=train_transforms,
num_workers=8,
buffer_size=16,
shuffle=True,
parallel_method='thread')
eval_reader = Reader(
data_dir=data_dir,
file_list=val_list,
label_list=label_list,
transforms=eval_transforms,
num_workers=8,
buffer_size=16,
shuffle=False,
parallel_method='thread')
```
> 模型构建
```python
model = UNet(
num_classes=2, input_channel=channel, use_bce_loss=True, use_dice_loss=True)
```
> 模型训练,并开启边训边评估
```python
model.train(
num_epochs=num_epochs,
train_reader=train_reader,
train_batch_size=train_batch_size,
eval_reader=eval_reader,
save_interval_epochs=5,
log_interval_steps=10,
save_dir=save_dir,
pretrain_weights=None,
optimizer=None,
learning_rate=lr,
)
```
### 3. 模型训练
> 设置GPU卡号
```shell script
export CUDA_VISIBLE_DEVICES=0
```
> 在RemoteSensing目录下运行`train_demo.py`即可开始训练。
```shell script
python train_demo.py --data_dir dataset/demo/ --save_dir saved_model/unet/ --channel 3 --num_epochs 20
```
### 4. 模型预测代码开发
通过如下`predict_demo.py`代码进行预测。
> 导入RemoteSensing api
```python
from models import load_model
```
> 加载训练过程中最好的模型,设置预测结果保存路径。
```python
import os
import os.path as osp
model = load_model(osp.join(save_dir, 'best_model'))
pred_dir = osp.join(save_dir, 'pred')
if not osp.exists(pred_dir):
os.mkdir(pred_dir)
```
> 使用模型对验证集进行测试,并保存预测结果。
```python
import numpy as np
from PIL import Image as Image
val_list = osp.join(data_dir, 'val.txt')
color_map = [0, 0, 0, 255, 255, 255]
with open(val_list) as f:
lines = f.readlines()
for line in lines:
img_path = line.split(' ')[0]
print('Predicting {}'.format(img_path))
img_path_ = osp.join(data_dir, img_path)
pred = model.predict(img_path_)
# 以伪彩色png图片保存预测结果
pred_name = osp.basename(img_path).rstrip('npy') + 'png'
pred_path = osp.join(pred_dir, pred_name)
pred_mask = Image.fromarray(pred.astype(np.uint8), mode='P')
pred_mask.putpalette(color_map)
pred_mask.save(pred_path)
```
### 5. 模型预测
> 设置GPU卡号
```shell script
export CUDA_VISIBLE_DEVICES=0
```
> 在RemoteSensing目录下运行`predict_demo.py`即可开始训练。
```shell script
python predict_demo.py --data_dir dataset/demo/ --load_model_dir saved_model/unet/
```
## Api说明
您可以使用`RemoteSensing`目录下提供的api构建自己的分割代码。
- [数据处理-transforms](docs/transforms.md)
......@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import utils
from . import nets
from . import models
from . import transforms
from . import readers
from .utils.utils import get_environ_info
import utils
import nets
import models
import transforms
import readers
from utils.utils import get_environ_info
env_info = get_environ_info()
......
__background__
cloud
\ No newline at end of file
images/1001.npy annotations/1001.png
images/1002.npy annotations/1002.png
images/1005.npy annotations/1005.png
images/0.npy annotations/0.png
images/1003.npy annotations/1003.png
images/1000.npy annotations/1000.png
images/1004.npy annotations/1004.png
images/100.npy annotations/100.png
images/1.npy annotations/1.png
images/10.npy annotations/10.png
# transforms.transforms
对用于分割任务的数据进行操作。可以利用[Compose](#compose)类将图像预处理/增强操作进行组合。
## Compose类
```python
transforms.transforms.Compose(transforms)
```
根据数据预处理/数据增强列表对输入数据进行操作。
### 参数
* **transforms** (list): 数据预处理/数据增强列表。
## RandomHorizontalFlip类
```python
transforms.transforms.RandomHorizontalFlip(prob=0.5)
```
以一定的概率对图像进行水平翻转,模型训练时的数据增强操作。
### 参数
* **prob** (float): 随机水平翻转的概率。默认值为0.5。
## RandomVerticalFlip类
```python
transforms.transforms.RandomVerticalFlip(prob=0.1)
```
以一定的概率对图像进行垂直翻转,模型训练时的数据增强操作。
### 参数
* **prob** (float): 随机垂直翻转的概率。默认值为0.1。
## Resize类
```python
transforms.transforms.Resize(target_size, interp='LINEAR')
```
调整图像大小(resize)。
- 当目标大小(target_size)类型为int时,根据插值方式,
将图像resize为[target_size, target_size]。
- 当目标大小(target_size)类型为list或tuple时,根据插值方式,
将图像resize为target_size, target_size的输入应为[w, h]或(w, h)。
### 参数
* **target_size** (int|list|tuple): 目标大小
* **interp** (str): resize的插值方式,与opencv的插值方式对应,
可选的值为['NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4'],默认为"LINEAR"。
## ResizeByLong类
```python
transforms.transforms.ResizeByLong(long_size)
```
对图像长边resize到固定值,短边按比例进行缩放。
### 参数
* **long_size** (int): resize后图像的长边大小。
## ResizeRangeScaling类
```python
transforms.transforms.ResizeRangeScaling(min_value=400, max_value=600)
```
对图像长边随机resize到指定范围内,短边按比例进行缩放,模型训练时的数据增强操作。
### 参数
* **min_value** (int): 图像长边resize后的最小值。默认值400。
* **max_value** (int): 图像长边resize后的最大值。默认值600。
## ResizeStepScaling类
```python
transforms.transforms.ResizeStepScaling(min_scale_factor=0.75, max_scale_factor=1.25, scale_step_size=0.25)
```
对图像按照某一个比例resize,这个比例以scale_step_size为步长,在[min_scale_factor, max_scale_factor]随机变动,模型训练时的数据增强操作。
### 参数
* **min_scale_factor**(float), resize最小尺度。默认值0.75。
* **max_scale_factor** (float), resize最大尺度。默认值1.25。
* **scale_step_size** (float), resize尺度范围间隔。默认值0.25。
## Clip类
```python
transforms.transforms.Clip(min_val=[0, 0, 0], max_val=[255.0, 255.0, 255.0])
```
对图像上超出一定范围的数据进行裁剪。
### 参数
* **min_var** (list): 裁剪的下限,小于min_val的数值均设为min_val. 默认值[0, 0, 0].
* **max_var** (list): 裁剪的上限,大于max_val的数值均设为max_val. 默认值[255.0, 255.0, 255.0]
## Normalize类
```python
transforms.transforms.Normalize(min_val=[0, 0, 0], max_val=[255.0, 255.0, 255.0], mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
```
对图像进行标准化。
1.图像像素归一化到区间 [0.0, 1.0]。
2.对图像进行减均值除以标准差操作。
### 参数
* **min_val** (list): 图像数据集的最小值。默认值[0, 0, 0].
* **max_val** (list): 图像数据集的最大值。默认值[255.0, 255.0, 255.0]
* **mean** (list): 图像数据集的均值。默认值[0.5, 0.5, 0.5]。
* **std** (list): 图像数据集的标准差。默认值[0.5, 0.5, 0.5]。
## Padding类
```python
transforms.transforms.Padding(target_size, im_padding_value=127.5, label_padding_value=255)
```
对图像或标注图像进行padding,padding方向为右和下。根据提供的值对图像或标注图像进行padding操作。
### 参数
* **target_size** (int|list|tuple): padding后图像的大小。
* **im_padding_value** (list): 图像padding的值。默认为127.5
* **label_padding_value** (int): 标注图像padding的值。默认值为255(仅在训练时需要设定该参数)。
## RandomPaddingCrop类
```python
transforms.transforms.RandomPaddingCrop(crop_size=512, im_padding_value=127.5, label_padding_value=255)
```
对图像和标注图进行随机裁剪,当所需要的裁剪尺寸大于原图时,则进行padding操作,模型训练时的数据增强操作。
### 参数
* **crop_size**(int|list|tuple): 裁剪图像大小。默认为512。
* **im_padding_value** (list): 图像padding的值。默认为127.5。
* **label_padding_value** (int): 标注图像padding的值。默认值为255。
## RandomBlur类
```python
transforms.transforms.RandomBlur(prob=0.1)
```
以一定的概率对图像进行高斯模糊,模型训练时的数据增强操作。
### 参数
* **prob** (float): 图像模糊概率。默认为0.1。
## RandomScaleAspect类
```python
transforms.transforms.RandomScaleAspect(min_scale=0.5, aspect_ratio=0.33)
```
裁剪并resize回原始尺寸的图像和标注图像,模型训练时的数据增强操作。
按照一定的面积比和宽高比对图像进行裁剪,并reszie回原始图像的图像,当存在标注图时,同步进行。
### 参数
* **min_scale** (float):裁取图像占原始图像的面积比,取值[0,1],为0时则返回原图。默认为0.5。
* **aspect_ratio** (float): 裁取图像的宽高比范围,非负值,为0时返回原图。默认为0.33。
......@@ -21,13 +21,13 @@ import math
import yaml
import copy
import json
import functools
import RemoteSensing.utils.logging as logging
import RemoteSensing
import utils.logging as logging
from collections import OrderedDict
from os import path as osp
from paddle.fluid.framework import Program
from ..utils.pretrain_weights import get_pretrain_weights
from utils.pretrain_weights import get_pretrain_weights
import transforms.transforms as T
import utils
import __init__
def dict2str(dict_input):
......@@ -46,7 +46,7 @@ class BaseAPI:
# 现有的CV模型都有这个属性,而这个属且也需要在eval时用到
self.num_classes = None
self.labels = None
if RemoteSensing.env_info['place'] == 'cpu':
if __init__.env_info['place'] == 'cpu':
self.places = fluid.cpu_places()
else:
self.places = fluid.cuda_places()
......@@ -73,8 +73,8 @@ class BaseAPI:
else:
raise Exception("Please support correct batch_size, \
which can be divided by available cards({}) in {}".
format(RemoteSensing.env_info['num'],
RemoteSensing.env_info['place']))
format(__init__.env_info['num'],
__init__.env_info['place']))
def build_program(self):
# 构建训练网络
......@@ -93,12 +93,9 @@ class BaseAPI:
def arrange_transforms(self, transforms, mode='train'):
# 给transforms添加arrange操作
if transforms.transforms[-1].__class__.__name__.startswith('Arrange'):
transforms.transforms[
-1] = RemoteSensing.transforms.transforms.ArrangeSegmenter(
mode=mode)
transforms.transforms[-1] = T.ArrangeSegmenter(mode=mode)
else:
transforms.transforms.append(
RemoteSensing.transforms.transforms.ArrangeSegmenter(mode=mode))
transforms.transforms.append(T.ArrangeSegmenter(mode=mode))
def build_train_data_loader(self, reader, batch_size):
# 初始化data_loader
......@@ -134,8 +131,8 @@ class BaseAPI:
if pretrain_weights is not None:
logging.info(
"Load pretrain weights from {}.".format(pretrain_weights))
RemoteSensing.utils.utils.load_pretrain_weights(
self.exe, self.train_prog, pretrain_weights, fuse_bn)
utils.utils.load_pretrain_weights(self.exe, self.train_prog,
pretrain_weights, fuse_bn)
# 进行裁剪
if sensitivities_file is not None:
from .slim.prune_config import get_sensitivities
......@@ -211,46 +208,6 @@ class BaseAPI:
open(osp.join(save_dir, '.success'), 'w').close()
logging.info("Model saved in {}.".format(save_dir))
def export_inference_model(self, save_dir):
test_input_names = [var.name for var in list(self.test_inputs.values())]
test_outputs = list(self.test_outputs.values())
if self.__class__.__name__ == 'MaskRCNN':
from RemoteSensing.utils.save import save_mask_inference_model
save_mask_inference_model(
dirname=save_dir,
executor=self.exe,
params_filename='__params__',
feeded_var_names=test_input_names,
target_vars=test_outputs,
main_program=self.test_prog)
else:
fluid.io.save_inference_model(
dirname=save_dir,
executor=self.exe,
params_filename='__params__',
feeded_var_names=test_input_names,
target_vars=test_outputs,
main_program=self.test_prog)
model_info = self.get_model_info()
model_info['status'] = 'Infer'
# 保存模型输出的变量描述
model_info['_ModelInputsOutputs'] = dict()
model_info['_ModelInputsOutputs']['test_inputs'] = [
[k, v.name] for k, v in self.test_inputs.items()
]
model_info['_ModelInputsOutputs']['test_outputs'] = [
[k, v.name] for k, v in self.test_outputs.items()
]
with open(
osp.join(save_dir, 'model.yml'), encoding='utf-8',
mode='w') as f:
yaml.dump(model_info, f)
# 模型保存成功的标志
open(osp.join(save_dir, '.success'), 'w').close()
logging.info("Model for inference deploy saved in {}.".format(save_dir))
def train_loop(self,
num_epochs,
train_reader,
......@@ -287,8 +244,7 @@ class BaseAPI:
if self.parallel_train_prog is None:
build_strategy = fluid.compiler.BuildStrategy()
build_strategy.fuse_all_optimizer_ops = False
if RemoteSensing.env_info['place'] != 'cpu' and len(
self.places) > 1:
if __init__.env_info['place'] != 'cpu' and len(self.places) > 1:
build_strategy.sync_batch_norm = self.sync_bn
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_iteration_per_drop_scope = 1
......
......@@ -19,8 +19,8 @@ import copy
from collections import OrderedDict
import paddle.fluid as fluid
from paddle.fluid.framework import Parameter
from ..utils import logging
import RemoteSensing
from utils import logging
import models
def load_model(model_dir):
......@@ -30,12 +30,11 @@ def load_model(model_dir):
info = yaml.load(f.read(), Loader=yaml.Loader)
status = info['status']
if not hasattr(RemoteSensing.models, info['Model']):
raise Exception(
"There's no attribute {} in RemoteSensing.models".format(
info['Model']))
if not hasattr(models, info['Model']):
raise Exception("There's no attribute {} in models".format(
info['Model']))
model = getattr(RemoteSensing.models, info['Model'])(**info['_init_params'])
model = getattr(models, info['Model'])(**info['_init_params'])
if status == "Normal" or \
status == "Prune":
startup_prog = fluid.Program()
......@@ -82,7 +81,7 @@ def load_model(model_dir):
def build_transforms(transforms_info):
from ..transforms import transforms as T
from transforms import transforms as T
transforms = list()
for op_info in transforms_info:
op_name = list(op_info.keys())[0]
......
......@@ -18,11 +18,11 @@ import numpy as np
import math
import cv2
import paddle.fluid as fluid
import RemoteSensing
import RemoteSensing.utils.logging as logging
import utils.logging as logging
from collections import OrderedDict
from .base import BaseAPI
from ..utils.metrics import ConfusionMatrix
from utils.metrics import ConfusionMatrix
import nets
class UNet(BaseAPI):
......@@ -90,7 +90,7 @@ class UNet(BaseAPI):
self.trainable = True
def build_net(self, mode='train'):
model = RemoteSensing.nets.UNet(
model = nets.UNet(
self.num_classes,
mode=mode,
upsample_mode=self.upsample_mode,
......@@ -152,9 +152,9 @@ class UNet(BaseAPI):
Args:
num_epochs (int): 训练迭代轮数。
train_reader (RemoteSensing.readers): 训练数据读取器。
train_reader (readers): 训练数据读取器。
train_batch_size (int): 训练数据batch大小。同时作为验证数据batch大小。默认2。
eval_reader (RemoteSensing.readers): 评估数据读取器。
eval_reader (readers): 评估数据读取器。
save_interval_epochs (int): 模型保存间隔(单位:迭代轮数)。默认为1。
log_interval_steps (int): 训练日志输出间隔(单位:迭代次数)。默认为2。
save_dir (str): 模型保存路径。默认'output'。
......@@ -216,7 +216,7 @@ class UNet(BaseAPI):
"""评估。
Args:
eval_reader (RemoteSensing.readers): 评估数据读取器。
eval_reader (readers): 评估数据读取器。
batch_size (int): 评估时的batch大小。默认1。
verbose (bool): 是否打印日志。默认True。
epoch_id (int): 当前评估模型所在的训练轮数。
......@@ -241,6 +241,8 @@ class UNet(BaseAPI):
for step, data in enumerate(data_generator()):
images = np.array([d[0] for d in data])
images = images.astype(np.float32)
labels = np.array([d[1] for d in data])
num_samples = images.shape[0]
if num_samples < batch_size:
......@@ -283,7 +285,7 @@ class UNet(BaseAPI):
"""预测。
Args:
img_file(str): 预测图像路径。
transforms(RemoteSensing.transforms): 数据预处理操作。
transforms(transforms): 数据预处理操作。
Returns:
np.ndarray: 预测结果灰度图。
......@@ -297,6 +299,7 @@ class UNet(BaseAPI):
self.arrange_transforms(
transforms=self.test_transforms, mode='test')
im, im_info = self.test_transforms(im_file)
im = im.astype(np.float32)
im = np.expand_dims(im, axis=0)
result = self.exe.run(
self.test_prog,
......
import os
import os.path as osp
import numpy as np
from PIL import Image as Image
import argparse
from models import load_model
def parse_args():
parser = argparse.ArgumentParser(description='RemoteSensing predict')
parser.add_argument(
'--data_dir',
dest='data_dir',
help='dataset directory',
default=None,
type=str)
parser.add_argument(
'--load_model_dir',
dest='load_model_dir',
help='model load directory',
default=None,
type=str)
return parser.parse_args()
args = parse_args()
data_dir = args.data_dir
load_model_dir = args.load_model_dir
# predict
model = load_model(osp.join(load_model_dir, 'best_model'))
pred_dir = osp.join(load_model_dir, 'pred')
if not osp.exists(pred_dir):
os.mkdir(pred_dir)
val_list = osp.join(data_dir, 'val.txt')
color_map = [0, 0, 0, 255, 255, 255]
with open(val_list) as f:
lines = f.readlines()
for line in lines:
img_path = line.split(' ')[0]
print('Predicting {}'.format(img_path))
img_path_ = osp.join(data_dir, img_path)
pred = model.predict(img_path_)
# 以伪彩色png图片保存预测结果
pred_name = osp.basename(img_path).rstrip('npy') + 'png'
pred_path = osp.join(pred_dir, pred_name)
pred_mask = Image.fromarray(pred.astype(np.uint8), mode='P')
pred_mask.putpalette(color_map)
pred_mask.save(pred_path)
......@@ -22,7 +22,7 @@ import copy
import random
import platform
import chardet
from ..utils import logging
from utils import logging
class EndSignal():
......
......@@ -14,7 +14,7 @@
from __future__ import absolute_import
import os.path as osp
import random
from ..utils import logging
from utils import logging
from .base import BaseReader
from .base import get_encoding
from collections import OrderedDict
......
pre-commit
yapf == 0.26.0
flake8
pyyaml >= 5.1
Pillow
numpy
six
opencv-python
tqdm
\ No newline at end of file
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import os.path
import argparse
import warnings
def parse_args():
parser = argparse.ArgumentParser(
description='PaddleSeg generate file list on your customized dataset.')
parser.add_argument('dataset_root', help='dataset root directory', type=str)
parser.add_argument(
'--separator',
dest='separator',
help='file list separator',
default=" ",
type=str)
parser.add_argument(
'--folder',
help='the folder names of images and labels',
type=str,
nargs=2,
default=['images', 'annotations'])
parser.add_argument(
'--second_folder',
help=
'the second-level folder names of train set, validation set, test set',
type=str,
nargs='*',
default=['train', 'val', 'test'])
parser.add_argument(
'--format',
help='data format of images and labels, default npy, png.',
type=str,
nargs=2,
default=['npy', 'png'])
parser.add_argument(
'--label_class',
help='label class names',
type=str,
nargs='*',
default=['__background__', '__foreground__'])
parser.add_argument(
'--postfix',
help='postfix of images or labels',
type=str,
nargs=2,
default=['', ''])
return parser.parse_args()
def get_files(image_or_label, dataset_split, args):
dataset_root = args.dataset_root
postfix = args.postfix
format = args.format
folder = args.folder
pattern = '*%s.%s' % (postfix[image_or_label], format[image_or_label])
search_files = os.path.join(dataset_root, folder[image_or_label],
dataset_split, pattern)
search_files2 = os.path.join(dataset_root, folder[image_or_label],
dataset_split, "*", pattern) # 包含子目录
search_files3 = os.path.join(dataset_root, folder[image_or_label],
dataset_split, "*", "*", pattern) # 包含三级目录
filenames = glob.glob(search_files)
filenames2 = glob.glob(search_files2)
filenames3 = glob.glob(search_files3)
filenames = filenames + filenames2 + filenames3
return sorted(filenames)
def generate_list(args):
dataset_root = args.dataset_root
separator = args.separator
file_list = os.path.join(dataset_root, 'labels.txt')
with open(file_list, "w") as f:
for label_class in args.label_class:
f.write(label_class + '\n')
for dataset_split in args.second_folder:
print("Creating {}.txt...".format(dataset_split))
image_files = get_files(0, dataset_split, args)
label_files = get_files(1, dataset_split, args)
if not image_files:
img_dir = os.path.join(dataset_root, args.folder[0], dataset_split)
warnings.warn("No images in {} !!!".format(img_dir))
num_images = len(image_files)
if not label_files:
label_dir = os.path.join(dataset_root, args.folder[1],
dataset_split)
warnings.warn("No labels in {} !!!".format(label_dir))
num_label = len(label_files)
if num_images != num_label and num_label > 0:
raise Exception(
"Number of images = {} number of labels = {} \n"
"Either number of images is equal to number of labels, "
"or number of labels is equal to 0.\n"
"Please check your dataset!".format(num_images, num_label))
file_list = os.path.join(dataset_root, dataset_split + '.txt')
with open(file_list, "w") as f:
for item in range(num_images):
left = image_files[item].replace(dataset_root, '')
if left[0] == os.path.sep:
left = left.lstrip(os.path.sep)
try:
right = label_files[item].replace(dataset_root, '')
if right[0] == os.path.sep:
right = right.lstrip(os.path.sep)
line = left + separator + right + '\n'
except:
line = left + '\n'
f.write(line)
print(line)
if __name__ == '__main__':
args = parse_args()
generate_list(args)
import sys
import os
import os.path as osp
import cv2
import numpy as np
from PIL import Image as Image
#================================setting========================
os.environ['CUDA_VISIBLE_DEVICES'] = "1"
batch_size = 4
channel = 10
epochs = 1
save_dir = 'saved_model/snow2019_unet_all_channel_vertical'
data_dir = "../../../dataset/snow2019/all_channel_data/"
#=============================================================
sys.path.append(osp.join(os.getcwd(), '..'))
import RemoteSensing.transforms.transforms as T
from RemoteSensing.readers.reader import Reader
from RemoteSensing.models import UNet, load_model
if not os.path.exists(save_dir):
os.makedirs(save_dir)
train_list = osp.join(data_dir, 'train.txt')
val_list = osp.join(data_dir, 'val.txt')
label_list = osp.join(data_dir, 'labels.txt')
os.system('cp ./{} {}'.format(__file__, osp.join(save_dir, __file__)))
import argparse
import transforms.transforms as T
from readers.reader import Reader
from models import UNet
def parse_args():
parser = argparse.ArgumentParser(description='RemoteSensing training')
parser.add_argument(
'--data_dir',
dest='data_dir',
help='dataset directory',
default=None,
type=str)
parser.add_argument(
'--save_dir',
dest='save_dir',
help='model save directory',
default=None,
type=str)
parser.add_argument(
'--channel',
dest='channel',
help='number of data channel',
default=3,
type=int)
parser.add_argument(
'--num_epochs',
dest='num_epochs',
help='number of traing epochs',
default=100,
type=int)
parser.add_argument(
'--train_batch_size',
dest='train_batch_size',
help='training batch size',
default=4,
type=int)
parser.add_argument(
'--lr', dest='lr', help='learning rate', default=0.01, type=float)
return parser.parse_args()
args = parse_args()
data_dir = args.data_dir
save_dir = args.save_dir
channel = args.channel
num_epochs = args.num_epochs
train_batch_size = args.train_batch_size
lr = args.lr
# 定义训练和验证时的transforms
train_transforms = T.Compose([
T.RandomVerticalFlip(0.5),
T.RandomHorizontalFlip(0.5),
T.ResizeStepScaling(0.5, 2.0, 0.25),
T.RandomPaddingCrop(769),
T.RandomPaddingCrop(256),
T.Normalize(mean=[0.5] * channel, std=[0.5] * channel),
])
eval_transforms = T.Compose([
T.Padding([1049, 1049]),
T.Normalize(mean=[0.5] * channel, std=[0.5] * channel),
])
test_transforms = T.Compose([
T.Padding([1049, 1049]),
T.Normalize(mean=[0.5] * channel, std=[0.5] * channel),
])
train_list = osp.join(data_dir, 'train.txt')
val_list = osp.join(data_dir, 'val.txt')
label_list = osp.join(data_dir, 'labels.txt')
# 定义数据读取器
train_reader = Reader(
data_dir=data_dir,
file_list=train_list,
......@@ -72,37 +93,14 @@ model = UNet(
num_classes=2, input_channel=channel, use_bce_loss=True, use_dice_loss=True)
model.train(
num_epochs=epochs,
num_epochs=num_epochs,
train_reader=train_reader,
train_batch_size=batch_size,
train_batch_size=train_batch_size,
eval_reader=eval_reader,
save_interval_epochs=5,
log_interval_steps=10,
save_dir=save_dir,
pretrain_weights=None,
optimizer=None,
learning_rate=0.01,
learning_rate=lr,
)
# predict
model = load_model(osp.join(save_dir, 'best_model'))
pred_dir = osp.join(save_dir, 'pred')
if not osp.exists(pred_dir):
os.mkdir(pred_dir)
color_map = [0, 0, 0, 255, 255, 255]
with open(val_list) as f:
lines = f.readlines()
for line in lines:
img_path = line.split(' ')[0]
print('Predicting {}'.format(img_path))
img_path_ = osp.join(data_dir, img_path)
pred = model.predict(img_path_)
pred_name = osp.basename(img_path).rstrip('npy') + 'png'
pred_path = osp.join(pred_dir, pred_name)
pred_mask = Image.fromarray(pred.astype(np.uint8), mode='P')
pred_mask.putpalette(color_map)
pred_mask.save(pred_path)
......@@ -18,8 +18,12 @@ import numpy as np
from PIL import Image, ImageEnhance
def normalize(im, mean, std):
im = im.astype(np.float32, copy=False) / 255.0
def normalize(im, min_value, max_value, mean, std):
# Rescaling (min-max normalization)
range_value = [max_value[i] - min_value[i] for i in range(len(max_value))]
im = (im.astype(np.float32, copy=False) - min_value) / range_value
# Standardization (Z-score Normalization)
im -= mean
im /= std
return im
......
......@@ -170,7 +170,7 @@ class Resize:
def __init__(self, target_size, interp='LINEAR'):
self.interp = interp
assert interp in self.interp_dict, "interp should be one of {}".format(
interp_dict.keys())
self.interp_dict.keys())
if isinstance(target_size, list) or isinstance(target_size, tuple):
if len(target_size) != 2:
raise ValueError(
......@@ -271,17 +271,6 @@ class ResizeByLong:
-shape_before_resize (tuple): 保存resize之前图像的形状(h, w)。
"""
if im_info is None:
im = np.pad(
im,
pad_width=((0, pad_height), (0, pad_width), (0, 0)),
mode='constant',
constant_values=(self.im_padding_value, self.im_padding_value))
label = np.pad(
label,
pad_width=((0, pad_height), (0, pad_width)),
mode='constant',
constant_values=(self.label_padding_value,
self.label_padding_value))
im_info = OrderedDict()
im_info['shape_before_resize'] = im.shape[:2]
......@@ -420,20 +409,58 @@ class ResizeStepScaling:
return (im, im_info, label)
class Clip:
"""
对图像上超出一定范围的数据进行裁剪。
Args:
min_val (list): 裁剪的下限,小于min_val的数值均设为min_val. 默认值[0, 0, 0].
max_val (list): 裁剪的上限,大于max_val的数值均设为max_val. 默认值[255.0, 255.0, 255.0]
"""
def __init__(self, min_val=[0, 0, 0], max_val=[255.0, 255.0, 255.0]):
self.min_val = min_val
self.max_val = max_val
def __call__(self, im, im_info=None, label=None):
if isinstance(self.min_val, list) and isinstance(self.max_val, list):
for k in range(im.shape[2]):
np.clip(
im[:, :, k],
self.min_val[k],
self.max_val[k],
out=im[:, :, k])
else:
raise TypeError('min_val and max_val must be list')
if label is None:
return (im, im_info)
else:
return (im, im_info, label)
class Normalize:
"""对图像进行标准化。
1.尺度缩放到 [0,1]。
1.图像像素归一化到区间 [0.0, 1.0]。
2.对图像进行减均值除以标准差操作。
Args:
mean (list): 图像数据集的均值。默认值[0.5, 0.5, 0.5]。
std (list): 图像数据集的标准差。默认值[0.5, 0.5, 0.5]。
min_val (list): 图像数据集的最小值。默认值[0, 0, 0].
max_val (list): 图像数据集的最大值。默认值[255.0, 255.0, 255.0]
mean (list): 图像数据集的均值。默认值[0.5, 0.5, 0.5].
std (list): 图像数据集的标准差。默认值[0.5, 0.5, 0.5].
Raises:
ValueError: mean或std不是list对象。std包含0。
"""
def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
def __init__(self,
min_val=[0, 0, 0],
max_val=[255.0, 255.0, 255.0],
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5]):
self.min_val = min_val
self.max_val = max_val
self.mean = mean
self.std = std
if not (isinstance(self.mean, list) and isinstance(self.std, list)):
......@@ -457,7 +484,8 @@ class Normalize:
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
im = normalize(im, mean, std)
im = normalize(im, self.min_val, self.max_val, mean, std)
if label is None:
return (im, im_info)
......@@ -471,7 +499,7 @@ class Padding:
Args:
target_size (int/list/tuple): padding后图像的大小。
im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]
im_padding_value (list): 图像padding的值。默认为127.5
label_padding_value (int): 标注图像padding的值。默认值为255。
Raises:
......@@ -554,7 +582,7 @@ class RandomPaddingCrop:
Args:
crop_size(int or list or tuple): 裁剪图像大小。默认为512。
im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。
im_padding_value (list): 图像padding的值。默认为127.5
label_padding_value (int): 标注图像padding的值。默认值为255。
Raises:
......@@ -684,75 +712,6 @@ class RandomBlur:
return (im, im_info, label)
class RandomRotation:
"""对图像进行随机旋转。
在不超过最大旋转角度的情况下,图像进行随机旋转,当存在标注图像时,同步进行,
并对旋转后的图像和标注图像进行相应的padding。
Args:
max_rotation (float): 最大旋转角度。默认为15度。
im_padding_value (list): 图像padding的值。默认为[127.5, 127.5, 127.5]。
label_padding_value (int): 标注图像padding的值。默认为255。
"""
def __init__(self,
max_rotation=15,
im_padding_value=[127.5, 127.5, 127.5],
label_padding_value=255):
self.max_rotation = max_rotation
self.im_padding_value = im_padding_value
self.label_padding_value = label_padding_value
def __call__(self, im, im_info=None, label=None):
"""
Args:
im (np.ndarray): 图像np.ndarray数据。
im_info (dict): 存储与图像相关的信息。
label (np.ndarray): 标注图像np.ndarray数据。
Returns:
tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
存储与图像相关信息的字典和标注图像np.ndarray数据。
"""
if self.max_rotation > 0:
(h, w) = im.shape[:2]
do_rotation = np.random.uniform(-self.max_rotation,
self.max_rotation)
pc = (w // 2, h // 2)
r = cv2.getRotationMatrix2D(pc, do_rotation, 1.0)
cos = np.abs(r[0, 0])
sin = np.abs(r[0, 1])
nw = int((h * sin) + (w * cos))
nh = int((h * cos) + (w * sin))
(cx, cy) = pc
r[0, 2] += (nw / 2) - cx
r[1, 2] += (nh / 2) - cy
dsize = (nw, nh)
im = cv2.warpAffine(
im,
r,
dsize=dsize,
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=self.im_padding_value)
label = cv2.warpAffine(
label,
r,
dsize=dsize,
flags=cv2.INTER_NEAREST,
borderMode=cv2.BORDER_CONSTANT,
borderValue=self.label_padding_value)
if label is None:
return (im, im_info)
else:
return (im, im_info, label)
class RandomScaleAspect:
"""裁剪并resize回原始尺寸的图像和标注图像。
按照一定的面积比和宽高比对图像进行裁剪,并reszie回原始图像的图像,当存在标注图时,同步进行。
......@@ -813,116 +772,6 @@ class RandomScaleAspect:
return (im, im_info, label)
class RandomDistort:
"""对图像进行随机失真。
1. 确定随机失真操作[变换明亮度、变换对比度、变换饱和度、变换色彩]的执行顺序。
2. 以一定的概率执行每个随机扰动操作。
Args:
brightness_range (float): 明亮度因子的范围。默认为0.5。
brightness_prob (float): 随机调整明亮度的概率。默认为0.5。
contrast_range (float): 对比度因子的范围。默认为0.5。
contrast_prob (float): 随机调整对比度的概率。默认为0.5。
saturation_range (float): 饱和度因子的范围。默认为0.5。
saturation_prob (float): 随机调整饱和度的概率。默认为0.5。
hue_range (int): 色调因子的范围。默认为18。
hue_prob (float): 随机调整色调的概率。默认为0.5。
is_order (bool): 是否按照固定顺序
[变换明亮度、变换对比度、变换饱和度、变换色彩]
执行像素内容变换操作。默认为False。
"""
def __init__(self,
brightness_range=0.5,
brightness_prob=0.5,
contrast_range=0.5,
contrast_prob=0.5,
saturation_range=0.5,
saturation_prob=0.5,
hue_range=18,
hue_prob=0.5,
is_order=False):
self.brightness_range = brightness_range
self.brightness_prob = brightness_prob
self.contrast_range = contrast_range
self.contrast_prob = contrast_prob
self.saturation_range = saturation_range
self.saturation_prob = saturation_prob
self.hue_range = hue_range
self.hue_prob = hue_prob
self.is_order = is_order
def __call__(self, im, im_info=None, label_info=None):
"""
Args:
im (np.ndarray): 图像np.ndarray数据。
im_info (dict): 存储与图像相关的信息。
label (np.ndarray): 标注图像np.ndarray数据。
Returns:
tuple: 当label为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
当label不为空时,返回的tuple为(im, im_info, label),分别对应图像np.ndarray数据、
存储与图像相关信息的字典和标注图像np.ndarray数据。
"""
brightness_lower = 1 - self.brightness_range
brightness_upper = 1 + self.brightness_range
contrast_lower = 1 - self.contrast_range
contrast_upper = 1 + self.contrast_range
saturation_lower = 1 - self.saturation_range
saturation_upper = 1 + self.saturation_range
hue_lower = -self.hue_range
hue_upper = self.hue_range
ops = [brightness, contrast, saturation, hue]
if self.is_order:
prob = np.random.uniform(0, 1)
if prob < 0.5:
ops = [
brightness,
saturation,
hue,
contrast,
]
else:
random.shuffle(ops)
params_dict = {
'brightness': {
'brightness_lower': brightness_lower,
'brightness_upper': brightness_upper
},
'contrast': {
'contrast_lower': contrast_lower,
'contrast_upper': contrast_upper
},
'saturation': {
'saturation_lower': saturation_lower,
'saturation_upper': saturation_upper
},
'hue': {
'hue_lower': hue_lower,
'hue_upper': hue_upper
}
}
prob_dict = {
'brightness': self.brightness_prob,
'contrast': self.contrast_prob,
'saturation': self.saturation_prob,
'hue': self.hue_prob
}
im = Image.fromarray(im)
for id in range(4):
params = params_dict[ops[id].__name__]
prob = prob_dict[ops[id].__name__]
params['im'] = im
if np.random.uniform(0, 1) < prob:
im = ops[id](**params)
im = np.asarray(im)
if label is None:
return (im, im_info)
else:
return (im, im_info, label)
class ArrangeSegmenter:
"""获取训练/验证/预测所需的信息。
......
......@@ -15,7 +15,7 @@
import time
import os
import sys
import RemoteSensing
import __init__
levels = {0: 'ERROR', 1: 'WARNING', 2: 'INFO', 3: 'DEBUG'}
......@@ -24,7 +24,7 @@ def log(level=2, message=""):
current_time = time.time()
time_array = time.localtime(current_time)
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array)
if RemoteSensing.log_level >= level:
if __init__.log_level >= level:
print("{} [{}]\t{}".format(current_time, levels[level],
message).encode("utf-8").decode("latin1"))
sys.stdout.flush()
......
import RemoteSensing
import os
import os.path as osp
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册