提交 31262487 编写于 作者: F FlyingQianMM

add human segmentation

上级 355832ee
# HumanSeg人像分割模型
本教程基于PaddleSeg核心分割网络,提供针对人像分割场景从预训练模型、Fine-tune、视频分割预测部署的全流程应用指南。
## 安装
**前置依赖**
* paddlepaddle >= 1.8.0
* python >= 3.5
* cython
* pycocotools
```
pip install paddlex -i https://mirror.baidu.com/pypi/simple
```
安装的相关问题参考[PaddleX安装](https://paddlex.readthedocs.io/zh_CN/latest/install.html)
## 预训练模型
HumanSeg开放了在大规模人像数据上训练的两个预训练模型,满足多种使用场景的需求
| 模型类型 | Checkpoint | Inference Model | Quant Inference Model | 备注 |
| --- | --- | --- | ---| --- |
| HumanSeg-server | [humanseg_server_ckpt](https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_server_ckpt.zip) | [humanseg_server_inference](https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_server_inference.zip) | -- | 高精度模型,适用于服务端GPU且背景复杂的人像场景, 模型结构为Deeplabv3+/Xcetion65, 输入大小(512, 512) |
| HumanSeg-mobile | [humanseg_mobile_ckpt](https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_mobile_ckpt.zip) | [humanseg_mobile_inference](https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_mobile_inference.zip) | [humanseg_mobile_quant](https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_mobile_quant.zip) | 轻量级模型, 适用于移动端或服务端CPU的前置摄像头场景,模型结构为HRNet_w18_samll_v1,输入大小(192, 192) |
模型性能
| 模型 | 模型大小 | 计算耗时 |
| --- | --- | --- |
|humanseg_server_inference| 158M | - |
|humanseg_mobile_inference | 5.8 M | 42.35ms |
|humanseg_mobile_quant | 1.6M | 24.93ms |
计算耗时运行环境: 小米,cpu:骁龙855, 内存:6GB, 图片大小:192*192)
**NOTE:**
其中Checkpoint为模型权重,用于Fine-tuning场景。
* Inference Model和Quant Inference Model为预测部署模型,包含`__model__`计算图结构、`__params__`模型参数和`model.yaml`基础的模型配置信息。
* 其中Inference Model适用于服务端的CPU和GPU预测部署,Qunat Inference Model为量化版本,适用于通过Paddle Lite进行移动端等端侧设备部署。
执行以下脚本进行HumanSeg预训练模型的下载
```bash
python pretrain_weights/download_pretrain_weights.py
```
## 下载测试数据
我们提供了[supervise.ly](https://supervise.ly/)发布人像分割数据集**Supervisely Persons**, 从中随机抽取一小部分并转化成PaddleSeg可直接加载数据格式。通过运行以下代码进行快速下载,其中包含手机前置摄像头的人像测试视频`video_test.mp4`.
```bash
python data/download_data.py
```
## 快速体验视频流人像分割
结合DIS(Dense Inverse Search-basedmethod)光流算法预测结果与分割结果,改善视频流人像分割
```bash
# 通过电脑摄像头进行实时分割处理
python video_infer.py --model_dir pretrain_weights/humanseg_lite_inference
# 对人像视频进行分割处理
python video_infer.py --model_dir pretrain_weights/humanseg_lite_inference --video_path data/video_test.mp4
```
视频分割结果如下:
<img src="https://paddleseg.bj.bcebos.com/humanseg/data/video_test.gif" width="20%" height="20%"><img src="https://paddleseg.bj.bcebos.com/humanseg/data/result.gif" width="20%" height="20%">
根据所选背景进行背景替换,背景可以是一张图片,也可以是一段视频。
```bash
# 通过电脑摄像头进行实时背景替换处理, 也可通过'--background_video_path'传入背景视频
python bg_replace.py --model_dir pretrain_weights/humanseg_lite_inference --background_image_path data/background.jpg
# 对人像视频进行背景替换处理, 也可通过'--background_video_path'传入背景视频
python bg_replace.py --model_dir pretrain_weights/humanseg_lite_inference --video_path data/video_test.mp4 --background_image_path data/background.jpg
# 对单张图像进行背景替换
python bg_replace.py --model_dir pretrain_weights/humanseg_lite_inference --image_path data/human_image.jpg --background_image_path data/background.jpg
```
背景替换结果如下:
<img src="https://paddleseg.bj.bcebos.com/humanseg/data/video_test.gif" width="20%" height="20%"><img src="https://paddleseg.bj.bcebos.com/humanseg/data/bg_replace.gif" width="20%" height="20%">
**NOTE**:
视频分割处理时间需要几分钟,请耐心等待。
提供的模型适用于手机摄像头竖屏拍摄场景,宽屏效果会略差一些。
## 训练
使用下述命令基于与训练模型进行Fine-tuning,请确保选用的模型结构`model_type`与模型参数`pretrain_weights`匹配。
```bash
python train.py --model_type HumanSegMobile \
--save_dir output/ \
--data_dir data/mini_supervisely \
--train_list data/mini_supervisely/train.txt \
--val_list data/mini_supervisely/val.txt \
--pretrain_weights pretrain_weights/humanseg_mobile_ckpt \
--batch_size 8 \
--learning_rate 0.001 \
--num_epochs 10 \
--image_shape 192 192
```
其中参数含义如下:
* `--model_type`: 模型类型,可选项为:HumanSegServer、HumanSegMobile和HumanSegLite
* `--save_dir`: 模型保存路径
* `--data_dir`: 数据集路径
* `--train_list`: 训练集列表路径
* `--val_list`: 验证集列表路径
* `--pretrain_weights`: 预训练模型路径
* `--batch_size`: 批大小
* `--learning_rate`: 初始学习率
* `--num_epochs`: 训练轮数
* `--image_shape`: 网络输入图像大小(w, h)
更多命令行帮助可运行下述命令进行查看:
```bash
python train.py --help
```
**NOTE**
可通过更换`--model_type`变量与对应的`--pretrain_weights`使用不同的模型快速尝试。
## 评估
使用下述命令进行评估
```bash
python eval.py --model_dir output/best_model \
--data_dir data/mini_supervisely \
--val_list data/mini_supervisely/val.txt \
--image_shape 192 192
```
其中参数含义如下:
* `--model_dir`: 模型路径
* `--data_dir`: 数据集路径
* `--val_list`: 验证集列表路径
* `--image_shape`: 网络输入图像大小(w, h)
## 预测
使用下述命令进行预测, 预测结果默认保存在`./output/result/`文件夹中。
```bash
python infer.py --model_dir output/best_model \
--data_dir data/mini_supervisely \
--test_list data/mini_supervisely/test.txt \
--save_dir output/result \
--image_shape 192 192
```
其中参数含义如下:
* `--model_dir`: 模型路径
* `--data_dir`: 数据集路径
* `--test_list`: 测试集列表路径
* `--image_shape`: 网络输入图像大小(w, h)
## 模型导出
```bash
paddlex --export_inference --model_dir output/best_model \
--save_dir output/export
```
其中参数含义如下:
* `--model_dir`: 模型路径
* `--save_dir`: 导出模型保存路径
## 离线量化
```bash
python quant_offline.py --model_dir output/best_model \
--data_dir data/mini_supervisely \
--quant_list data/mini_supervisely/val.txt \
--save_dir output/quant_offline \
--image_shape 192 192
```
其中参数含义如下:
* `--model_dir`: 待量化模型路径
* `--data_dir`: 数据集路径
* `--quant_list`: 量化数据集列表路径,一般直接选择训练集或验证集
* `--save_dir`: 量化模型保存路径
* `--image_shape`: 网络输入图像大小(w, h)
## AIStudio在线教程
我们在AI Studio平台上提供了人像分割在线体验的教程,[点击体验](https://aistudio.baidu.com/aistudio/projectdetail/475345)
# coding: utf8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import os.path as osp
import cv2
import numpy as np
from utils.humanseg_postprocess import postprocess, threshold_mask
import paddlex as pdx
import paddlex.utils.logging as logging
from paddlex.seg import transforms
def parse_args():
parser = argparse.ArgumentParser(
description='HumanSeg inference for video')
parser.add_argument(
'--model_dir',
dest='model_dir',
help='Model path for inference',
type=str)
parser.add_argument(
'--image_path',
dest='image_path',
help='Image including human',
type=str,
default=None)
parser.add_argument(
'--background_image_path',
dest='background_image_path',
help='Background image for replacing',
type=str,
default=None)
parser.add_argument(
'--video_path',
dest='video_path',
help='Video path for inference',
type=str,
default=None)
parser.add_argument(
'--background_video_path',
dest='background_video_path',
help='Background video path for replacing',
type=str,
default=None)
parser.add_argument(
'--save_dir',
dest='save_dir',
help='The directory for saving the inference results',
type=str,
default='./output')
parser.add_argument(
"--image_shape",
dest="image_shape",
help="The image shape for net inputs.",
nargs=2,
default=[192, 192],
type=int)
return parser.parse_args()
def predict(img, model, test_transforms):
model.arrange_transforms(transforms=test_transforms, mode='test')
img, im_info = test_transforms(img.astype('float32'))
img = np.expand_dims(img, axis=0)
result = model.exe.run(model.test_prog,
feed={'image': img},
fetch_list=list(model.test_outputs.values()))
score_map = result[1]
score_map = np.squeeze(score_map, axis=0)
score_map = np.transpose(score_map, (1, 2, 0))
return score_map, im_info
def recover(img, im_info):
for info in im_info[::-1]:
if info[0] == 'resize':
w, h = info[1][1], info[1][0]
img = cv2.resize(img, (w, h), cv2.INTER_LINEAR)
elif info[0] == 'padding':
w, h = info[1][0], info[1][0]
img = img[0:h, 0:w, :]
return img
def bg_replace(score_map, img, bg):
h, w, _ = img.shape
bg = cv2.resize(bg, (w, h))
score_map = np.repeat(score_map[:, :, np.newaxis], 3, axis=2)
comb = (score_map * img + (1 - score_map) * bg).astype(np.uint8)
return comb
def infer(args):
resize_h = args.image_shape[1]
resize_w = args.image_shape[0]
test_transforms = transforms.Compose(
[transforms.Resize((resize_w, resize_h)), transforms.Normalize()])
model = pdx.load_model(args.model_dir)
if not osp.exists(args.save_dir):
os.makedirs(args.save_dir)
# 图像背景替换
if args.image_path is not None:
if not osp.exists(args.image_path):
raise Exception('The --image_path is not existed: {}'.format(
args.image_path))
if args.background_image_path is None:
raise Exception(
'The --background_image_path is not set. Please set it')
else:
if not osp.exists(args.background_image_path):
raise Exception(
'The --background_image_path is not existed: {}'.format(
args.background_image_path))
img = cv2.imread(args.image_path)
score_map, im_info = predict(img, model, test_transforms)
score_map = score_map[:, :, 1]
score_map = recover(score_map, im_info)
bg = cv2.imread(args.background_image_path)
save_name = osp.basename(args.image_path)
save_path = osp.join(args.save_dir, save_name)
result = bg_replace(score_map, img, bg)
cv2.imwrite(save_path, result)
# 视频背景替换,如果提供背景视频则以背景视频作为背景,否则采用提供的背景图片
else:
is_video_bg = False
if args.background_video_path is not None:
if not osp.exists(args.background_video_path):
raise Exception(
'The --background_video_path is not existed: {}'.format(
args.background_video_path))
is_video_bg = True
elif args.background_image_path is not None:
if not osp.exists(args.background_image_path):
raise Exception(
'The --background_image_path is not existed: {}'.format(
args.background_image_path))
else:
raise Exception(
'Please offer backgound image or video. You should set --backbground_iamge_paht or --background_video_path'
)
disflow = cv2.DISOpticalFlow_create(
cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST)
prev_gray = np.zeros((resize_h, resize_w), np.uint8)
prev_cfd = np.zeros((resize_h, resize_w), np.float32)
is_init = True
if args.video_path is not None:
logging.info('Please wait. It is computing......')
if not osp.exists(args.video_path):
raise Exception('The --video_path is not existed: {}'.format(
args.video_path))
cap_video = cv2.VideoCapture(args.video_path)
fps = cap_video.get(cv2.CAP_PROP_FPS)
width = int(cap_video.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap_video.get(cv2.CAP_PROP_FRAME_HEIGHT))
save_name = osp.basename(args.video_path)
save_name = save_name.split('.')[0]
save_path = osp.join(args.save_dir, save_name + '.avi')
cap_out = cv2.VideoWriter(
save_path,
cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps,
(width, height))
if is_video_bg:
cap_bg = cv2.VideoCapture(args.background_video_path)
frames_bg = cap_bg.get(cv2.CAP_PROP_FRAME_COUNT)
current_frame_bg = 1
else:
img_bg = cv2.imread(args.background_image_path)
while cap_video.isOpened():
ret, frame = cap_video.read()
if ret:
score_map, im_info = predict(frame, model, test_transforms)
cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
score_map = 255 * score_map[:, :, 1]
optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \
disflow, is_init)
prev_gray = cur_gray.copy()
prev_cfd = optflow_map.copy()
is_init = False
optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
optflow_map = threshold_mask(
optflow_map, thresh_bg=0.2, thresh_fg=0.8)
score_map = recover(optflow_map, im_info)
#循环读取背景帧
if is_video_bg:
ret_bg, frame_bg = cap_bg.read()
if ret_bg:
if current_frame_bg == frames_bg:
current_frame_bg = 1
cap_bg.set(cv2.CAP_PROP_POS_FRAMES, 0)
else:
break
current_frame_bg += 1
comb = bg_replace(score_map, frame, frame_bg)
else:
comb = bg_replace(score_map, frame, img_bg)
cap_out.write(comb)
else:
break
if is_video_bg:
cap_bg.release()
cap_video.release()
cap_out.release()
# 当没有输入预测图像和视频的时候,则打开摄像头
else:
cap_video = cv2.VideoCapture(0)
if not cap_video.isOpened():
raise IOError("Error opening video stream or file, "
"--video_path whether existing: {}"
" or camera whether working".format(
args.video_path))
return
if is_video_bg:
cap_bg = cv2.VideoCapture(args.background_video_path)
frames_bg = cap_bg.get(cv2.CAP_PROP_FRAME_COUNT)
current_frame_bg = 1
else:
img_bg = cv2.imread(args.background_image_path)
while cap_video.isOpened():
ret, frame = cap_video.read()
if ret:
score_map, im_info = predict(frame, model, test_transforms)
cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
score_map = 255 * score_map[:, :, 1]
optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \
disflow, is_init)
prev_gray = cur_gray.copy()
prev_cfd = optflow_map.copy()
is_init = False
optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
optflow_map = threshold_mask(
optflow_map, thresh_bg=0.2, thresh_fg=0.8)
score_map = recover(optflow_map, im_info)
#循环读取背景帧
if is_video_bg:
ret_bg, frame_bg = cap_bg.read()
if ret_bg:
if current_frame_bg == frames_bg:
current_frame_bg = 1
cap_bg.set(cv2.CAP_PROP_POS_FRAMES, 0)
else:
break
current_frame_bg += 1
comb = bg_replace(score_map, frame, frame_bg)
else:
comb = bg_replace(score_map, frame, img_bg)
cv2.imshow('HumanSegmentation', comb)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
else:
break
if is_video_bg:
cap_bg.release()
cap_video.release()
if __name__ == "__main__":
args = parse_args()
infer(args)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
TEST_PATH = os.path.join(LOCAL_PATH, "../../../", "test")
sys.path.append(TEST_PATH)
import paddlex as pdx
def download_data(savepath):
url = "https://paddleseg.bj.bcebos.com/humanseg/data/mini_supervisely.zip"
pdx.utils.download_and_decompress(url=url, path=savepath)
url = "https://paddleseg.bj.bcebos.com/humanseg/data/video_test.zip"
pdx.utils.download_and_decompress(url=url, path=savepath)
if __name__ == "__main__":
download_data(LOCAL_PATH)
print("Data download finish!")
# coding: utf8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import paddlex as pdx
import paddlex.utils.logging as logging
from paddlex.seg import transforms
def parse_args():
parser = argparse.ArgumentParser(description='HumanSeg training')
parser.add_argument(
'--model_dir',
dest='model_dir',
help='Model path for evaluating',
type=str,
default='output/best_model')
parser.add_argument(
'--data_dir',
dest='data_dir',
help='The root directory of dataset',
type=str)
parser.add_argument(
'--val_list',
dest='val_list',
help='Val list file of dataset',
type=str,
default=None)
parser.add_argument(
'--batch_size',
dest='batch_size',
help='Mini batch size',
type=int,
default=128)
parser.add_argument(
"--image_shape",
dest="image_shape",
help="The image shape for net inputs.",
nargs=2,
default=[192, 192],
type=int)
return parser.parse_args()
def dict2str(dict_input):
out = ''
for k, v in dict_input.items():
try:
v = round(float(v), 6)
except:
pass
out = out + '{}={}, '.format(k, v)
return out.strip(', ')
def evaluate(args):
eval_transforms = transforms.Compose(
[transforms.Resize(args.image_shape), transforms.Normalize()])
eval_dataset = pdx.datasets.SegDataset(
data_dir=args.data_dir,
file_list=args.val_list,
transforms=eval_transforms)
model = pdx.load_model(args.model_dir)
metrics = model.evaluate(eval_dataset, args.batch_size)
logging.info('[EVAL] Finished, {} .'.format(dict2str(metrics)))
if __name__ == '__main__':
args = parse_args()
evaluate(args)
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
TEST_PATH = os.path.join(LOCAL_PATH, "../../../", "test")
sys.path.append(TEST_PATH)
import paddlex as pdx
model_urls = {
"humanseg_server_ckpt":
"https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_server_ckpt.zip",
"humanseg_server_inference":
"https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_server_inference.zip",
"humanseg_mobile_ckpt":
"https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_mobile_ckpt.zip",
"humanseg_mobile_inference":
"https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_mobile_inference.zip",
"humanseg_mobile_quant":
"https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_mobile_quant.zip",
"humanseg_lite_ckpt":
"https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_lite_ckpt.zip",
"humanseg_lite_inference":
"https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_lite_inference.zip",
"humanseg_lite_quant":
"https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_lite_quant.zip",
}
if __name__ == "__main__":
for model_name, url in model_urls.items():
pdx.utils.download_and_decompress(url=url, path=LOCAL_PATH)
print("Pretrained Model download success!")
# coding: utf8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import paddlex as pdx
from paddlex.seg import transforms
def parse_args():
parser = argparse.ArgumentParser(description='HumanSeg training')
parser.add_argument(
'--model_dir',
dest='model_dir',
help='Model path for quant',
type=str,
default='output/best_model')
parser.add_argument(
'--batch_size',
dest='batch_size',
help='Mini batch size',
type=int,
default=1)
parser.add_argument(
'--batch_nums',
dest='batch_nums',
help='Batch number for quant',
type=int,
default=10)
parser.add_argument(
'--data_dir',
dest='data_dir',
help='the root directory of dataset',
type=str)
parser.add_argument(
'--quant_list',
dest='quant_list',
help='Image file list for model quantization, it can be vat.txt or train.txt',
type=str,
default=None)
parser.add_argument(
'--save_dir',
dest='save_dir',
help='The directory for saving the quant model',
type=str,
default='./output/quant_offline')
parser.add_argument(
"--image_shape",
dest="image_shape",
help="The image shape for net inputs.",
nargs=2,
default=[192, 192],
type=int)
return parser.parse_args()
def evaluate(args):
eval_transforms = transforms.Compose(
[transforms.Resize(args.image_shape), transforms.Normalize()])
eval_dataset = pdx.datasets.SegDataset(
data_dir=args.data_dir,
file_list=args.quant_list,
transforms=eval_transforms)
model = pdx.load_model(args.model_dir)
pdx.slim.export_quant_model(model, eval_dataset, args.batch_size,
args.batch_nums, args.save_dir)
if __name__ == '__main__':
args = parse_args()
evaluate(args)
# coding: utf8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
# 选择使用0号卡
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# 使用CPU
#os.environ['CUDA_VISIBLE_DEVICES'] = ''
import argparse
import paddlex as pdx
from paddlex.seg import transforms
# 下载和解压人像分割数据集
human_seg_data = 'https://paddlex.bj.bcebos.com/humanseg/data/human_seg_data.zip'
pdx.utils.download_and_decompress(human_seg_data, path='./')
# 下载和解压人像分割预训练模型
pretrain_weights = 'https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_mobile_ckpt.zip'
pdx.utils.download_and_decompress(
pretrain_weights, path='./output/human_seg/pretrain')
# 定义训练和验证时的transforms
train_transforms = transforms.Compose([
transforms.Resize([192, 192]), transforms.RandomHorizontalFlip(),
transforms.Normalize()
])
eval_transforms = transforms.Compose(
[transforms.Resize([192, 192]), transforms.Normalize()])
# 定义训练和验证所用的数据集
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/datasets/semantic_segmentation.html#segdataset
train_dataset = pdx.datasets.SegDataset(
data_dir='human_seg_data',
file_list='human_seg_data/train_list.txt',
label_list='human_seg_data/labels.txt',
transforms=train_transforms,
shuffle=True)
eval_dataset = pdx.datasets.SegDataset(
data_dir='human_seg_data',
file_list='human_seg_data/val_list.txt',
label_list='human_seg_data/labels.txt',
transforms=eval_transforms)
# 初始化模型,并进行训练
# 可使用VisualDL查看训练指标
# VisualDL启动方式: visualdl --logdir output/unet/vdl_log --port 8001
# 浏览器打开 https://0.0.0.0:8001即可
# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
# https://paddlex.readthedocs.io/zh_CN/latest/apis/models/semantic_segmentation.html#hrnet
num_classes = len(train_dataset.labels)
model = pdx.seg.HRNet(num_classes=num_classes, width='18_small_v1')
model.train(
num_epochs=10,
train_dataset=train_dataset,
train_batch_size=8,
eval_dataset=eval_dataset,
learning_rate=0.001,
pretrain_weights='./output/human_seg/pretrain/humanseg_mobile_ckpt',
save_dir='output/human_seg',
use_vdl=True)
MODEL_TYPE = ['HumanSegMobile', 'HumanSegServer']
def parse_args():
parser = argparse.ArgumentParser(description='HumanSeg training')
parser.add_argument(
'--model_type',
dest='model_type',
help="Model type for traing, which is one of ('HumanSegMobile', 'HumanSegServer')",
type=str,
default='HumanSegMobile')
parser.add_argument(
'--data_dir',
dest='data_dir',
help='The root directory of dataset',
type=str)
parser.add_argument(
'--train_list',
dest='train_list',
help='Train list file of dataset',
type=str)
parser.add_argument(
'--val_list',
dest='val_list',
help='Val list file of dataset',
type=str,
default=None)
parser.add_argument(
'--save_dir',
dest='save_dir',
help='The directory for saving the model snapshot',
type=str,
default='./output')
parser.add_argument(
'--num_classes',
dest='num_classes',
help='Number of classes',
type=int,
default=2)
parser.add_argument(
"--image_shape",
dest="image_shape",
help="The image shape for net inputs.",
nargs=2,
default=[192, 192],
type=int)
parser.add_argument(
'--num_epochs',
dest='num_epochs',
help='Number epochs for training',
type=int,
default=100)
parser.add_argument(
'--batch_size',
dest='batch_size',
help='Mini batch size',
type=int,
default=128)
parser.add_argument(
'--learning_rate',
dest='learning_rate',
help='Learning rate',
type=float,
default=0.01)
parser.add_argument(
'--pretrain_weights',
dest='pretrain_weights',
help='The path of pretrianed weight',
type=str,
default=None)
parser.add_argument(
'--resume_checkpoint',
dest='resume_checkpoint',
help='The path of resume checkpoint',
type=str,
default=None)
parser.add_argument(
'--use_vdl',
dest='use_vdl',
help='Whether to use visualdl',
action='store_true')
parser.add_argument(
'--save_interval_epochs',
dest='save_interval_epochs',
help='The interval epochs for save a model snapshot',
type=int,
default=5)
return parser.parse_args()
def train(args):
train_transforms = transforms.Compose([
transforms.Resize(args.image_shape), transforms.RandomHorizontalFlip(),
transforms.Normalize()
])
eval_transforms = transforms.Compose(
[transforms.Resize(args.image_shape), transforms.Normalize()])
train_dataset = pdx.datasets.SegDataset(
data_dir=args.data_dir,
file_list=args.train_list,
transforms=train_transforms,
shuffle=True)
eval_dataset = pdx.datasets.SegDataset(
data_dir=args.data_dir,
file_list=args.val_list,
transforms=eval_transforms)
if args.model_type == 'HumanSegMobile':
model = pdx.seg.HRNet(
num_classes=args.num_classes, width='18_small_v1')
elif args.model_type == 'HumanSegServer':
model = pdx.seg.DeepLabv3p(
num_classes=args.num_classes, backbone='Xception65')
else:
raise ValueError(
"--model_type: {} is set wrong, it shold be one of ('HumanSegMobile', "
"'HumanSegLite', 'HumanSegServer')".format(args.model_type))
model.train(
num_epochs=args.num_epochs,
train_dataset=train_dataset,
train_batch_size=args.batch_size,
eval_dataset=eval_dataset,
save_interval_epochs=args.save_interval_epochs,
learning_rate=args.learning_rate,
pretrain_weights=args.pretrain_weights,
resume_checkpoint=args.resume_checkpoint,
save_dir=args.save_dir,
use_vdl=args.use_vdl)
if __name__ == '__main__':
args = parse_args()
train(args)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import humanseg_postprocess
# coding: utf8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow):
"""计算光流跟踪匹配点和光流图
输入参数:
pre_gray: 上一帧灰度图
cur_gray: 当前帧灰度图
prev_cfd: 上一帧光流图
dl_weights: 融合权重图
disflow: 光流数据结构
返回值:
is_track: 光流点跟踪二值图,即是否具有光流点匹配
track_cfd: 光流跟踪图
"""
check_thres = 8
h, w = pre_gray.shape[:2]
track_cfd = np.zeros_like(prev_cfd)
is_track = np.zeros_like(pre_gray)
flow_fw = disflow.calc(pre_gray, cur_gray, None)
flow_bw = disflow.calc(cur_gray, pre_gray, None)
flow_fw = np.round(flow_fw).astype(np.int)
flow_bw = np.round(flow_bw).astype(np.int)
y_list = np.array(range(h))
x_list = np.array(range(w))
yv, xv = np.meshgrid(y_list, x_list)
yv, xv = yv.T, xv.T
cur_x = xv + flow_fw[:, :, 0]
cur_y = yv + flow_fw[:, :, 1]
# 超出边界不跟踪
not_track = (cur_x < 0) + (cur_x >= w) + (cur_y < 0) + (cur_y >= h)
flow_bw[~not_track] = flow_bw[cur_y[~not_track], cur_x[~not_track]]
not_track += (np.square(flow_fw[:, :, 0] + flow_bw[:, :, 0]) +
np.square(flow_fw[:, :, 1] + flow_bw[:, :, 1])
) >= check_thres
track_cfd[cur_y[~not_track], cur_x[~not_track]] = prev_cfd[~not_track]
is_track[cur_y[~not_track], cur_x[~not_track]] = 1
not_flow = np.all(np.abs(flow_fw) == 0,
axis=-1) * np.all(np.abs(flow_bw) == 0, axis=-1)
dl_weights[cur_y[not_flow], cur_x[not_flow]] = 0.05
return track_cfd, is_track, dl_weights
def human_seg_track_fuse(track_cfd, dl_cfd, dl_weights, is_track):
"""光流追踪图和人像分割结构融合
输入参数:
track_cfd: 光流追踪图
dl_cfd: 当前帧分割结果
dl_weights: 融合权重图
is_track: 光流点匹配二值图
返回
cur_cfd: 光流跟踪图和人像分割结果融合图
"""
fusion_cfd = dl_cfd.copy()
is_track = is_track.astype(np.bool)
fusion_cfd[is_track] = dl_weights[is_track] * dl_cfd[is_track] + (
1 - dl_weights[is_track]) * track_cfd[is_track]
# 确定区域
index_certain = ((dl_cfd > 0.9) + (dl_cfd < 0.1)) * is_track
index_less01 = (dl_weights < 0.1) * index_certain
fusion_cfd[index_less01] = 0.3 * dl_cfd[index_less01] + 0.7 * track_cfd[
index_less01]
index_larger09 = (dl_weights >= 0.1) * index_certain
fusion_cfd[index_larger09] = 0.4 * dl_cfd[
index_larger09] + 0.6 * track_cfd[index_larger09]
return fusion_cfd
def threshold_mask(img, thresh_bg, thresh_fg):
dst = (img / 255.0 - thresh_bg) / (thresh_fg - thresh_bg)
dst[np.where(dst > 1)] = 1
dst[np.where(dst < 0)] = 0
return dst.astype(np.float32)
def postprocess(cur_gray, scoremap, prev_gray, pre_cfd, disflow, is_init):
"""光流优化
Args:
cur_gray : 当前帧灰度图
pre_gray : 前一帧灰度图
pre_cfd :前一帧融合结果
scoremap : 当前帧分割结果
difflow : 光流
is_init : 是否第一帧
Returns:
fusion_cfd : 光流追踪图和预测结果融合图
"""
h, w = scoremap.shape
cur_cfd = scoremap.copy()
if is_init:
if h <= 64 or w <= 64:
disflow.setFinestScale(1)
elif h <= 160 or w <= 160:
disflow.setFinestScale(2)
else:
disflow.setFinestScale(3)
fusion_cfd = cur_cfd
else:
weights = np.ones((h, w), np.float32) * 0.3
track_cfd, is_track, weights = human_seg_tracking(
prev_gray, cur_gray, pre_cfd, weights, disflow)
fusion_cfd = human_seg_track_fuse(track_cfd, cur_cfd, weights,
is_track)
return fusion_cfd
# coding: utf8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import os.path as osp
import cv2
import numpy as np
from utils.humanseg_postprocess import postprocess, threshold_mask
import paddlex as pdx
import paddlex.utils.logging as logging
from paddlex.seg import transforms
def parse_args():
parser = argparse.ArgumentParser(
description='HumanSeg inference for video')
parser.add_argument(
'--model_dir',
dest='model_dir',
help='Model path for inference',
type=str)
parser.add_argument(
'--video_path',
dest='video_path',
help='Video path for inference, camera will be used if the path not existing',
type=str,
default=None)
parser.add_argument(
'--save_dir',
dest='save_dir',
help='The directory for saving the inference results',
type=str,
default='./output')
parser.add_argument(
"--image_shape",
dest="image_shape",
help="The image shape for net inputs.",
nargs=2,
default=[192, 192],
type=int)
return parser.parse_args()
def predict(img, model, test_transforms):
model.arrange_transforms(transforms=test_transforms, mode='test')
img, im_info = test_transforms(img.astype('float32'))
img = np.expand_dims(img, axis=0)
result = model.exe.run(model.test_prog,
feed={'image': img},
fetch_list=list(model.test_outputs.values()))
score_map = result[1]
score_map = np.squeeze(score_map, axis=0)
score_map = np.transpose(score_map, (1, 2, 0))
return score_map, im_info
def recover(img, im_info):
for info in im_info[::-1]:
if info[0] == 'resize':
w, h = info[1][1], info[1][0]
img = cv2.resize(img, (w, h), cv2.INTER_LINEAR)
elif info[0] == 'padding':
w, h = info[1][0], info[1][0]
img = img[0:h, 0:w, :]
return img
def video_infer(args):
resize_h = args.image_shape[1]
resize_w = args.image_shape[0]
test_transforms = transforms.Compose(
[transforms.Resize((resize_w, resize_h)), transforms.Normalize()])
model = pdx.load_model(args.model_dir)
if not args.video_path:
cap = cv2.VideoCapture(0)
else:
cap = cv2.VideoCapture(args.video_path)
if not cap.isOpened():
raise IOError("Error opening video stream or file, "
"--video_path whether existing: {}"
" or camera whether working".format(args.video_path))
return
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
disflow = cv2.DISOpticalFlow_create(cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST)
prev_gray = np.zeros((resize_h, resize_w), np.uint8)
prev_cfd = np.zeros((resize_h, resize_w), np.float32)
is_init = True
fps = cap.get(cv2.CAP_PROP_FPS)
if args.video_path:
logging.info("Please wait. It is computing......")
# 用于保存预测结果视频
if not osp.exists(args.save_dir):
os.makedirs(args.save_dir)
out = cv2.VideoWriter(
osp.join(args.save_dir, 'result.avi'),
cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps, (width, height))
# 开始获取视频帧
while cap.isOpened():
ret, frame = cap.read()
if ret:
score_map, im_info = predict(frame, model, test_transforms)
cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
score_map = 255 * score_map[:, :, 1]
optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \
disflow, is_init)
prev_gray = cur_gray.copy()
prev_cfd = optflow_map.copy()
is_init = False
optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
optflow_map = threshold_mask(
optflow_map, thresh_bg=0.2, thresh_fg=0.8)
img_matting = np.repeat(
optflow_map[:, :, np.newaxis], 3, axis=2)
img_matting = recover(img_matting, im_info)
bg_im = np.ones_like(img_matting) * 255
comb = (img_matting * frame +
(1 - img_matting) * bg_im).astype(np.uint8)
out.write(comb)
else:
break
cap.release()
out.release()
else:
while cap.isOpened():
ret, frame = cap.read()
if ret:
score_map, im_info = predict(frame, model, test_transforms)
cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
cur_gray = cv2.resize(cur_gray, (resize_w, resize_h))
score_map = 255 * score_map[:, :, 1]
optflow_map = postprocess(cur_gray, score_map, prev_gray, prev_cfd, \
disflow, is_init)
prev_gray = cur_gray.copy()
prev_cfd = optflow_map.copy()
is_init = False
optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
optflow_map = threshold_mask(
optflow_map, thresh_bg=0.2, thresh_fg=0.8)
img_matting = np.repeat(
optflow_map[:, :, np.newaxis], 3, axis=2)
img_matting = recover(img_matting, im_info)
bg_im = np.ones_like(img_matting) * 255
comb = (img_matting * frame +
(1 - img_matting) * bg_im).astype(np.uint8)
cv2.imshow('HumanSegmentation', comb)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
else:
break
cap.release()
if __name__ == "__main__":
args = parse_args()
video_infer(args)
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -28,7 +28,7 @@ class SegDataset(Dataset):
Args:
data_dir (str): 数据集所在的目录路径。
file_list (str): 描述数据集图片文件和对应标注文件的文件路径(文本内每行路径为相对data_dir的相对路)。
label_list (str): 描述数据集包含的类别信息文件路径。
label_list (str): 描述数据集包含的类别信息文件路径。默认值为None。
transforms (list): 数据集中每个样本的预处理/增强算子。
num_workers (int): 数据集中样本在预处理过程中的线程或进程数。默认为4。
buffer_size (int): 数据集中样本在预处理过程中队列的缓存长度,以样本数为单位。默认为100。
......@@ -40,7 +40,7 @@ class SegDataset(Dataset):
def __init__(self,
data_dir,
file_list,
label_list,
label_list=None,
transforms=None,
num_workers='auto',
buffer_size=100,
......@@ -56,10 +56,11 @@ class SegDataset(Dataset):
self.labels = list()
self._epoch = 0
with open(label_list, encoding=get_encoding(label_list)) as f:
for line in f:
item = line.strip()
self.labels.append(item)
if label_list is not None:
with open(label_list, encoding=get_encoding(label_list)) as f:
for line in f:
item = line.strip()
self.labels.append(item)
with open(file_list, encoding=get_encoding(file_list)) as f:
for line in f:
......@@ -69,8 +70,8 @@ class SegDataset(Dataset):
full_path_im = osp.join(data_dir, items[0])
full_path_label = osp.join(data_dir, items[1])
if not osp.exists(full_path_im):
raise IOError(
'The image file {} is not exist!'.format(full_path_im))
raise IOError('The image file {} is not exist!'.format(
full_path_im))
if not osp.exists(full_path_label):
raise IOError('The image file {} is not exist!'.format(
full_path_label))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册