提交 570dea4f 编写于 作者: C chenguowei01

Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleSeg into develop

......@@ -10,7 +10,7 @@
PaddleSeg是基于[PaddlePaddle](https://www.paddlepaddle.org.cn)开发的端到端图像分割开发套件,覆盖了DeepLabv3+, U-Net, ICNet, PSPNet, HRNet, Fast-SCNN等主流分割网络。通过模块化的设计,以配置化方式驱动模型组合,帮助开发者更便捷地完成从训练到部署的全流程图像分割应用。
- [特点](#特点)
- [特点](#特点)
- [安装](#安装)
- [使用教程](#使用教程)
- [快速入门](#快速入门)
......@@ -31,7 +31,7 @@ PaddleSeg是基于[PaddlePaddle](https://www.paddlepaddle.org.cn)开发的端到
- **模块化设计**
支持U-Net, DeepLabv3+, ICNet, PSPNet, HRNet, Fast-SCNN六种主流分割网络,结合预训练模型和可调节的骨干网络,满足不同性能和精度的要求;选择不同的损失函数如Dice Loss, BCE Loss等方式可以强化小目标和不均衡样本场景下的分割精度。
支持U-Net, DeepLabv3+, ICNet, PSPNet, HRNet, Fast-SCNN六种主流分割网络,结合预训练模型和可调节的骨干网络,满足不同性能和精度的要求;选择不同的损失函数如Dice Loss, Lovasz Loss等方式可以强化小目标和不均衡样本场景下的分割精度。
- **高性能**
......@@ -107,8 +107,8 @@ pip install -r requirements.txt
### 高级功能
* [PaddleSeg的数据增强](./docs/data_aug.md)
* [如何解决二分类中类别不均衡问题](./docs/loss_select.md)
* [特色垂类模型使用](./contrib)
* [PaddleSeg的loss选择](./docs/loss_select.md)
* [PaddleSeg产业实践](./contrib)
* [多进程训练和混合精度训练](./docs/multiple_gpus_train_and_mixed_precision_train.md)
* 使用PaddleSlim进行分割模型压缩([量化](./slim/quantization/README.md), [蒸馏](./slim/distillation/README.md), [剪枝](./slim/prune/README.md), [搜索](./slim/nas/README.md))
## 在线体验
......@@ -162,15 +162,15 @@ A: 降低Batch size,使用Group Norm策略;请注意训练过程中当`DEFAU
* 新增[气象遥感分割方案](./contrib/RemoteSensing),支持积雪识别、云检测等气象遥感场景。
* 新增[Lovasz Loss](docs/lovasz_loss.md),解决数据类别不均衡问题。
* 使用VisualDL 2.0作为训练可视化工具
* 2020.02.25
**`v0.4.0`**
* 新增适用于实时场景且不需要预训练模型的分割网络Fast-SCNN,提供基于Cityscapes的[预训练模型](./docs/model_zoo.md)1个
* 新增LaneNet车道线检测网络,提供[预训练模型](https://github.com/PaddlePaddle/PaddleSeg/tree/release/v0.4.0/contrib/LaneNet#%E4%B8%83-%E5%8F%AF%E8%A7%86%E5%8C%96)一个
* 新增基于PaddleSlim的分割库压缩策略([量化](./slim/quantization/README.md), [蒸馏](./slim/distillation/README.md), [剪枝](./slim/prune/README.md), [搜索](./slim/nas/README.md))
* 2019.12.15
**`v0.3.0`**
......@@ -182,7 +182,7 @@ A: 降低Batch size,使用Group Norm策略;请注意训练过程中当`DEFAU
* 新增Paddle-Lite移动端部署方案,支持人像分割模型的移动端部署。
* 新增不同分割模型的预测[性能数据Benchmark](./deploy/python/docs/PaddleSeg_Infer_Benchmark.md), 便于开发者提供模型选型性能参考。
* 2019.11.04
**`v0.2.0`**
......
......@@ -70,10 +70,30 @@ python video_infer.py --model_dir pretrained_weights/humanseg_lite_inference --v
<img src="https://paddleseg.bj.bcebos.com/humanseg/data/video_test.gif" width="20%" height="20%"><img src="https://paddleseg.bj.bcebos.com/humanseg/data/result.gif" width="20%" height="20%">
根据所选背景进行背景替换,背景可以是一张图片,也可以是一段视频。
```bash
# 通过电脑摄像头进行实时背景替换处理, 也可通过'--background_video_path'传入背景视频
python bg_replace.py --model_dir pretrained_weights/humanseg_lite_inference --background_image_path data/background.jpg
# 对人像视频进行背景替换处理, 也可通过'--background_video_path'传入背景视频
python bg_replace.py --model_dir pretrained_weights/humanseg_lite_inference --video_path data/video_test.mp4 --background_image_path data/background.jpg
# 对单张图像进行背景替换
python bg_replace.py --model_dir pretrained_weights/humanseg_lite_inference --image_path data/human_image.jpg --background_image_path data/background.jpg
```
背景替换结果如下:
<img src="https://paddleseg.bj.bcebos.com/humanseg/data/video_test.gif" width="20%" height="20%"><img src="https://paddleseg.bj.bcebos.com/humanseg/data/bg_replace.gif" width="20%" height="20%">
**NOTE**:
视频分割处理时间需要几分钟,请耐心等待。
提供的模型适用于手机摄像头竖屏拍摄场景,宽屏效果会略差一些。
## 训练
使用下述命令基于与训练模型进行Fine-tuning,请确保选用的模型结构`model_type`与模型参数`pretrained_weights`匹配。
```bash
......@@ -122,11 +142,12 @@ python val.py --model_dir output/best_model \
* `--image_shape`: 网络输入图像大小(w, h)
## 预测
使用下述命令进行预测
使用下述命令进行预测, 预测结果默认保存在`./output/result/`文件夹中。
```bash
python infer.py --model_dir output/best_model \
--data_dir data/mini_supervisely \
--test_list data/mini_supervisely/test.txt \
--save_dir output/result \
--image_shape 192 192
```
其中参数含义如下:
......
# coding: utf8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import os.path as osp
import cv2
import numpy as np
from utils.humanseg_postprocess import postprocess, threshold_mask
import models
import transforms
def parse_args():
parser = argparse.ArgumentParser(description='HumanSeg inference for video')
parser.add_argument(
'--model_dir',
dest='model_dir',
help='Model path for inference',
type=str)
parser.add_argument(
'--image_path',
dest='image_path',
help='Image including human',
type=str,
default=None)
parser.add_argument(
'--background_image_path',
dest='background_image_path',
help='Background image for replacing',
type=str,
default=None)
parser.add_argument(
'--video_path',
dest='video_path',
help='Video path for inference',
type=str,
default=None)
parser.add_argument(
'--background_video_path',
dest='background_video_path',
help='Background video path for replacing',
type=str,
default=None)
parser.add_argument(
'--save_dir',
dest='save_dir',
help='The directory for saving the inference results',
type=str,
default='./output')
parser.add_argument(
"--image_shape",
dest="image_shape",
help="The image shape for net inputs.",
nargs=2,
default=[192, 192],
type=int)
return parser.parse_args()
def predict(img, model, test_transforms):
model.arrange_transform(transforms=test_transforms, mode='test')
img, im_info = test_transforms(img)
img = np.expand_dims(img, axis=0)
result = model.exe.run(
model.test_prog,
feed={'image': img},
fetch_list=list(model.test_outputs.values()))
score_map = result[1]
score_map = np.squeeze(score_map, axis=0)
score_map = np.transpose(score_map, (1, 2, 0))
return score_map, im_info
def recover(img, im_info):
keys = list(im_info.keys())
for k in keys[::-1]:
if k == 'shape_before_resize':
h, w = im_info[k][0], im_info[k][1]
img = cv2.resize(img, (w, h), cv2.INTER_LINEAR)
elif k == 'shape_before_padding':
h, w = im_info[k][0], im_info[k][1]
img = img[0:h, 0:w]
return img
def bg_replace(score_map, img, bg):
h, w, _ = img.shape
bg = cv2.resize(bg, (w, h))
score_map = np.repeat(score_map[:, :, np.newaxis], 3, axis=2)
comb = (score_map * img + (1 - score_map) * bg).astype(np.uint8)
return comb
def infer(args):
resize_h = args.image_shape[1]
resize_w = args.image_shape[0]
test_transforms = transforms.Compose(
[transforms.Resize((resize_w, resize_h)),
transforms.Normalize()])
model = models.load_model(args.model_dir)
if not osp.exists(args.save_dir):
os.makedirs(args.save_dir)
# 图像背景替换
if args.image_path is not None:
if not osp.exists(args.image_path):
raise ('The --image_path is not existed: {}'.format(
args.image_path))
if args.background_image_path is None:
raise ('The --background_image_path is not set. Please set it')
else:
if not osp.exists(args.background_image_path):
raise ('The --background_image_path is not existed: {}'.format(
args.background_image_path))
img = cv2.imread(args.image_path)
score_map, im_info = predict(img, model, test_transforms)
score_map = score_map[:, :, 1]
score_map = recover(score_map, im_info)
bg = cv2.imread(args.background_image_path)
save_name = osp.basename(args.image_path)
save_path = osp.join(args.save_dir, save_name)
result = bg_replace(score_map, img, bg)
cv2.imwrite(save_path, result)
# 视频背景替换,如果提供背景视频则以背景视频作为背景,否则采用提供的背景图片
else:
is_video_bg = False
if args.background_video_path is not None:
if not osp.exists(args.background_video_path):
raise ('The --background_video_path is not existed: {}'.format(
args.background_video_path))
is_video_bg = True
elif args.background_image_path is not None:
if not osp.exists(args.background_image_path):
raise ('The --background_image_path is not existed: {}'.format(
args.background_image_path))
else:
raise (
'Please offer backgound image or video. You should set --backbground_iamge_paht or --background_video_path'
)
disflow = cv2.DISOpticalFlow_create(
cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST)
prev_gray = np.zeros((resize_h, resize_w), np.uint8)
prev_cfd = np.zeros((resize_h, resize_w), np.float32)
is_init = True
if args.video_path is not None:
print('Please waite. It is computing......')
if not osp.exists(args.video_path):
raise ('The --video_path is not existed: {}'.format(
args.video_path))
cap_video = cv2.VideoCapture(args.video_path)
fps = cap_video.get(cv2.CAP_PROP_FPS)
width = int(cap_video.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap_video.get(cv2.CAP_PROP_FRAME_HEIGHT))
save_name = osp.basename(args.video_path)
save_name = save_name.split('.')[0]
save_path = osp.join(args.save_dir, save_name + '.avi')
cap_out = cv2.VideoWriter(
save_path, cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps,
(width, height))
if is_video_bg:
cap_bg = cv2.VideoCapture(args.background_video_path)
frames_bg = cap_bg.get(cv2.CAP_PROP_FRAME_COUNT)
current_frame_bg = 1
else:
img_bg = cv2.imread(args.background_image_path)
while cap_video.isOpened():
ret, frame = cap_video.read()
if ret:
score_map, im_info = predict(frame, model, test_transforms)
cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
score_map = 255 * score_map[:, :, 1]
optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \
disflow, is_init)
prev_gray = cur_gray.copy()
prev_cfd = optflow_map.copy()
is_init = False
optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
optflow_map = threshold_mask(
optflow_map, thresh_bg=0.2, thresh_fg=0.8)
score_map = recover(optflow_map, im_info)
#循环读取背景帧
if is_video_bg:
ret_bg, frame_bg = cap_bg.read()
if ret_bg:
if current_frame_bg == frames_bg:
current_frame_bg = 1
cap_bg.set(cv2.CAP_PROP_POS_FRAMES, 0)
else:
break
current_frame_bg += 1
comb = bg_replace(score_map, frame, frame_bg)
else:
comb = bg_replace(score_map, frame, img_bg)
cap_out.write(comb)
else:
break
if is_video_bg:
cap_bg.release()
cap_video.release()
cap_out.release()
# 当没有输入预测图像和视频的时候,则打开摄像头
else:
cap_video = cv2.VideoCapture(0)
if not cap_video.isOpened():
raise IOError("Error opening video stream or file, "
"--video_path whether existing: {}"
" or camera whether working".format(
args.video_path))
return
if is_video_bg:
cap_bg = cv2.VideoCapture(args.background_video_path)
frames_bg = cap_bg.get(cv2.CAP_PROP_FRAME_COUNT)
current_frame_bg = 1
else:
img_bg = cv2.imread(args.background_image_path)
while cap_video.isOpened():
ret, frame = cap_video.read()
if ret:
score_map, im_info = predict(frame, model, test_transforms)
cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
score_map = 255 * score_map[:, :, 1]
optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \
disflow, is_init)
prev_gray = cur_gray.copy()
prev_cfd = optflow_map.copy()
is_init = False
optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
optflow_map = threshold_mask(
optflow_map, thresh_bg=0.2, thresh_fg=0.8)
score_map = recover(optflow_map, im_info)
#循环读取背景帧
if is_video_bg:
ret_bg, frame_bg = cap_bg.read()
if ret_bg:
if current_frame_bg == frames_bg:
current_frame_bg = 1
cap_bg.set(cv2.CAP_PROP_POS_FRAMES, 0)
else:
break
current_frame_bg += 1
comb = bg_replace(score_map, frame, frame_bg)
else:
comb = bg_replace(score_map, frame, img_bg)
cv2.imshow('HumanSegmentation', comb)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
else:
break
if is_video_bg:
cap_bg.release()
cap_video.release()
if __name__ == "__main__":
args = parse_args()
infer(args)
......@@ -14,13 +14,6 @@
# limitations under the License.
import numpy as np
import cv2
import os
def get_round(data):
round = 0.5 if data >= 0 else -0.5
return (int)(data + round)
def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow):
......@@ -41,26 +34,28 @@ def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow):
is_track = np.zeros_like(pre_gray)
flow_fw = disflow.calc(pre_gray, cur_gray, None)
flow_bw = disflow.calc(cur_gray, pre_gray, None)
for r in range(h):
for c in range(w):
fxy_fw = flow_fw[r, c]
dx_fw = get_round(fxy_fw[0])
cur_x = dx_fw + c
dy_fw = get_round(fxy_fw[1])
cur_y = dy_fw + r
if cur_x < 0 or cur_x >= w or cur_y < 0 or cur_y >= h:
continue
fxy_bw = flow_bw[cur_y, cur_x]
dx_bw = get_round(fxy_bw[0])
dy_bw = get_round(fxy_bw[1])
if ((dy_fw + dy_bw) * (dy_fw + dy_bw) +
(dx_fw + dx_bw) * (dx_fw + dx_bw)) >= check_thres:
continue
if abs(dy_fw) <= 0 and abs(dx_fw) <= 0 and abs(dy_bw) <= 0 and abs(
dx_bw) <= 0:
dl_weights[cur_y, cur_x] = 0.05
is_track[cur_y, cur_x] = 1
track_cfd[cur_y, cur_x] = prev_cfd[r, c]
flow_fw = np.round(flow_fw).astype(np.int)
flow_bw = np.round(flow_bw).astype(np.int)
y_list = np.array(range(h))
x_list = np.array(range(w))
yv, xv = np.meshgrid(y_list, x_list)
yv, xv = yv.T, xv.T
cur_x = xv + flow_fw[:, :, 0]
cur_y = yv + flow_fw[:, :, 1]
# 超出边界不跟踪
not_track = (cur_x < 0) + (cur_x >= w) + (cur_y < 0) + (cur_y >= h)
flow_bw[~not_track] = flow_bw[cur_y[~not_track], cur_x[~not_track]]
not_track += (np.square(flow_fw[:, :, 0] + flow_bw[:, :, 0]) +
np.square(flow_fw[:, :, 1] + flow_bw[:, :, 1])) >= check_thres
track_cfd[cur_y[~not_track], cur_x[~not_track]] = prev_cfd[~not_track]
is_track[cur_y[~not_track], cur_x[~not_track]] = 1
not_flow = np.all(
np.abs(flow_fw) == 0, axis=-1) * np.all(
np.abs(flow_bw) == 0, axis=-1)
dl_weights[cur_y[not_flow], cur_x[not_flow]] = 0.05
return track_cfd, is_track, dl_weights
......@@ -75,24 +70,27 @@ def human_seg_track_fuse(track_cfd, dl_cfd, dl_weights, is_track):
cur_cfd: 光流跟踪图和人像分割结果融合图
"""
fusion_cfd = dl_cfd.copy()
idxs = np.where(is_track > 0)
for i in range(len(idxs[0])):
x, y = idxs[0][i], idxs[1][i]
dl_score = dl_cfd[x, y]
track_score = track_cfd[x, y]
fusion_cfd[x, y] = dl_weights[x, y] * dl_score + (
1 - dl_weights[x, y]) * track_score
if dl_score > 0.9 or dl_score < 0.1:
if dl_weights[x, y] < 0.1:
fusion_cfd[x, y] = 0.3 * dl_score + 0.7 * track_score
else:
fusion_cfd[x, y] = 0.4 * dl_score + 0.6 * track_score
else:
fusion_cfd[x, y] = dl_weights[x, y] * dl_score + (
1 - dl_weights[x, y]) * track_score
is_track = is_track.astype(np.bool)
fusion_cfd[is_track] = dl_weights[is_track] * dl_cfd[is_track] + (
1 - dl_weights[is_track]) * track_cfd[is_track]
# 确定区域
index_certain = ((dl_cfd > 0.9) + (dl_cfd < 0.1)) * is_track
index_less01 = (dl_weights < 0.1) * index_certain
fusion_cfd[index_less01] = 0.3 * dl_cfd[index_less01] + 0.7 * track_cfd[
index_less01]
index_larger09 = (dl_weights >= 0.1) * index_certain
fusion_cfd[index_larger09] = 0.4 * dl_cfd[index_larger09] + 0.6 * track_cfd[
index_larger09]
return fusion_cfd
def threshold_mask(img, thresh_bg, thresh_fg):
dst = (img / 255.0 - thresh_bg) / (thresh_fg - thresh_bg)
dst[np.where(dst > 1)] = 1
dst[np.where(dst < 0)] = 0
return dst.astype(np.float32)
def postprocess(cur_gray, scoremap, prev_gray, pre_cfd, disflow, is_init):
"""光流优化
Args:
......@@ -105,13 +103,10 @@ def postprocess(cur_gray, scoremap, prev_gray, pre_cfd, disflow, is_init):
Returns:
fusion_cfd : 光流追踪图和预测结果融合图
"""
height, width = scoremap.shape[0], scoremap.shape[1]
disflow = cv2.DISOpticalFlow_create(cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST)
h, w = scoremap.shape
cur_cfd = scoremap.copy()
if is_init:
is_init = False
if h <= 64 or w <= 64:
disflow.setFinestScale(1)
elif h <= 160 or w <= 160:
......@@ -120,18 +115,9 @@ def postprocess(cur_gray, scoremap, prev_gray, pre_cfd, disflow, is_init):
disflow.setFinestScale(3)
fusion_cfd = cur_cfd
else:
weights = np.ones((w, h), np.float32) * 0.3
weights = np.ones((h, w), np.float32) * 0.3
track_cfd, is_track, weights = human_seg_tracking(
prev_gray, cur_gray, pre_cfd, weights, disflow)
fusion_cfd = human_seg_track_fuse(track_cfd, cur_cfd, weights, is_track)
fusion_cfd = cv2.GaussianBlur(fusion_cfd, (3, 3), 0)
return fusion_cfd
def threshold_mask(img, thresh_bg, thresh_fg):
dst = (img / 255.0 - thresh_bg) / (thresh_fg - thresh_bg)
dst[np.where(dst > 1)] = 1
dst[np.where(dst < 0)] = 0
return dst.astype(np.float32)
......@@ -109,7 +109,7 @@ def video_infer(args):
fps = cap.get(cv2.CAP_PROP_FPS)
if args.video_path:
print('Please waite. It is computing......')
# 用于保存预测结果视频
if not osp.exists(args.save_dir):
os.makedirs(args.save_dir)
......@@ -123,8 +123,8 @@ def video_infer(args):
score_map, im_info = predict(frame, model, test_transforms)
cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
scoremap = 255 * score_map[:, :, 1]
optflow_map = postprocess(cur_gray, scoremap, prev_gray, prev_cfd, \
score_map = 255 * score_map[:, :, 1]
optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \
disflow, is_init)
prev_gray = cur_gray.copy()
prev_cfd = optflow_map.copy()
......@@ -132,10 +132,11 @@ def video_infer(args):
optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
optflow_map = threshold_mask(
optflow_map, thresh_bg=0.2, thresh_fg=0.8)
img_mat = np.repeat(optflow_map[:, :, np.newaxis], 3, axis=2)
img_mat = recover(img_mat, im_info)
bg_im = np.ones_like(img_mat) * 255
comb = (img_mat * frame + (1 - img_mat) * bg_im).astype(
img_matting = np.repeat(
optflow_map[:, :, np.newaxis], 3, axis=2)
img_matting = recover(img_matting, im_info)
bg_im = np.ones_like(img_matting) * 255
comb = (img_matting * frame + (1 - img_matting) * bg_im).astype(
np.uint8)
out.write(comb)
else:
......@@ -150,20 +151,20 @@ def video_infer(args):
score_map, im_info = predict(frame, model, test_transforms)
cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
scoremap = 255 * score_map[:, :, 1]
optflow_map = postprocess(cur_gray, scoremap, prev_gray, prev_cfd, \
score_map = 255 * score_map[:, :, 1]
optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \
disflow, is_init)
prev_gray = cur_gray.copy()
prev_cfd = optflow_map.copy()
is_init = False
# optflow_map = optflow_map/255.0
optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
optflow_map = threshold_mask(
optflow_map, thresh_bg=0.2, thresh_fg=0.8)
img_mat = np.repeat(optflow_map[:, :, np.newaxis], 3, axis=2)
img_mat = recover(img_mat, im_info)
bg_im = np.ones_like(img_mat) * 255
comb = (img_mat * frame + (1 - img_mat) * bg_im).astype(
img_matting = np.repeat(
optflow_map[:, :, np.newaxis], 3, axis=2)
img_matting = recover(img_matting, im_info)
bg_im = np.ones_like(img_matting) * 255
comb = (img_matting * frame + (1 - img_matting) * bg_im).astype(
np.uint8)
cv2.imshow('HumanSegmentation', comb)
if cv2.waitKey(1) & 0xFF == ord('q'):
......
......@@ -48,9 +48,6 @@ endif()
if(EXISTS "${PADDLE_DIR}/third_party/install/snappystream/include")
include_directories("${PADDLE_DIR}/third_party/install/snappystream/include")
endif()
include_directories("${PADDLE_DIR}/third_party/install/zlib/include")
include_directories("${PADDLE_DIR}/third_party/boost")
include_directories("${PADDLE_DIR}/third_party/eigen3")
if (EXISTS "${PADDLE_DIR}/third_party/install/snappy/lib")
link_directories("${PADDLE_DIR}/third_party/install/snappy/lib")
......@@ -59,7 +56,6 @@ if(EXISTS "${PADDLE_DIR}/third_party/install/snappystream/lib")
link_directories("${PADDLE_DIR}/third_party/install/snappystream/lib")
endif()
link_directories("${PADDLE_DIR}/third_party/install/zlib/lib")
link_directories("${PADDLE_DIR}/third_party/install/protobuf/lib")
link_directories("${PADDLE_DIR}/third_party/install/glog/lib")
link_directories("${PADDLE_DIR}/third_party/install/gflags/lib")
......@@ -171,7 +167,7 @@ if (NOT WIN32)
set(EXTERNAL_LIB "-lrt -ldl -lpthread")
set(DEPS ${DEPS}
${MATH_LIB} ${MKLDNN_LIB}
glog gflags protobuf yaml-cpp z xxhash
glog gflags protobuf yaml-cpp xxhash
${EXTERNAL_LIB})
if(EXISTS "${PADDLE_DIR}/third_party/install/snappystream/lib")
set(DEPS ${DEPS} snappystream)
......@@ -182,7 +178,7 @@ if (NOT WIN32)
else()
set(DEPS ${DEPS}
${MATH_LIB} ${MKLDNN_LIB}
opencv_world346 glog libyaml-cppmt gflags_static libprotobuf zlibstatic xxhash ${EXTERNAL_LIB})
opencv_world346 glog libyaml-cppmt gflags_static libprotobuf xxhash ${EXTERNAL_LIB})
set(DEPS ${DEPS} libcmt shlwapi)
if (EXISTS "${PADDLE_DIR}/third_party/install/snappy/lib")
set(DEPS ${DEPS} snappy)
......
......@@ -107,7 +107,6 @@ class DeployConfig:
self.use_pr = deploy_conf["USE_PR"]
class ImageReader:
def __init__(self, configs):
self.config = configs
......@@ -117,19 +116,18 @@ class ImageReader:
def process_worker(self, imgs, idx, use_pr=False):
image_path = imgs[idx]
im = cv2.imread(image_path, -1)
channels = im.shape[2]
ori_h = im.shape[0]
ori_w = im.shape[1]
if channels == 1:
if len(im.shape) == 2:
im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)
channels = im.shape[2]
channels = im.shape[2]
if channels != 3 and channels != 4:
print("Only support rgb(gray) or rgba image.")
return -1
ori_h = im.shape[0]
ori_w = im.shape[1]
# resize to eval_crop_size
eval_crop_size = self.config.eval_crop_size
if (ori_h != eval_crop_size[0] or ori_w != eval_crop_size[1]):
if (ori_h != eval_crop_size[1] or ori_w != eval_crop_size[0]):
im = cv2.resize(
im, eval_crop_size, fx=0, fy=0, interpolation=cv2.INTER_LINEAR)
......
# Dice loss
对于二类图像分割任务中,经常出现类别分布不均匀的情况,例如:工业产品的瑕疵检测、道路提取及病变区域提取等。我们可使用dice loss(dice coefficient loss)解决这个问题。
注:dice loss和bce loss仅支持二分类。
## 原理介绍
Dice loss的定义如下:
<p align="center">
<img src="./imgs/dice.png" hspace='10' height="46" width="200"/> <br />
</p>
其中 Y 表示ground truth,P 表示预测结果。| |表示矩阵元素之和。![](./imgs/dice2.png) 表示*Y**P*的共有元素数,
实际通过求两者的逐像素乘积之和进行计算。例如:
<p align="center">
<img src="./imgs/dice3.png" hspace='10' /> <br />
</p>
其中 1 表示前景,0 表示背景。
**Note:** 在标注图片中,务必保证前景像素值为1,背景像素值为0.
Dice系数请参见[维基百科](https://zh.wikipedia.org/wiki/Dice%E7%B3%BB%E6%95%B0)
**为什么在类别不均衡问题上,dice loss效果比softmax loss更好?**
首先来看softmax loss的定义:
<p align="center">
<img src="./imgs/softmax_loss.png" height="130" /> <br />
</p>
其中 y 表示ground truth,p 表示网络输出。
在图像分割中,`softmax loss`评估每一个像素点的类别预测,然后平均所有的像素点。这个本质上就是对图片上的每个像素进行平等的学习。这就造成了一个问题,如果在图像上的多种类别有不平衡的表征,那么训练会由最主流的类别主导。以上面DeepGlobe道路提取的数据为例子,网络将偏向于背景的学习,降低了网络对前景目标的提取能力。
`dice loss(dice coefficient loss)`通过预测和标注的交集除以它们的总体像素进行计算,它将一个类别的所有像素作为一个整体作为考量,而且计算交集在总体中的占比,所以不受大量背景像素的影响,能够取得更好的效果。
在实际应用中`dice loss`往往与`bce loss(binary cross entroy loss)`结合使用,提高模型训练的稳定性。
## PaddleSeg指定训练loss
PaddleSeg通过`cfg.SOLVER.LOSS`参数可以选择训练时的损失函数,
`cfg.SOLVER.LOSS=['dice_loss','bce_loss']`将指定训练loss为`dice loss``bce loss`的组合
## Dice loss解决类别不均衡问题的示例
我们以道路提取任务为例应用dice loss.
在DeepGlobe比赛的Road Extraction中,训练数据道路占比为:4.5%. 如下为其图片样例:
<p align="center">
<img src="./imgs/deepglobe.png" hspace='10'/> <br />
</p>
可以看出道路在整张图片中的比例很小。
### 数据集下载
我们从DeepGlobe比赛的Road Extraction的训练集中随机抽取了800张图片作为训练集,200张图片作为验证集,
制作了一个小型的道路提取数据集[MiniDeepGlobeRoadExtraction](https://paddleseg.bj.bcebos.com/dataset/MiniDeepGlobeRoadExtraction.zip)
### 实验比较
在MiniDeepGlobeRoadExtraction数据集进行了实验比较。
* 数据集下载
```shell
python dataset/download_mini_deepglobe_road_extraction.py
```
* 预训练模型下载
```shell
python pretrained_model/download_model.py deeplabv3p_mobilenetv2-1-0_bn_coco
```
* 配置/数据校验
```shell
python pdseg/check.py --cfg ./configs/deepglobe_road_extraction.yaml
```
* 训练
```shell
python pdseg/train.py --cfg ./configs/deepglobe_road_extraction.yaml --use_gpu SOLVER.LOSS "['dice_loss','bce_loss']"
```
* 评估
```
python pdseg/eval.py --cfg ./configs/deepglobe_road_extraction.yaml --use_gpu SOLVER.LOSS "['dice_loss','bce_loss']"
```
* 结果比较
softmax loss和dice loss + bce loss实验结果如下图所示。
图中橙色曲线为dice loss + bce loss,最高mIoU为76.02%,蓝色曲线为softmax loss, 最高mIoU为73.62%。
<p align="center">
<img src="./imgs/loss_comparison.png" hspace='10' height="208" width="516"/> <br />
</p>
# 如何解决二分类中类别不均衡问题
对于二类图像分割任务中,经常出现类别分布不均匀的情况,例如:工业产品的瑕疵检测、道路提取及病变区域提取等。
目前PaddleSeg提供了三种loss函数,分别为softmax loss(sotfmax with cross entroy loss)、dice loss(dice coefficient loss)和bce loss(binary cross entroy loss). 我们可使用dice loss解决这个问题。
注:dice loss和bce loss仅支持二分类。
## Dice loss
Dice loss的定义如下:
<p align="center">
<img src="./imgs/dice.png" hspace='10' height="46" width="200"/> <br />
</p>
其中 Y 表示ground truth,P 表示预测结果。| |表示矩阵元素之和。![](./imgs/dice2.png) 表示*Y**P*的共有元素数,
实际通过求两者的逐像素乘积之和进行计算。例如:
<p align="center">
<img src="./imgs/dice3.png" hspace='10' /> <br />
</p>
其中 1 表示前景,0 表示背景。
**Note:** 在标注图片中,务必保证前景像素值为1,背景像素值为0.
Dice系数请参见[维基百科](https://zh.wikipedia.org/wiki/Dice%E7%B3%BB%E6%95%B0)
**为什么在类别不均衡问题上,dice loss效果比softmax loss更好?**
首先来看softmax loss的定义:
<p align="center">
<img src="./imgs/softmax_loss.png" height="130" /> <br />
</p>
其中 y 表示ground truth,p 表示网络输出。
在图像分割中,`softmax loss`评估每一个像素点的类别预测,然后平均所有的像素点。这个本质上就是对图片上的每个像素进行平等的学习。这就造成了一个问题,如果在图像上的多种类别有不平衡的表征,那么训练会由最主流的类别主导。以上面DeepGlobe道路提取的数据为例子,网络将偏向于背景的学习,降低了网络对前景目标的提取能力。
`dice loss(dice coefficient loss)`通过预测和标注的交集除以它们的总体像素进行计算,它将一个类别的所有像素作为一个整体作为考量,而且计算交集在总体中的占比,所以不受大量背景像素的影响,能够取得更好的效果。
在实际应用中`dice loss`往往与`bce loss(binary cross entroy loss)`结合使用,提高模型训练的稳定性。
## PaddleSeg指定训练loss
PaddleSeg通过`cfg.SOLVER.LOSS`参数可以选择训练时的损失函数,
`cfg.SOLVER.LOSS=['dice_loss','bce_loss']`将指定训练loss为`dice loss``bce loss`的组合
## Dice loss解决类别不均衡问题的示例
我们以道路提取任务为例应用dice loss.
在DeepGlobe比赛的Road Extraction中,训练数据道路占比为:4.5%. 如下为其图片样例:
<p align="center">
<img src="./imgs/deepglobe.png" hspace='10'/> <br />
</p>
可以看出道路在整张图片中的比例很小。
### 数据集下载
我们从DeepGlobe比赛的Road Extraction的训练集中随机抽取了800张图片作为训练集,200张图片作为验证集,
制作了一个小型的道路提取数据集[MiniDeepGlobeRoadExtraction](https://paddleseg.bj.bcebos.com/dataset/MiniDeepGlobeRoadExtraction.zip)
### 实验比较
在MiniDeepGlobeRoadExtraction数据集进行了实验比较。
* 数据集下载
```shell
python dataset/download_mini_deepglobe_road_extraction.py
# Loss选择
目前PaddleSeg提供了6种损失函数,分别为
- Softmax loss (softmax with cross entropy loss)
- Weighted softmax loss (weighted softmax with cross entropy loss)
- Dice loss (dice coefficient loss)
- Bce loss (binary cross entropy loss)
- Lovasz hinge loss
- Lovasz softmax loss
## 类别不均衡问题
在图像分割任务中,经常出现类别分布不均匀的情况,例如:工业产品的瑕疵检测、道路提取及病变区域提取等。
针对这个问题,您可使用Weighted softmax loss、Dice loss、Lovasz hinge loss和Lovasz softmax loss进行解决。
### Weighted softmax loss
Weighted softmax loss是按类别设置不同权重的softmax loss。
通过设置`cfg.SOLVER.CROSS_ENTROPY_WEIGHT`参数进行使用。
默认为None. 如果设置为'dynamic',会根据每个batch中各个类别的数目,动态调整类别权重。
也可以设置一个静态权重(list的方式),比如有3类,每个类别权重可以设置为[0.1, 2.0, 0.9]. 示例如下
```yaml
SOLVER:
CROSS_ENTROPY_WEIGHT: 'dynamic'
```
* 预训练模型下载
```shell
python pretrained_model/download_model.py deeplabv3p_mobilenetv2-1-0_bn_coco
```
* 配置/数据校验
```shell
python pdseg/check.py --cfg ./configs/deepglobe_road_extraction.yaml
```
* 训练
```shell
python pdseg/train.py --cfg ./configs/deepglobe_road_extraction.yaml --use_gpu SOLVER.LOSS "['dice_loss','bce_loss']"
```
* 评估
```
python pdseg/eval.py --cfg ./configs/deepglobe_road_extraction.yaml --use_gpu SOLVER.LOSS "['dice_loss','bce_loss']"
```
* 结果比较
softmax loss和dice loss + bce loss实验结果如下图所示。
图中橙色曲线为dice loss + bce loss,最高mIoU为76.02%,蓝色曲线为softmax loss, 最高mIoU为73.62%。
<p align="center">
<img src="./imgs/loss_comparison.png" hspace='10' height="208" width="516"/> <br />
</p>
### Dice loss
参见[Dice loss教程](./dice_loss.md)
### Lovasz hinge loss和Lovasz softmax loss
参见[Lovasz loss教程](./lovasz_loss.md)
# Lovasz loss
对于图像分割任务中,经常出现类别分布不均匀的情况,例如:工业产品的瑕疵检测、道路提取及病变区域提取等。
对于图像分割任务中,经常出现类别分布不均匀的情况,例如:工业产品的瑕疵检测、道路提取及病变区域提取等。我们可使用lovasz loss解决这个问题。
我们可使用lovasz loss解决这个问题。Lovasz loss根据分割目标的类别数量可分为两种:lovasz hinge loss适用于二分类问题,lovasz softmax loss适用于多分类问题
Lovasz loss基于子模损失(submodular losses)的凸Lovasz扩展,对神经网络的mean IoU损失进行优化。Lovasz loss根据分割目标的类别数量可分为两种:lovasz hinge loss和lovasz softmax loss. 其中lovasz hinge loss适用于二分类问题,lovasz softmax loss适用于多分类问题。该工作发表在CVPR 2018上,可点击[参考文献](#参考文献)查看具体原理
## Lovasz hinge loss
### 使用方式
### 使用指南
PaddleSeg通过`cfg.SOLVER.LOSS`参数可以选择训练时的损失函数,
`cfg.SOLVER.LOSS=['lovasz_hinge_loss','bce_loss']`将指定训练loss为`lovasz hinge loss``bce loss`的组合。
`cfg.SOLVER.LOSS=['lovasz_hinge_loss','bce_loss']`将指定训练loss为`lovasz hinge loss``bce loss`(binary cross-entropy loss)的组合。
Lovasz hinge loss有3种使用方式:(1)直接训练使用。(2)bce loss结合使用。(3)先使用bec loss进行训练,再使用lovasz hinge loss进行finetuning. 第1种方式不一定达到理想效果,推荐使用后两种方式。本文以第2种方式为例。
### 使用示例
同时,也可以通过`cfg.SOLVER.LOSS_WEIGHT`参数对不同loss进行权重配比,灵活运用于训练调参。如下所示
```yaml
SOLVER:
LOSS: ["lovasz_hinge_loss","bce_loss"]
LOSS_WEIGHT:
LOVASZ_HINGE_LOSS: 0.5
BCE_LOSS: 0.5
```
### 实验对比
我们以道路提取任务为例应用lovasz hinge loss.
在DeepGlobe比赛的Road Extraction中,训练数据道路占比为:4.5%. 如下为其图片样例:
基于MiniDeepGlobeRoadExtraction数据集与bce loss进行了实验对比。
该数据集来源于DeepGlobe比赛的Road Extraction单项,训练数据道路占比为:4.5%. 如下为其图片样例:
<p align="center">
<img src="./imgs/deepglobe.png" hspace='10'/> <br />
</p>
可以看出道路在整张图片中的比例很小。
#### 实验对比
在MiniDeepGlobeRoadExtraction数据集进行了实验对比。
为进行快速体验,这里使用DeepLabv3+模型,backbone为MobileNetV2.
* 数据集下载
我们从DeepGlobe比赛的Road Extraction的训练集中随机抽取了800张图片作为训练集,200张图片作为验证集,
......@@ -64,23 +72,32 @@ lovasz hinge loss + bce loss和softmax loss的对比结果如下图所示。
## Lovasz softmax loss
### 使用方式
### 使用指南
PaddleSeg通过`cfg.SOLVER.LOSS`参数可以选择训练时的损失函数,
`cfg.SOLVER.LOSS=['lovasz_softmax_loss','softmax_loss']`将指定训练loss为`lovasz softmax loss``softmax loss`的组合。
Lovasz softmax loss有3种使用方式:(1)直接训练使用。(2)softmax loss结合使用。(3)先使用softmax loss进行训练,再使用lovasz softmax loss进行finetuning. 第1种方式不一定达到理想效果,推荐使用后两种方式。本文以第2种方式为例。
### 使用示例
我们以Pascal voc为例应用lovasz softmax loss.
同时,也可以通过`cfg.SOLVER.LOSS_WEIGHT`参数对不同loss进行权重配比,灵活运用于训练调参。如下所示
```yaml
SOLVER:
LOSS: ["lovasz_softmax_loss","softmax_loss"]
LOSS_WEIGHT:
LOVASZ_SOFTMAX_LOSS: 0.2
SOFTMAX_LOSS: 0.8
```
### 实验对比
#### 实验对比
接下来以PASCAL VOC 2012数据集为例应用lovasz softmax loss. 我们将lovasz softmax loss与softmax loss进行了实验对比。为进行快速体验,这里使用DeepLabv3+模型,backbone为MobileNetV2.
在Pascal voc数据集上与softmax loss进行了实验对比。
* 数据集下载
<p align="center">
<img src="./imgs/VOC2012.png" width="50%" height="50%" hspace='10'/> <br />
</p>
```shell
python dataset/download_and_convert_voc2012.py
```
......@@ -114,3 +131,7 @@ lovasz softmax loss + softmax loss和softmax loss的对比结果如下图所示
</p>
图中橙色曲线代表lovasz softmax loss + softmax loss,最高mIoU为64.63%,蓝色曲线代表softmax loss, 最高mIoU为63.55%,相比提升1.08个百分点。
## 参考文献
[Berman M, Rannen Triki A, Blaschko M B. The lovász-softmax loss: a tractable surrogate for the optimization of the intersection-over-union measure in neural networks[C]//Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2018: 4413-4421.](http://openaccess.thecvf.com/content_cvpr_2018/html/Berman_The_LovaSz-Softmax_Loss_CVPR_2018_paper.html)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册