未验证 提交 8281cc6f 编写于 作者: W wuyefeilin 提交者: GitHub

Add contrib HumanSeg (#229)

* update solver.py and model_builder.py

* update solver.py

* update infer.py

* update model_builder.py for fitting different aug_method

* update model_builder.py

* update model_builder.py

* update export process

* save best model

* update train.py

* update train.py

* update train.py

* eval for last epoch model

* check model path

* check TEST.TEST_MODEL

* reconsitution humanseg

* first refactor humanseg

* update humanseg.py and add main.py

* frist add

* add shufflenet

* first add

* add hrnet

* add tutorial

* add humanseg infer

* add realtime humanseg

* update hrnet.py

* update humanseg.py

* rm cpp deploy

* add infer of resize_bylong

* update

* update infer.py

* update some

* update humanseg.py

* image_segment and video_segment add to class HumanSeg

* add quant

* resolve with_data_parallel problem

* add coeff params

* add demo

* update

* update __init__.py

* add setup.py
上级 5b9a1db6
# HumanSeg
## 环境
将contrib目录加入环境变量PYTHONPATH
## 训练、验证、预测、模型导出参见[turtorial](./turtorial)
......@@ -17,7 +17,7 @@ from . import nets
from . import models
from . import datasets
from . import transforms
from .utils.utils import get_environ_info
from .utils import get_environ_info
env_info = get_environ_info()
......
......@@ -11,3 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .dataset import Dataset
# 实时人像分割预测部署
本模型基于飞浆开源的人像分割模型,并做了大量的针对视频的光流追踪优化,提供了完整的支持视频流的实时人像分割解决方案,并提供了高性能的`Python`集成部署方案。
## 模型下载
支持的模型文件如下,请根据应用场景选择合适的模型:
|模型文件 | 说明 |
| --- | --- |
|[shv75_deeplab_0303_quant](https://paddleseg.bj.bcebos.com/deploy/models/shv75_0303_quant.zip) | 小模型, 适合轻量级计算环境 |
|[shv75_deeplab_0303](https://paddleseg.bj.bcebos.com/deploy/models/shv75_deeplab_0303.zip)| 小模型,适合轻量级计算环境 |
|[deeplabv3_xception_humanseg](https://paddleseg.bj.bcebos.com/deploy/models/deeplabv3_xception_humanseg.zip) | 服务端GPU环境 |
**注意:下载后解压到合适的路径,后续该路径将做为预测参数用于加载模型。**
## 预测部署
- [Python预测部署](./python)
## 效果预览
<figure class="half">
<img src="https://paddleseg.bj.bcebos.com/deploy/data/input.gif">
<img src="https://paddleseg.bj.bcebos.com/deploy/data/output.gif">
</figure>
# 实时人像分割Python预测部署方案
本方案基于Python实现,最小化依赖并把所有模型加载、数据预处理、预测、光流处理等后处理都封装在文件`infer.py`中,用户可以直接使用或集成到自己项目中。
## 前置依赖
- Windows(7,8,10) / Linux (Ubuntu 16.04) or MacOS 10.1+
- Paddle 1.7+
- Python 3.6+
注意:
1. 仅支持Paddle 1.7以上版本
2. MacOS上不支持GPU预测
其它未涉及情形,能正常安装`Paddle``OpenCV`通常都能正常使用。
## 安装依赖
执行如下命令
```shell
pip install -r requirements.txt
```
## 运行
1. 输入图片进行分割
```
python infer.py --model_dir /PATH/TO/INFERENCE/MODEL --img_path /PATH/TO/INPUT/IMAGE
```
预测结果会保存为`result.jpeg`
2. 输入视频进行分割
```shell
python infer.py --model_dir /PATH/TO/INFERENCE/MODEL --video_path /PATH/TO/INPUT/VIDEO
```
预测结果会保存在`result.avi`
3. 使用摄像头视频流
```shell
python infer.py --model_dir /PATH/TO/INFERENCE/MODEL --use_camera True
```
预测结果会通过可视化窗口实时显示。
**注意:**
`GPU`默认关闭, 如果要使用`GPU`进行加速,则先运行
```
export CUDA_VISIBLE_DEVICES=0
```
然后在前面的预测命令中增加参数`--use_gpu True`即可。
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""实时人像分割Python预测部署"""
import os
import argparse
import numpy as np
import cv2
import paddle.fluid as fluid
def humanseg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow):
"""计算光流跟踪匹配点和光流图
输入参数:
pre_gray: 上一帧灰度图
cur_gray: 当前帧灰度图
prev_cfd: 上一帧光流图
dl_weights: 融合权重图
disflow: 光流数据结构
返回值:
is_track: 光流点跟踪二值图,即是否具有光流点匹配
track_cfd: 光流跟踪图
"""
check_thres = 8
hgt, wdh = pre_gray.shape[:2]
track_cfd = np.zeros_like(prev_cfd)
is_track = np.zeros_like(pre_gray)
# 计算前向光流
flow_fw = disflow.calc(pre_gray, cur_gray, None)
# 计算后向光流
flow_bw = disflow.calc(cur_gray, pre_gray, None)
get_round = lambda data: (int)(data + 0.5) if data >= 0 else (int)(data -
0.5)
for row in range(hgt):
for col in range(wdh):
# 计算光流处理后对应点坐标
# (row, col) -> (cur_x, cur_y)
fxy_fw = flow_fw[row, col]
dx_fw = get_round(fxy_fw[0])
cur_x = dx_fw + col
dy_fw = get_round(fxy_fw[1])
cur_y = dy_fw + row
if cur_x < 0 or cur_x >= wdh or cur_y < 0 or cur_y >= hgt:
continue
fxy_bw = flow_bw[cur_y, cur_x]
dx_bw = get_round(fxy_bw[0])
dy_bw = get_round(fxy_bw[1])
# 光流移动小于阈值
lmt = ((dy_fw + dy_bw) * (dy_fw + dy_bw) +
(dx_fw + dx_bw) * (dx_fw + dx_bw))
if lmt >= check_thres:
continue
# 静止点降权
if abs(dy_fw) <= 0 and abs(dx_fw) <= 0 and abs(dy_bw) <= 0 and abs(
dx_bw) <= 0:
dl_weights[cur_y, cur_x] = 0.05
is_track[cur_y, cur_x] = 1
track_cfd[cur_y, cur_x] = prev_cfd[row, col]
return track_cfd, is_track, dl_weights
def humanseg_track_fuse(track_cfd, dl_cfd, dl_weights, is_track):
"""光流追踪图和人像分割结构融合
输入参数:
track_cfd: 光流追踪图
dl_cfd: 当前帧分割结果
dl_weights: 融合权重图
is_track: 光流点匹配二值图
返回值:
cur_cfd: 光流跟踪图和人像分割结果融合图
"""
cur_cfd = dl_cfd.copy()
idxs = np.where(is_track > 0)
for i in range(len(idxs)):
x, y = idxs[0][i], idxs[1][i]
dl_score = dl_cfd[y, x]
track_score = track_cfd[y, x]
if dl_score > 0.9 or dl_score < 0.1:
if dl_weights[x, y] < 0.1:
cur_cfd[x, y] = 0.3 * dl_score + 0.7 * track_score
else:
cur_cfd[x, y] = 0.4 * dl_score + 0.6 * track_score
else:
cur_cfd[x, y] = dl_weights[x, y] * dl_score + (
1 - dl_weights[x, y]) * track_score
return cur_cfd
def threshold_mask(img, thresh_bg, thresh_fg):
"""设置背景和前景阈值mask
输入参数:
img : 原始图像, np.uint8 类型.
thresh_bg : 背景阈值百分比,低于该值置为0.
thresh_fg : 前景阈值百分比,超过该值置为1.
返回值:
dst : 原始图像设置完前景背景阈值mask结果, np.float32 类型.
"""
dst = (img / 255.0 - thresh_bg) / (thresh_fg - thresh_bg)
dst[np.where(dst > 1)] = 1
dst[np.where(dst < 0)] = 0
return dst.astype(np.float32)
def optflow_handle(cur_gray, scoremap, is_init):
"""光流优化
Args:
cur_gray : 当前帧灰度图
scoremap : 当前帧分割结果
is_init : 是否第一帧
Returns:
dst : 光流追踪图和预测结果融合图, 类型为 np.float32
"""
width, height = scoremap.shape[0], scoremap.shape[1]
disflow = cv2.DISOpticalFlow_create(cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST)
prev_gray = np.zeros((height, width), np.uint8)
prev_cfd = np.zeros((height, width), np.float32)
cur_cfd = scoremap.copy()
if is_init:
is_init = False
if height <= 64 or width <= 64:
disflow.setFinestScale(1)
elif height <= 160 or width <= 160:
disflow.setFinestScale(2)
else:
disflow.setFinestScale(3)
fusion_cfd = cur_cfd
else:
weights = np.ones((width, height), np.float32) * 0.3
track_cfd, is_track, weights = humanseg_tracking(
prev_gray, cur_gray, prev_cfd, weights, disflow)
fusion_cfd = humanseg_track_fuse(track_cfd, cur_cfd, weights, is_track)
fusion_cfd = cv2.GaussianBlur(fusion_cfd, (3, 3), 0)
return fusion_cfd
class HumanSeg:
"""人像分割类
封装了人像分割模型的加载,数据预处理,预测,后处理等
"""
def __init__(self, model_dir, mean, scale, eval_size, use_gpu=False):
self.mean = np.array(mean).reshape((3, 1, 1))
self.scale = np.array(scale).reshape((3, 1, 1))
self.eval_size = eval_size
self.load_model(model_dir, use_gpu)
def load_model(self, model_dir, use_gpu):
"""加载模型并创建predictor
Args:
model_dir: 预测模型路径, 包含 `__model__` 和 `__params__`
use_gpu: 是否使用GPU加速
"""
prog_file = os.path.join(model_dir, '__model__')
params_file = os.path.join(model_dir, '__params__')
config = fluid.core.AnalysisConfig(prog_file, params_file)
if use_gpu:
config.enable_use_gpu(100, 0)
config.switch_ir_optim(True)
else:
config.disable_gpu()
config.disable_glog_info()
config.switch_specify_input_names(True)
config.enable_memory_optim()
self.predictor = fluid.core.create_paddle_predictor(config)
def preprocess(self, image):
"""图像预处理
hwc_rgb 转换为 chw_bgr,并进行归一化
输入参数:
image: 原始图像
返回值:
经过预处理后的图片结果
"""
img_mat = cv2.resize(
image, self.eval_size, interpolation=cv2.INTER_LINEAR)
# HWC -> CHW
img_mat = img_mat.swapaxes(1, 2)
img_mat = img_mat.swapaxes(0, 1)
# Convert to float
img_mat = img_mat[:, :, :].astype('float32')
img_mat = (img_mat / 255. - self.mean) / self.scale
img_mat = img_mat[np.newaxis, :, :, :]
return img_mat
def postprocess(self, image, output_data):
"""对预测结果进行后处理
Args:
image: 原始图,opencv 图片对象
output_data: Paddle预测结果原始数据
Returns:
原图和预测结果融合并做了光流优化的结果图
"""
scoremap = output_data[0, 1, :, :]
scoremap = (scoremap * 255).astype(np.uint8)
ori_h, ori_w = image.shape[0], image.shape[1]
evl_h, evl_w = self.eval_size[0], self.eval_size[1]
# 光流处理
cur_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
cur_gray = cv2.resize(cur_gray, (evl_w, evl_h))
optflow_map = optflow_handle(cur_gray, scoremap, False)
optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
optflow_map = threshold_mask(optflow_map, thresh_bg=0.2, thresh_fg=0.8)
optflow_map = cv2.resize(optflow_map, (ori_w, ori_h))
optflow_map = np.repeat(optflow_map[:, :, np.newaxis], 3, axis=2)
bg_im = np.ones_like(optflow_map) * 255
comb = (optflow_map * image + (1 - optflow_map) * bg_im).astype(
np.uint8)
return comb
def run_predict(self, image):
"""运行预测并返回可视化结果图
输入参数:
image: 需要预测的原始图, opencv图片对象
返回值:
可视化的预测结果图
"""
im_mat = self.preprocess(image)
im_tensor = fluid.core.PaddleTensor(im_mat.copy().astype('float32'))
output_data = self.predictor.run([im_tensor])[1]
output_data = output_data.as_ndarray()
return self.postprocess(image, output_data)
def image_segment(self, path):
"""对图片文件进行分割
结果保存到`result.jpeg`文件中
"""
img_mat = cv2.imread(path)
img_mat = self.run_predict(img_mat)
cv2.imwrite('result.jpeg', img_mat)
def video_segment(self, path=None):
"""
对视屏流进行分割,
path为None时默认打开摄像头。
"""
if path is None:
cap = cv2.VideoCapture(0)
else:
cap = cv2.VideoCapture(path)
if not cap.isOpened():
raise IOError("Error opening video stream or file")
return
if path is not None:
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
# 用于保存预测结果视频
out = cv2.VideoWriter('result.avi',
cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'),
fps, (width, height))
# 开始获取视频帧
while cap.isOpened():
ret, frame = cap.read()
if ret:
img_mat = self.run_predict(frame)
out.write(img_mat)
else:
break
cap.release()
out.release()
else:
while cap.isOpened():
ret, frame = cap.read()
if ret:
img_mat = self.run_predict(frame)
cv2.imshow('HumanSegmentation', img_mat)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
else:
break
cap.release()
def main(args):
"""预测程序入口
完成模型加载, 对视频、摄像头、图片文件等预测过程
"""
model_dir = args.model_dir
use_gpu = args.use_gpu
# 加载模型
mean = [0.5, 0.5, 0.5]
scale = [0.5, 0.5, 0.5]
eval_size = (192, 192)
model = HumanSeg(model_dir, mean, scale, eval_size, use_gpu)
if args.use_camera:
# 开启摄像头
model.video_segment()
elif args.video_path:
# 使用视频文件作为输入
model.video_segment(args.video_path)
elif args.img_path:
# 使用图片文件作为输入
model.image_segment(args.img_path)
else:
raise ValueError(
'One of (--model_dir, --video_path, --use_camera) should be given.')
def parse_args():
"""解析命令行参数
"""
parser = argparse.ArgumentParser('Realtime Human Segmentation')
parser.add_argument(
'--model_dir',
type=str,
default='',
help='path of human segmentation model')
parser.add_argument(
'--img_path', type=str, default='', help='path of input image')
parser.add_argument(
'--video_path', type=str, default='', help='path of input video')
parser.add_argument(
'--use_camera',
type=bool,
default=False,
help='input video stream from camera')
parser.add_argument(
'--use_gpu', type=bool, default=False, help='enable gpu')
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
main(args)
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""实时人像分割Python预测部署"""
import os
import argparse
import numpy as np
import cv2
import paddle.fluid as fluid
def humanseg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow):
"""计算光流跟踪匹配点和光流图
输入参数:
pre_gray: 上一帧灰度图
cur_gray: 当前帧灰度图
prev_cfd: 上一帧光流图
dl_weights: 融合权重图
disflow: 光流数据结构
返回值:
is_track: 光流点跟踪二值图,即是否具有光流点匹配
track_cfd: 光流跟踪图
"""
check_thres = 8
hgt, wdh = pre_gray.shape[:2]
track_cfd = np.zeros_like(prev_cfd)
is_track = np.zeros_like(pre_gray)
# 计算前向光流
flow_fw = disflow.calc(pre_gray, cur_gray, None)
# 计算后向光流
flow_bw = disflow.calc(cur_gray, pre_gray, None)
get_round = lambda data: (int)(data + 0.5) if data >= 0 else (int)(data -
0.5)
for row in range(hgt):
for col in range(wdh):
# 计算光流处理后对应点坐标
# (row, col) -> (cur_x, cur_y)
fxy_fw = flow_fw[row, col]
dx_fw = get_round(fxy_fw[0])
cur_x = dx_fw + col
dy_fw = get_round(fxy_fw[1])
cur_y = dy_fw + row
if cur_x < 0 or cur_x >= wdh or cur_y < 0 or cur_y >= hgt:
continue
fxy_bw = flow_bw[cur_y, cur_x]
dx_bw = get_round(fxy_bw[0])
dy_bw = get_round(fxy_bw[1])
# 光流移动小于阈值
lmt = ((dy_fw + dy_bw) * (dy_fw + dy_bw) +
(dx_fw + dx_bw) * (dx_fw + dx_bw))
if lmt >= check_thres:
continue
# 静止点降权
if abs(dy_fw) <= 0 and abs(dx_fw) <= 0 and abs(dy_bw) <= 0 and abs(
dx_bw) <= 0:
dl_weights[cur_y, cur_x] = 0.05
is_track[cur_y, cur_x] = 1
track_cfd[cur_y, cur_x] = prev_cfd[row, col]
return track_cfd, is_track, dl_weights
def humanseg_track_fuse(track_cfd, dl_cfd, dl_weights, is_track):
"""光流追踪图和人像分割结构融合
输入参数:
track_cfd: 光流追踪图
dl_cfd: 当前帧分割结果
dl_weights: 融合权重图
is_track: 光流点匹配二值图
返回值:
cur_cfd: 光流跟踪图和人像分割结果融合图
"""
cur_cfd = dl_cfd.copy()
idxs = np.where(is_track > 0)
for i in range(len(idxs)):
x, y = idxs[0][i], idxs[1][i]
dl_score = dl_cfd[y, x]
track_score = track_cfd[y, x]
if dl_score > 0.9 or dl_score < 0.1:
if dl_weights[x, y] < 0.1:
cur_cfd[x, y] = 0.3 * dl_score + 0.7 * track_score
else:
cur_cfd[x, y] = 0.4 * dl_score + 0.6 * track_score
else:
cur_cfd[x, y] = dl_weights[x, y] * dl_score + (
1 - dl_weights[x, y]) * track_score
return cur_cfd
def threshold_mask(img, thresh_bg, thresh_fg):
"""设置背景和前景阈值mask
输入参数:
img : 原始图像, np.uint8 类型.
thresh_bg : 背景阈值百分比,低于该值置为0.
thresh_fg : 前景阈值百分比,超过该值置为1.
返回值:
dst : 原始图像设置完前景背景阈值mask结果, np.float32 类型.
"""
dst = (img / 255.0 - thresh_bg) / (thresh_fg - thresh_bg)
dst[np.where(dst > 1)] = 1
dst[np.where(dst < 0)] = 0
return dst.astype(np.float32)
def optflow_handle(cur_gray, scoremap, is_init):
"""光流优化
Args:
cur_gray : 当前帧灰度图
scoremap : 当前帧分割结果
is_init : 是否第一帧
Returns:
dst : 光流追踪图和预测结果融合图, 类型为 np.float32
"""
width, height = scoremap.shape[0], scoremap.shape[1]
disflow = cv2.DISOpticalFlow_create(cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST)
prev_gray = np.zeros((height, width), np.uint8)
prev_cfd = np.zeros((height, width), np.float32)
cur_cfd = scoremap.copy()
if is_init:
is_init = False
if height <= 64 or width <= 64:
disflow.setFinestScale(1)
elif height <= 160 or width <= 160:
disflow.setFinestScale(2)
else:
disflow.setFinestScale(3)
fusion_cfd = cur_cfd
else:
weights = np.ones((width, height), np.float32) * 0.3
track_cfd, is_track, weights = humanseg_tracking(
prev_gray, cur_gray, prev_cfd, weights, disflow)
fusion_cfd = humanseg_track_fuse(track_cfd, cur_cfd, weights, is_track)
fusion_cfd = cv2.GaussianBlur(fusion_cfd, (3, 3), 0)
return fusion_cfd
class HumanSeg:
"""人像分割类
封装了人像分割模型的加载,数据预处理,预测,后处理等
"""
def __init__(self, model_dir, mean, scale, long_size, use_gpu=False):
self.mean = np.array(mean).reshape((3, 1, 1))
self.scale = np.array(scale).reshape((3, 1, 1))
self.long_size = long_size
self.load_model(model_dir, use_gpu)
def load_model(self, model_dir, use_gpu):
"""加载模型并创建predictor
Args:
model_dir: 预测模型路径, 包含 `__model__` 和 `__params__`
use_gpu: 是否使用GPU加速
"""
prog_file = os.path.join(model_dir, '__model__')
params_file = os.path.join(model_dir, '__params__')
config = fluid.core.AnalysisConfig(prog_file, params_file)
if use_gpu:
config.enable_use_gpu(100, 0)
config.switch_ir_optim(True)
else:
config.disable_gpu()
config.disable_glog_info()
config.switch_specify_input_names(True)
config.enable_memory_optim()
self.predictor = fluid.core.create_paddle_predictor(config)
def preprocess(self, image):
"""图像预处理
hwc_rgb 转换为 chw_bgr,并进行归一化
输入参数:
image: 原始图像
返回值:
经过预处理后的图片结果
"""
origin_h, origin_w = image.shape[0], image.shape[1]
scale = float(self.long_size) / max(origin_w, origin_h)
resize_w = int(round(origin_w * scale))
resize_h = int(round(origin_h * scale))
img_mat = cv2.resize(
image, (resize_w, resize_h), interpolation=cv2.INTER_LINEAR)
pad_h = self.long_size - resize_h
pad_w = self.long_size - resize_w
img_mat = cv2.copyMakeBorder(
img_mat,
0,
pad_h,
0,
pad_w,
cv2.BORDER_CONSTANT,
value=[127.5, 127.5, 127.5])
# HWC -> CHW
img_mat = img_mat.swapaxes(1, 2)
img_mat = img_mat.swapaxes(0, 1)
# Convert to float
img_mat = img_mat[:, :, :].astype('float32')
img_mat = (img_mat / 255. - self.mean) / self.scale
img_mat = img_mat[np.newaxis, :, :, :]
return img_mat
def postprocess(self, image, output_data):
"""对预测结果进行后处理
Args:
image: 原始图,opencv 图片对象
output_data: Paddle预测结果原始数据
Returns:
原图和预测结果融合并做了光流优化的结果图
"""
scoremap = output_data[0, 1, :, :]
scoremap = (scoremap * 255).astype(np.uint8)
# 光流处理
cur_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
origin_h, origin_w = image.shape[0], image.shape[1]
scale = float(self.long_size) / max(origin_w, origin_h)
resize_w = int(round(origin_w * scale))
resize_h = int(round(origin_h * scale))
cur_gray = cv2.resize(
cur_gray, (resize_w, resize_h), interpolation=cv2.INTER_LINEAR)
pad_h = self.long_size - resize_h
pad_w = self.long_size - resize_w
cur_gray = cv2.copyMakeBorder(
cur_gray, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=127.5)
optflow_map = optflow_handle(cur_gray, scoremap, False)
optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
optflow_map = threshold_mask(optflow_map, thresh_bg=0.2, thresh_fg=0.8)
optflow_map = optflow_map[0:resize_h, 0:resize_w]
optflow_map = cv2.resize(optflow_map, (origin_w, origin_h))
optflow_map = np.repeat(optflow_map[:, :, np.newaxis], 3, axis=2)
bg_im = np.ones_like(optflow_map) * 255
comb = (optflow_map * image + (1 - optflow_map) * bg_im).astype(
np.uint8)
return comb
def run_predict(self, image):
"""运行预测并返回可视化结果图
输入参数:
image: 需要预测的原始图, opencv图片对象
返回值:
可视化的预测结果图
"""
im_mat = self.preprocess(image)
im_tensor = fluid.core.PaddleTensor(im_mat.copy().astype('float32'))
output_data = self.predictor.run([im_tensor])[1]
output_data = output_data.as_ndarray()
return self.postprocess(image, output_data)
def image_segment(self, path):
"""对图片文件进行分割
结果保存到`result.jpeg`文件中
"""
img_mat = cv2.imread(path)
img_mat = self.run_predict(img_mat)
cv2.imwrite('result.jpeg', img_mat)
def video_segment(self, path=None):
"""
对视屏流进行分割,
path为None时默认打开摄像头。
"""
if path is None:
cap = cv2.VideoCapture(0)
else:
cap = cv2.VideoCapture(path)
if not cap.isOpened():
raise IOError("Error opening video stream or file")
return
if path is not None:
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
# 用于保存预测结果视频
out = cv2.VideoWriter('result.avi',
cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'),
fps, (width, height))
# 开始获取视频帧
while cap.isOpened():
ret, frame = cap.read()
if ret:
img_mat = self.run_predict(frame)
out.write(img_mat)
else:
break
cap.release()
out.release()
else:
while cap.isOpened():
ret, frame = cap.read()
if ret:
img_mat = self.run_predict(frame)
cv2.imshow('HumanSegmentation', img_mat)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
else:
break
cap.release()
def main(args):
"""预测程序入口
完成模型加载, 对视频、摄像头、图片文件等预测过程
"""
model_dir = args.model_dir
use_gpu = args.use_gpu
# 加载模型
mean = [0.5, 0.5, 0.5]
scale = [0.5, 0.5, 0.5]
long_size = 192
model = HumanSeg(model_dir, mean, scale, long_size, use_gpu)
if args.use_camera:
# 开启摄像头
model.video_segment()
elif args.video_path:
# 使用视频文件作为输入
model.video_segment(args.video_path)
elif args.img_path:
# 使用图片文件作为输入
model.image_segment(args.img_path)
else:
raise ValueError(
'One of (--model_dir, --video_path, --use_camera) should be given.')
def parse_args():
"""解析命令行参数
"""
parser = argparse.ArgumentParser('Realtime Human Segmentation')
parser.add_argument(
'--model_dir',
type=str,
default='',
help='path of human segmentation model')
parser.add_argument(
'--img_path', type=str, default='', help='path of input image')
parser.add_argument(
'--video_path', type=str, default='', help='path of input video')
parser.add_argument(
'--use_camera',
type=bool,
default=False,
help='input video stream from camera')
parser.add_argument(
'--use_gpu', type=bool, default=False, help='enable gpu')
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
main(args)
opencv-python==4.1.2.30
opencv-contrib-python==4.2.0.32
import argparse
import os
import os.path as osp
import cv2
import numpy as np
import tqdm
import HumanSeg
def parse_args():
parser = argparse.ArgumentParser(
description='HumanSeg inference and visualization')
parser.add_argument(
'--test_model',
dest='test_model',
help='model path for inference',
type=str)
parser.add_argument(
'--data_dir',
dest='data_dir',
help='the root directory of dataset',
type=str)
parser.add_argument(
'--file_list', dest='file_list', help='file list for test', type=str)
parser.add_argument(
'--save_dir',
dest='save_dir',
help='the directory for saveing the inferenc results',
type=str,
default='./result')
return parser.parse_args()
def mkdir(path):
sub_dir = osp.dirname(path)
if not osp.exists(sub_dir):
os.makedirs(sub_dir)
def process(args):
model = HumanSeg.models.load_model(args.test_model)
added_saveed_path = osp.join(args.save_dir, 'added')
mat_saved_path = osp.join(args.save_dir, 'mat')
scoremap_saved_path = osp.join(args.save_dir, 'scoremap')
with open(args.file_list, 'r') as f:
files = f.readlines()
for file in tqdm.tqdm(files):
file = file.strip()
im_file = osp.join(args.data_dir, file)
im = cv2.imread(im_file)
result = model.predict(im)
# save added image
added_image = HumanSeg.utils.visualize(im_file, result, weight=0.6)
added_image_file = osp.join(added_saveed_path, file)
mkdir(added_image_file)
cv2.imwrite(added_image_file, added_image)
# save score map
score_map = result['score_map'][:, :, 1]
score_map = (score_map * 255).astype(np.uint8)
score_map_file = osp.join(scoremap_saved_path, file)
mkdir(score_map_file)
cv2.imwrite(score_map_file, score_map)
# save mat image
score_map = np.expand_dims(score_map, axis=-1)
mat_image = np.concatenate([im, score_map], axis=2)
mat_file = osp.join(mat_saved_path, file)
ext = osp.splitext(mat_file)[-1]
mat_file = mat_file.replace(ext, '.png')
mkdir(mat_file)
cv2.imwrite(mat_file, mat_image)
if __name__ == '__main__':
args = parse_args()
process(args)
from .humanseg import HumanSegMobile
from .humanseg import HumanSegServer
from .humanseg import HumanSegLite
from .humanseg import HRNet
from .load_model import load_model
此差异已折叠。
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import yaml
import os.path as osp
import six
import copy
from collections import OrderedDict
import paddle.fluid as fluid
import HumanSeg
import HumanSeg.utils.logging as logging
def load_model(model_dir):
if not osp.exists(osp.join(model_dir, "model.yml")):
raise Exception("There's not model.yml in {}".format(model_dir))
with open(osp.join(model_dir, "model.yml")) as f:
info = yaml.load(f.read(), Loader=yaml.Loader)
status = info['status']
if not hasattr(HumanSeg.models, info['Model']):
raise Exception("There's no attribute {} in HumanSeg.models".format(
info['Model']))
model = getattr(HumanSeg.models, info['Model'])(**info['_init_params'])
if status == "Normal":
startup_prog = fluid.Program()
model.test_prog = fluid.Program()
with fluid.program_guard(model.test_prog, startup_prog):
with fluid.unique_name.guard():
model.test_inputs, model.test_outputs = model.build_net(
mode='test')
model.test_prog = model.test_prog.clone(for_test=True)
model.exe.run(startup_prog)
import pickle
with open(osp.join(model_dir, 'model.pdparams'), 'rb') as f:
load_dict = pickle.load(f)
fluid.io.set_program_state(model.test_prog, load_dict)
elif status in ['Infer', 'Quant']:
[prog, input_names, outputs] = fluid.io.load_inference_model(
model_dir, model.exe, params_filename='__params__')
model.test_prog = prog
test_outputs_info = info['_ModelInputsOutputs']['test_outputs']
model.test_inputs = OrderedDict()
model.test_outputs = OrderedDict()
for name in input_names:
model.test_inputs[name] = model.test_prog.global_block().var(name)
for i, out in enumerate(outputs):
var_desc = test_outputs_info[i]
model.test_outputs[var_desc[0]] = out
if 'test_transforms' in info:
model.test_transforms = build_transforms(info['test_transforms'])
model.eval_transforms = copy.deepcopy(model.test_transforms)
if '_Attributes' in info:
for k, v in info['_Attributes'].items():
if k in model.__dict__:
model.__dict__[k] = v
logging.info("Model[{}] loaded.".format(info['Model']))
return model
def build_transforms(transforms_info):
import HumanSeg.transforms as T
transforms = list()
for op_info in transforms_info:
op_name = list(op_info.keys())[0]
op_attr = op_info[op_name]
if not hasattr(T, op_name):
raise Exception(
"There's no operator named '{}' in transforms".format(op_name))
transforms.append(getattr(T, op_name)(**op_attr))
eval_transforms = T.Compose(transforms)
return eval_transforms
from .backbone import mobilenet_v2
from .backbone import shufflenet_slim
from .backbone import xception
from .unet import UNet
from .deeplabv3p import DeepLabv3p
from .shufflenet_slim import ShuffleSeg
from .hrnet import HRNet
# coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
from .seg_modules import softmax_with_loss
from .seg_modules import dice_loss
from .seg_modules import bce_loss
from .libs import sigmoid_to_softmax
class HRNet(object):
def __init__(self,
num_classes,
mode='train',
stage1_num_modules=1,
stage1_num_blocks=[4],
stage1_num_channels=[64],
stage2_num_modules=1,
stage2_num_blocks=[4, 4],
stage2_num_channels=[18, 36],
stage3_num_modules=4,
stage3_num_blocks=[4, 4, 4],
stage3_num_channels=[18, 36, 72],
stage4_num_modules=3,
stage4_num_blocks=[4, 4, 4, 4],
stage4_num_channels=[18, 36, 72, 144],
use_bce_loss=False,
use_dice_loss=False,
class_weight=None,
ignore_index=255):
# dice_loss或bce_loss只适用两类分割中
if num_classes > 2 and (use_bce_loss or use_dice_loss):
raise ValueError(
"dice loss and bce loss is only applicable to binary classfication"
)
if class_weight is not None:
if isinstance(class_weight, list):
if len(class_weight) != num_classes:
raise ValueError(
"Length of class_weight should be equal to number of classes"
)
elif isinstance(class_weight, str):
if class_weight.lower() != 'dynamic':
raise ValueError(
"if class_weight is string, must be dynamic!")
else:
raise TypeError(
'Expect class_weight is a list or string but receive {}'.
format(type(class_weight)))
self.num_classes = num_classes
self.mode = mode
self.use_bce_loss = use_bce_loss
self.use_dice_loss = use_dice_loss
self.class_weight = class_weight
self.ignore_index = ignore_index
self.stage1_num_modules = stage1_num_modules
self.stage1_num_blocks = stage1_num_blocks
self.stage1_num_channels = stage1_num_channels
self.stage2_num_modules = stage2_num_modules
self.stage2_num_blocks = stage2_num_blocks
self.stage2_num_channels = stage2_num_channels
self.stage3_num_modules = stage3_num_modules
self.stage3_num_blocks = stage3_num_blocks
self.stage3_num_channels = stage3_num_channels
self.stage4_num_modules = stage4_num_modules
self.stage4_num_blocks = stage4_num_blocks
self.stage4_num_channels = stage4_num_channels
def build_net(self, inputs):
image = inputs['image']
logit = self._high_resolution_net(image, self.num_classes)
if self.num_classes == 1:
out = sigmoid_to_softmax(logit)
out = fluid.layers.transpose(out, [0, 2, 3, 1])
else:
out = fluid.layers.transpose(logit, [0, 2, 3, 1])
pred = fluid.layers.argmax(out, axis=3)
pred = fluid.layers.unsqueeze(pred, axes=[3])
if self.mode == 'train':
label = inputs['label']
mask = label != self.ignore_index
return self._get_loss(logit, label, mask)
else:
if self.num_classes == 1:
logit = sigmoid_to_softmax(logit)
else:
logit = fluid.layers.softmax(logit, axis=1)
return pred, logit
return logit
def generate_inputs(self):
inputs = OrderedDict()
inputs['image'] = fluid.data(
dtype='float32', shape=[None, 3, None, None], name='image')
if self.mode == 'train':
inputs['label'] = fluid.data(
dtype='int32', shape=[None, 1, None, None], name='label')
elif self.mode == 'eval':
inputs['label'] = fluid.data(
dtype='int32', shape=[None, 1, None, None], name='label')
return inputs
def _get_loss(self, logit, label, mask):
avg_loss = 0
if not (self.use_dice_loss or self.use_bce_loss):
avg_loss += softmax_with_loss(
logit,
label,
mask,
num_classes=self.num_classes,
weight=self.class_weight,
ignore_index=self.ignore_index)
else:
if self.use_dice_loss:
avg_loss += dice_loss(logit, label, mask)
if self.use_bce_loss:
avg_loss += bce_loss(
logit, label, mask, ignore_index=self.ignore_index)
return avg_loss
def _conv_bn_layer(self,
input,
filter_size,
num_filters,
stride=1,
padding=1,
num_groups=1,
if_act=True,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=num_groups,
act=None,
param_attr=ParamAttr(initializer=MSRA(), name=name + '_weights'),
bias_attr=False)
bn_name = name + '_bn'
bn = fluid.layers.batch_norm(
input=conv,
param_attr=ParamAttr(
name=bn_name + "_scale",
initializer=fluid.initializer.Constant(1.0)),
bias_attr=ParamAttr(
name=bn_name + "_offset",
initializer=fluid.initializer.Constant(0.0)),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
if if_act:
bn = fluid.layers.relu(bn)
return bn
def _basic_block(self,
input,
num_filters,
stride=1,
downsample=False,
name=None):
residual = input
conv = self._conv_bn_layer(
input=input,
filter_size=3,
num_filters=num_filters,
stride=stride,
name=name + '_conv1')
conv = self._conv_bn_layer(
input=conv,
filter_size=3,
num_filters=num_filters,
if_act=False,
name=name + '_conv2')
if downsample:
residual = self._conv_bn_layer(
input=input,
filter_size=1,
num_filters=num_filters,
if_act=False,
name=name + '_downsample')
return fluid.layers.elementwise_add(x=residual, y=conv, act='relu')
def _bottleneck_block(self,
input,
num_filters,
stride=1,
downsample=False,
name=None):
residual = input
conv = self._conv_bn_layer(
input=input,
filter_size=1,
num_filters=num_filters,
name=name + '_conv1')
conv = self._conv_bn_layer(
input=conv,
filter_size=3,
num_filters=num_filters,
stride=stride,
name=name + '_conv2')
conv = self._conv_bn_layer(
input=conv,
filter_size=1,
num_filters=num_filters * 4,
if_act=False,
name=name + '_conv3')
if downsample:
residual = self._conv_bn_layer(
input=input,
filter_size=1,
num_filters=num_filters * 4,
if_act=False,
name=name + '_downsample')
return fluid.layers.elementwise_add(x=residual, y=conv, act='relu')
def _fuse_layers(self, x, channels, multi_scale_output=True, name=None):
out = []
for i in range(len(channels) if multi_scale_output else 1):
residual = x[i]
shape = fluid.layers.shape(residual)[-2:]
for j in range(len(channels)):
if j > i:
y = self._conv_bn_layer(
x[j],
filter_size=1,
num_filters=channels[i],
if_act=False,
name=name + '_layer_' + str(i + 1) + '_' + str(j + 1))
y = fluid.layers.resize_bilinear(input=y, out_shape=shape)
residual = fluid.layers.elementwise_add(
x=residual, y=y, act=None)
elif j < i:
y = x[j]
for k in range(i - j):
if k == i - j - 1:
y = self._conv_bn_layer(
y,
filter_size=3,
num_filters=channels[i],
stride=2,
if_act=False,
name=name + '_layer_' + str(i + 1) + '_' +
str(j + 1) + '_' + str(k + 1))
else:
y = self._conv_bn_layer(
y,
filter_size=3,
num_filters=channels[j],
stride=2,
name=name + '_layer_' + str(i + 1) + '_' +
str(j + 1) + '_' + str(k + 1))
residual = fluid.layers.elementwise_add(
x=residual, y=y, act=None)
residual = fluid.layers.relu(residual)
out.append(residual)
return out
def _branches(self, x, block_num, channels, name=None):
out = []
for i in range(len(channels)):
residual = x[i]
for j in range(block_num[i]):
residual = self._basic_block(
residual,
channels[i],
name=name + '_branch_layer_' + str(i + 1) + '_' +
str(j + 1))
out.append(residual)
return out
def _high_resolution_module(self,
x,
blocks,
channels,
multi_scale_output=True,
name=None):
residual = self._branches(x, blocks, channels, name=name)
out = self._fuse_layers(
residual,
channels,
multi_scale_output=multi_scale_output,
name=name)
return out
def _transition_layer(self, x, in_channels, out_channels, name=None):
num_in = len(in_channels)
num_out = len(out_channels)
out = []
for i in range(num_out):
if i < num_in:
if in_channels[i] != out_channels[i]:
residual = self._conv_bn_layer(
x[i],
filter_size=3,
num_filters=out_channels[i],
name=name + '_layer_' + str(i + 1))
out.append(residual)
else:
out.append(x[i])
else:
residual = self._conv_bn_layer(
x[-1],
filter_size=3,
num_filters=out_channels[i],
stride=2,
name=name + '_layer_' + str(i + 1))
out.append(residual)
return out
def _stage(self,
x,
num_modules,
num_blocks,
num_channels,
multi_scale_output=True,
name=None):
out = x
for i in range(num_modules):
if i == num_modules - 1 and multi_scale_output == False:
out = self._high_resolution_module(
out,
num_blocks,
num_channels,
multi_scale_output=False,
name=name + '_' + str(i + 1))
else:
out = self._high_resolution_module(
out, num_blocks, num_channels, name=name + '_' + str(i + 1))
return out
def _layer1(self, input, num_modules, num_blocks, num_channels, name=None):
# num_modules 默认为1,是否增加处理,官网实现为[1],是否对齐。
conv = input
for i in range(num_blocks[0]):
conv = self._bottleneck_block(
conv,
num_filters=num_channels[0],
downsample=True if i == 0 else False,
name=name + '_' + str(i + 1))
return conv
def _high_resolution_net(self, input, num_classes):
x = self._conv_bn_layer(
input=input,
filter_size=3,
num_filters=self.stage1_num_channels[0],
stride=2,
if_act=True,
name='layer1_1')
x = self._conv_bn_layer(
input=x,
filter_size=3,
num_filters=self.stage1_num_channels[0],
stride=2,
if_act=True,
name='layer1_2')
la1 = self._layer1(
x,
self.stage1_num_modules,
self.stage1_num_blocks,
self.stage1_num_channels,
name='layer2')
tr1 = self._transition_layer([la1],
self.stage1_num_channels,
self.stage2_num_channels,
name='tr1')
st2 = self._stage(
tr1,
self.stage2_num_modules,
self.stage2_num_blocks,
self.stage2_num_channels,
name='st2')
tr2 = self._transition_layer(
st2, self.stage2_num_channels, self.stage3_num_channels, name='tr2')
st3 = self._stage(
tr2,
self.stage3_num_modules,
self.stage3_num_blocks,
self.stage3_num_channels,
name='st3')
tr3 = self._transition_layer(
st3, self.stage3_num_channels, self.stage4_num_channels, name='tr3')
st4 = self._stage(
tr3,
self.stage4_num_modules,
self.stage4_num_blocks,
self.stage4_num_channels,
name='st4')
# upsample
shape = fluid.layers.shape(st4[0])[-2:]
st4[1] = fluid.layers.resize_bilinear(st4[1], out_shape=shape)
st4[2] = fluid.layers.resize_bilinear(st4[2], out_shape=shape)
st4[3] = fluid.layers.resize_bilinear(st4[3], out_shape=shape)
out = fluid.layers.concat(st4, axis=1)
last_channels = sum(self.stage4_num_channels)
out = self._conv_bn_layer(
input=out,
filter_size=1,
num_filters=last_channels,
stride=1,
if_act=True,
name='conv-2')
out = fluid.layers.conv2d(
input=out,
num_filters=num_classes,
filter_size=1,
stride=1,
padding=0,
act=None,
param_attr=ParamAttr(initializer=MSRA(), name='conv-1_weights'),
bias_attr=False)
input_shape = fluid.layers.shape(input)[-2:]
out = fluid.layers.resize_bilinear(out, input_shape)
return out
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
from .libs import sigmoid_to_softmax
from .seg_modules import softmax_with_loss
from .seg_modules import dice_loss
from .seg_modules import bce_loss
class ShuffleSeg(object):
# def __init__(self):
# self.params = train_parameters
def __init__(self,
num_classes,
mode='train',
use_bce_loss=False,
use_dice_loss=False,
class_weight=None,
ignore_index=255):
# dice_loss或bce_loss只适用两类分割中
if num_classes > 2 and (use_bce_loss or use_dice_loss):
raise ValueError(
"dice loss and bce loss is only applicable to binary classfication"
)
if class_weight is not None:
if isinstance(class_weight, list):
if len(class_weight) != num_classes:
raise ValueError(
"Length of class_weight should be equal to number of classes"
)
elif isinstance(class_weight, str):
if class_weight.lower() != 'dynamic':
raise ValueError(
"if class_weight is string, must be dynamic!")
else:
raise TypeError(
'Expect class_weight is a list or string but receive {}'.
format(type(class_weight)))
self.num_classes = num_classes
self.mode = mode
self.use_bce_loss = use_bce_loss
self.use_dice_loss = use_dice_loss
self.class_weight = class_weight
self.ignore_index = ignore_index
def _get_loss(self, logit, label, mask):
avg_loss = 0
if not (self.use_dice_loss or self.use_bce_loss):
avg_loss += softmax_with_loss(
logit,
label,
mask,
num_classes=self.num_classes,
weight=self.class_weight,
ignore_index=self.ignore_index)
else:
if self.use_dice_loss:
avg_loss += dice_loss(logit, label, mask)
if self.use_bce_loss:
avg_loss += bce_loss(
logit, label, mask, ignore_index=self.ignore_index)
return avg_loss
def generate_inputs(self):
inputs = OrderedDict()
inputs['image'] = fluid.data(
dtype='float32', shape=[None, 3, None, None], name='image')
if self.mode == 'train':
inputs['label'] = fluid.data(
dtype='int32', shape=[None, 1, None, None], name='label')
elif self.mode == 'eval':
inputs['label'] = fluid.data(
dtype='int32', shape=[None, 1, None, None], name='label')
return inputs
def build_net(self, inputs, class_dim=2):
if self.use_dice_loss or self.use_bce_loss:
self.num_classes = 1
image = inputs['image']
## Encoder
conv1 = self.conv_bn(image, 3, 36, 2, 1)
print('encoder 1', conv1.shape)
shortcut = self.conv_bn(
input=conv1, filter_size=1, num_filters=18, stride=1, padding=0)
print('shortcut 1', shortcut.shape)
pool = fluid.layers.pool2d(
input=conv1,
pool_size=3,
pool_type='max',
pool_stride=2,
pool_padding=1)
print('encoder 2', pool.shape)
# Block 1
conv = self.sfnetv2module(pool, stride=2, num_filters=72)
conv = self.sfnetv2module(conv, stride=1)
conv = self.sfnetv2module(conv, stride=1)
conv = self.sfnetv2module(conv, stride=1)
print('encoder 3', conv.shape)
# Block 2
conv = self.sfnetv2module(conv, stride=2)
conv = self.sfnetv2module(conv, stride=1)
conv = self.sfnetv2module(conv, stride=1)
conv = self.sfnetv2module(conv, stride=1)
conv = self.sfnetv2module(conv, stride=1)
conv = self.sfnetv2module(conv, stride=1)
conv = self.sfnetv2module(conv, stride=1)
conv = self.sfnetv2module(conv, stride=1)
print('encoder 4', conv.shape)
### decoder
conv = self.depthwise_separable(conv, 3, 64, 1)
shortcut_shape = fluid.layers.shape(shortcut)[2:]
conv_b = fluid.layers.resize_bilinear(conv, shortcut_shape)
concat = fluid.layers.concat([shortcut, conv_b], axis=1)
decode_conv = self.depthwise_separable(concat, 3, 64, 1)
logit = self.output_layer(decode_conv, class_dim)
if self.num_classes == 1:
out = sigmoid_to_softmax(logit)
out = fluid.layers.transpose(out, [0, 2, 3, 1])
else:
out = fluid.layers.transpose(logit, [0, 2, 3, 1])
pred = fluid.layers.argmax(out, axis=3)
pred = fluid.layers.unsqueeze(pred, axes=[3])
if self.mode == 'train':
label = inputs['label']
mask = label != self.ignore_index
return self._get_loss(logit, label, mask)
else:
if self.num_classes == 1:
logit = sigmoid_to_softmax(logit)
else:
logit = fluid.layers.softmax(logit, axis=1)
return pred, logit
return logit
def conv_bn(self,
input,
filter_size,
num_filters,
stride,
padding,
channels=None,
num_groups=1,
act='relu',
use_cudnn=True):
parameter_attr = ParamAttr(learning_rate=1, initializer=MSRA())
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=parameter_attr,
bias_attr=False)
return fluid.layers.batch_norm(input=conv, act=act)
def depthwise_separable(self, input, filter_size, num_filters, stride):
num_filters1 = int(input.shape[1])
num_groups = num_filters1
depthwise_conv = self.conv_bn(
input=input,
filter_size=filter_size,
num_filters=int(num_filters1),
stride=stride,
padding=int(filter_size / 2),
num_groups=num_groups,
use_cudnn=False,
act=None)
pointwise_conv = self.conv_bn(
input=depthwise_conv,
filter_size=1,
num_filters=num_filters,
stride=1,
padding=0)
return pointwise_conv
def sfnetv2module(self, input, stride, num_filters=None):
if stride == 1:
shortcut, branch = fluid.layers.split(
input, num_or_sections=2, dim=1)
if num_filters is None:
in_channels = int(branch.shape[1])
else:
in_channels = int(num_filters / 2)
else:
branch = input
if num_filters is None:
in_channels = int(branch.shape[1])
else:
in_channels = int(num_filters / 2)
shortcut = self.depthwise_separable(input, 3, in_channels, stride)
branch_1x1 = self.conv_bn(
input=branch,
filter_size=1,
num_filters=int(in_channels),
stride=1,
padding=0)
branch_dw1x1 = self.depthwise_separable(branch_1x1, 3, in_channels,
stride)
output = fluid.layers.concat(input=[shortcut, branch_dw1x1], axis=1)
# channel shuffle
# b, c, h, w = output.shape
shape = fluid.layers.shape(output)
c = output.shape[1]
b, h, w = shape[0], shape[2], shape[3]
output = fluid.layers.reshape(x=output, shape=[b, 2, in_channels, h, w])
output = fluid.layers.transpose(x=output, perm=[0, 2, 1, 3, 4])
output = fluid.layers.reshape(x=output, shape=[b, c, h, w])
return output
def output_layer(self, input, out_dim):
param_attr = fluid.param_attr.ParamAttr(
learning_rate=1.,
regularizer=fluid.regularizer.L2Decay(0.),
initializer=fluid.initializer.Xavier())
# deconv
output = fluid.layers.conv2d_transpose(
input=input,
num_filters=out_dim,
filter_size=2,
padding=0,
stride=2,
bias_attr=True,
param_attr=param_attr,
act=None)
return output
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import setuptools
long_descrition = 'HumanSeg'
setuptools.setup(
name='HumanSeg',
version='1.0.0',
author='paddleseg',
description=long_descrition,
long_descrition=long_descrition,
packages='./',
setup_requires=['cython', 'numpy'],
install_requires=['pyyaml', 'tqdm', 'visualdl==1.3.0', 'paddleslim==1.0.1'],
license='Apache 2.0')
......@@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import transforms
from .transforms import *
from . import functional
......@@ -58,13 +58,14 @@ class Compose:
if im_info is None:
im_info = dict()
im = cv2.imread(im).astype('float32')
if isinstance(im, str):
im = cv2.imread(im).astype('float32')
if isinstance(label, str):
label = np.asarray(Image.open(label))
if im is None:
raise ValueError('Can\'t read The image file {}!'.format(im))
if self.to_rgb:
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
if label is not None:
label = np.asarray(Image.open(label))
for op in self.transforms:
outputs = op(im, im_info, label)
......
import HumanSeg
model = HumanSeg.models.load_model('output/best_model')
model.export_inference_model('output/export')
import HumanSeg
from HumanSeg.utils import visualize
im_file = '/ssd1/chenguowei01/dataset/humanseg/supervise.ly/pexel/img/person_detection__ds6/img/pexels-photo-704264.jpg'
model = HumanSeg.models.load_model('output/best_model')
result = model.predict(im_file)
visualize(im_file, result, save_dir='output/')
import os
import numpy as np
from HumanSeg.datasets.dataset import Dataset
from HumanSeg.models import HumanSegMobile
from HumanSeg.transforms import transforms
train_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.Resize((192, 192)),
transforms.Normalize()
])
eval_transforms = transforms.Compose(
[transforms.Resize((192, 192)),
transforms.Normalize()])
data_dir = '/ssd1/chenguowei01/dataset/humanseg/supervise.ly'
train_list = '/ssd1/chenguowei01/dataset/humanseg/supervise.ly/train.txt'
val_list = '/ssd1/chenguowei01/dataset/humanseg/supervise.ly/val.txt'
train_dataset = Dataset(
data_dir=data_dir,
file_list=train_list,
transforms=train_transforms,
num_workers='auto',
buffer_size=100,
parallel_method='thread',
shuffle=True)
eval_dataset = Dataset(
data_dir=data_dir,
file_list=val_list,
transforms=eval_transforms,
num_workers='auto',
buffer_size=100,
parallel_method='thread',
shuffle=False)
model = HumanSegMobile(num_classes=2)
model.train(
num_epochs=100,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
save_interval_epochs=5,
train_batch_size=256,
# resume_weights='/Users/chenguowei01/PycharmProjects/github/PaddleSeg/contrib/HumanSeg/output/epoch_20',
log_interval_steps=2,
save_dir='output',
use_vdl=True,
)
import HumanSeg
from HumanSeg.datasets.dataset import Dataset
from HumanSeg.transforms import transforms
eval_transforms = transforms.Compose(
[transforms.Resize((192, 192)),
transforms.Normalize()])
data_dir = '/ssd1/chenguowei01/dataset/humanseg/supervise.ly'
val_list = '/ssd1/chenguowei01/dataset/humanseg/supervise.ly/val.txt'
eval_dataset = Dataset(
data_dir=data_dir,
file_list=val_list,
transforms=eval_transforms,
num_workers='auto',
buffer_size=100,
parallel_method='thread',
shuffle=False)
model = HumanSeg.models.load_model('output/best_model')
model.evaluate(eval_dataset, 2)
......@@ -16,3 +16,4 @@ from . import logging
from . import utils
from .metrics import ConfusionMatrix
from .utils import *
from .post_quantization import HumanSegPostTrainingQuantization
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid.contrib.slim.quantization.quantization_pass import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization.quantization_pass import AddQuantDequantPass
from paddle.fluid.contrib.slim.quantization.quantization_pass import _op_real_in_out_name
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
import paddle.fluid as fluid
import os
import HumanSeg.utils.logging as logging
class HumanSegPostTrainingQuantization(PostTrainingQuantization):
def __init__(self,
executor,
dataset,
program,
inputs,
outputs,
batch_size=10,
batch_nums=None,
scope=None,
algo="KL",
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False,
is_use_cache_file=False,
cache_dir="./temp_post_training"):
'''
The class utilizes post training quantization methon to quantize the
fp32 model. It uses calibrate data to calculate the scale factor of
quantized variables, and inserts fake quant/dequant op to obtain the
quantized model.
Args:
executor(fluid.Executor): The executor to load, run and save the
quantized model.
dataset(Python Iterator): The data Reader.
program(fluid.Program): The paddle program, save the parameters for model.
inputs(dict): The input of prigram.
outputs(dict): The output of program.
batch_size(int, optional): The batch size of DataLoader. Default is 10.
batch_nums(int, optional): If batch_nums is not None, the number of
calibrate data is batch_size*batch_nums. If batch_nums is None, use
all data provided by sample_generator as calibrate data.
scope(fluid.Scope, optional): The scope of the program, use it to load
and save variables. If scope=None, get scope by global_scope().
algo(str, optional): If algo=KL, use KL-divergenc method to
get the more precise scale factor. If algo='direct', use
abs_max methon to get the scale factor. Default is KL.
quantizable_op_type(list[str], optional): List the type of ops
that will be quantized. Default is ["conv2d", "depthwise_conv2d",
"mul"].
is_full_quantized(bool, optional): If set is_full_quantized as True,
apply quantization to all supported quantizable op type. If set
is_full_quantized as False, only apply quantization to the op type
according to the input quantizable_op_type.
is_use_cache_file(bool, optional): If set is_use_cache_file as False,
all temp data will be saved in memory. If set is_use_cache_file as True,
it will save temp data to disk. When the fp32 model is complex or
the number of calibrate data is large, we should set is_use_cache_file
as True. Defalut is False.
cache_dir(str, optional): When is_use_cache_file is True, set cache_dir as
the directory for saving temp data. Default is ./temp_post_training.
Returns:
None
'''
self._executor = executor
self._dataset = dataset
self._batch_size = batch_size
self._batch_nums = batch_nums
self._scope = fluid.global_scope() if scope == None else scope
self._algo = algo
self._is_use_cache_file = is_use_cache_file
self._cache_dir = cache_dir
if self._is_use_cache_file and not os.path.exists(self._cache_dir):
os.mkdir(self._cache_dir)
supported_quantizable_op_type = \
QuantizationTransformPass._supported_quantizable_op_type + \
AddQuantDequantPass._supported_quantizable_op_type
if is_full_quantize:
self._quantizable_op_type = supported_quantizable_op_type
else:
self._quantizable_op_type = quantizable_op_type
for op_type in self._quantizable_op_type:
assert op_type in supported_quantizable_op_type + \
AddQuantDequantPass._activation_type, \
op_type + " is not supported for quantization."
self._place = self._executor.place
self._program = program
self._feed_list = list(inputs.values())
self._fetch_list = list(outputs.values())
self._data_loader = None
self._op_real_in_out_name = _op_real_in_out_name
self._bit_length = 8
self._quantized_weight_var_name = set()
self._quantized_act_var_name = set()
self._sampling_data = {}
self._quantized_var_scale_factor = {}
def quantize(self):
'''
Quantize the fp32 model. Use calibrate data to calculate the scale factor of
quantized variables, and inserts fake quant/dequant op to obtain the
quantized model.
Args:
None
Returns:
the program of quantized model.
'''
self._preprocess()
batch_id = 0
for data in self._data_loader():
self._executor.run(
program=self._program,
feed=data,
fetch_list=self._fetch_list,
return_numpy=False)
self._sample_data(batch_id)
if batch_id % 5 == 0:
logging.info("run batch: {}".format(batch_id))
batch_id += 1
if self._batch_nums and batch_id >= self._batch_nums:
break
logging.info("all run batch: ".format(batch_id))
logging.info("calculate scale factor ...")
self._calculate_scale_factor()
logging.info("update the program ...")
self._update_program()
self._save_output_scale()
return self._program
def save_quantized_model(self, save_model_path):
'''
Save the quantized model to the disk.
Args:
save_model_path(str): The path to save the quantized model
Returns:
None
'''
feed_vars_names = [var.name for var in self._feed_list]
fluid.io.save_inference_model(
dirname=save_model_path,
feeded_var_names=feed_vars_names,
target_vars=self._fetch_list,
executor=self._executor,
params_filename='__params__',
main_program=self._program)
def _preprocess(self):
'''
Load model and set data loader, collect the variable names for sampling,
and set activation variables to be persistable.
'''
feed_vars = [fluid.framework._get_var(var.name, self._program) \
for var in self._feed_list]
self._data_loader = fluid.io.DataLoader.from_generator(
feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True)
self._data_loader.set_sample_list_generator(
self._dataset.generator(self._batch_size, drop_last=True),
places=self._place)
# collect the variable names for sampling
persistable_var_names = []
for var in self._program.list_vars():
if var.persistable:
persistable_var_names.append(var.name)
for op in self._program.global_block().ops:
op_type = op.type
if op_type in self._quantizable_op_type:
if op_type in ("conv2d", "depthwise_conv2d"):
self._quantized_act_var_name.add(op.input("Input")[0])
self._quantized_weight_var_name.add(op.input("Filter")[0])
self._quantized_act_var_name.add(op.output("Output")[0])
elif op_type == "mul":
if self._is_input_all_not_persistable(
op, persistable_var_names):
op._set_attr("skip_quant", True)
logging.warning(
"Skip quant a mul op for two input variables are not persistable"
)
else:
self._quantized_act_var_name.add(op.input("X")[0])
self._quantized_weight_var_name.add(op.input("Y")[0])
self._quantized_act_var_name.add(op.output("Out")[0])
else:
# process other quantizable op type, the input must all not persistable
if self._is_input_all_not_persistable(
op, persistable_var_names):
input_output_name_list = self._op_real_in_out_name[
op_type]
for input_name in input_output_name_list[0]:
for var_name in op.input(input_name):
self._quantized_act_var_name.add(var_name)
for output_name in input_output_name_list[1]:
for var_name in op.output(output_name):
self._quantized_act_var_name.add(var_name)
# set activation variables to be persistable, so can obtain
# the tensor data in sample_data
for var in self._program.list_vars():
if var.name in self._quantized_act_var_name:
var.persistable = True
......@@ -20,6 +20,7 @@ import numpy as np
import six
import yaml
import math
import cv2
from . import logging
......@@ -218,3 +219,58 @@ def load_pretrain_weights(exe, main_prog, weights_dir, fuse_bn=False):
len(vars_to_load), weights_dir))
if fuse_bn:
fuse_bn_weights(exe, main_prog, weights_dir)
def visualize(image, result, save_dir=None, weight=0.6):
"""
Convert segment result to color image, and save added image.
Args:
image: the path of origin image
result: the predict result of image
save_dir: the directory for saving visual image
weight: the image weight of visual image, and the result weight is (1 - weight)
"""
label_map = result['label_map']
color_map = get_color_map_list(256)
color_map = np.array(color_map).astype("uint8")
# Use OpenCV LUT for color mapping
c1 = cv2.LUT(label_map, color_map[:, 0])
c2 = cv2.LUT(label_map, color_map[:, 1])
c3 = cv2.LUT(label_map, color_map[:, 2])
pseudo_img = np.dstack((c1, c2, c3))
im = cv2.imread(image)
vis_result = cv2.addWeighted(im, weight, pseudo_img, 1 - weight, 0)
if save_dir is not None:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
image_name = os.path.split(image)[-1]
out_path = os.path.join(save_dir, image_name)
cv2.imwrite(out_path, vis_result)
else:
return vis_result
def get_color_map_list(num_classes):
""" Returns the color map for visualizing the segmentation mask,
which can support arbitrary number of classes.
Args:
num_classes: Number of classes
Returns:
The color map
"""
num_classes += 1
color_map = num_classes * [0, 0, 0]
for i in range(0, num_classes):
j = 0
lab = i
while lab:
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
j += 1
lab >>= 3
color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
color_map = color_map[1:]
return color_map
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册