未验证 提交 8281cc6f 编写于 作者: W wuyefeilin 提交者: GitHub

Add contrib HumanSeg (#229)

* update solver.py and model_builder.py

* update solver.py

* update infer.py

* update model_builder.py for fitting different aug_method

* update model_builder.py

* update model_builder.py

* update export process

* save best model

* update train.py

* update train.py

* update train.py

* eval for last epoch model

* check model path

* check TEST.TEST_MODEL

* reconsitution humanseg

* first refactor humanseg

* update humanseg.py and add main.py

* frist add

* add shufflenet

* first add

* add hrnet

* add tutorial

* add humanseg infer

* add realtime humanseg

* update hrnet.py

* update humanseg.py

* rm cpp deploy

* add infer of resize_bylong

* update

* update infer.py

* update some

* update humanseg.py

* image_segment and video_segment add to class HumanSeg

* add quant

* resolve with_data_parallel problem

* add coeff params

* add demo

* update

* update __init__.py

* add setup.py
上级 5b9a1db6
# HumanSeg
## 环境
将contrib目录加入环境变量PYTHONPATH
## 训练、验证、预测、模型导出参见[turtorial](./turtorial)
...@@ -17,7 +17,7 @@ from . import nets ...@@ -17,7 +17,7 @@ from . import nets
from . import models from . import models
from . import datasets from . import datasets
from . import transforms from . import transforms
from .utils.utils import get_environ_info from .utils import get_environ_info
env_info = get_environ_info() env_info = get_environ_info()
......
...@@ -11,3 +11,5 @@ ...@@ -11,3 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .dataset import Dataset
# 实时人像分割预测部署
本模型基于飞浆开源的人像分割模型,并做了大量的针对视频的光流追踪优化,提供了完整的支持视频流的实时人像分割解决方案,并提供了高性能的`Python`集成部署方案。
## 模型下载
支持的模型文件如下,请根据应用场景选择合适的模型:
|模型文件 | 说明 |
| --- | --- |
|[shv75_deeplab_0303_quant](https://paddleseg.bj.bcebos.com/deploy/models/shv75_0303_quant.zip) | 小模型, 适合轻量级计算环境 |
|[shv75_deeplab_0303](https://paddleseg.bj.bcebos.com/deploy/models/shv75_deeplab_0303.zip)| 小模型,适合轻量级计算环境 |
|[deeplabv3_xception_humanseg](https://paddleseg.bj.bcebos.com/deploy/models/deeplabv3_xception_humanseg.zip) | 服务端GPU环境 |
**注意:下载后解压到合适的路径,后续该路径将做为预测参数用于加载模型。**
## 预测部署
- [Python预测部署](./python)
## 效果预览
<figure class="half">
<img src="https://paddleseg.bj.bcebos.com/deploy/data/input.gif">
<img src="https://paddleseg.bj.bcebos.com/deploy/data/output.gif">
</figure>
# 实时人像分割Python预测部署方案
本方案基于Python实现,最小化依赖并把所有模型加载、数据预处理、预测、光流处理等后处理都封装在文件`infer.py`中,用户可以直接使用或集成到自己项目中。
## 前置依赖
- Windows(7,8,10) / Linux (Ubuntu 16.04) or MacOS 10.1+
- Paddle 1.7+
- Python 3.6+
注意:
1. 仅支持Paddle 1.7以上版本
2. MacOS上不支持GPU预测
其它未涉及情形,能正常安装`Paddle``OpenCV`通常都能正常使用。
## 安装依赖
执行如下命令
```shell
pip install -r requirements.txt
```
## 运行
1. 输入图片进行分割
```
python infer.py --model_dir /PATH/TO/INFERENCE/MODEL --img_path /PATH/TO/INPUT/IMAGE
```
预测结果会保存为`result.jpeg`
2. 输入视频进行分割
```shell
python infer.py --model_dir /PATH/TO/INFERENCE/MODEL --video_path /PATH/TO/INPUT/VIDEO
```
预测结果会保存在`result.avi`
3. 使用摄像头视频流
```shell
python infer.py --model_dir /PATH/TO/INFERENCE/MODEL --use_camera True
```
预测结果会通过可视化窗口实时显示。
**注意:**
`GPU`默认关闭, 如果要使用`GPU`进行加速,则先运行
```
export CUDA_VISIBLE_DEVICES=0
```
然后在前面的预测命令中增加参数`--use_gpu True`即可。
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""实时人像分割Python预测部署"""
import os
import argparse
import numpy as np
import cv2
import paddle.fluid as fluid
def humanseg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow):
"""计算光流跟踪匹配点和光流图
输入参数:
pre_gray: 上一帧灰度图
cur_gray: 当前帧灰度图
prev_cfd: 上一帧光流图
dl_weights: 融合权重图
disflow: 光流数据结构
返回值:
is_track: 光流点跟踪二值图,即是否具有光流点匹配
track_cfd: 光流跟踪图
"""
check_thres = 8
hgt, wdh = pre_gray.shape[:2]
track_cfd = np.zeros_like(prev_cfd)
is_track = np.zeros_like(pre_gray)
# 计算前向光流
flow_fw = disflow.calc(pre_gray, cur_gray, None)
# 计算后向光流
flow_bw = disflow.calc(cur_gray, pre_gray, None)
get_round = lambda data: (int)(data + 0.5) if data >= 0 else (int)(data -
0.5)
for row in range(hgt):
for col in range(wdh):
# 计算光流处理后对应点坐标
# (row, col) -> (cur_x, cur_y)
fxy_fw = flow_fw[row, col]
dx_fw = get_round(fxy_fw[0])
cur_x = dx_fw + col
dy_fw = get_round(fxy_fw[1])
cur_y = dy_fw + row
if cur_x < 0 or cur_x >= wdh or cur_y < 0 or cur_y >= hgt:
continue
fxy_bw = flow_bw[cur_y, cur_x]
dx_bw = get_round(fxy_bw[0])
dy_bw = get_round(fxy_bw[1])
# 光流移动小于阈值
lmt = ((dy_fw + dy_bw) * (dy_fw + dy_bw) +
(dx_fw + dx_bw) * (dx_fw + dx_bw))
if lmt >= check_thres:
continue
# 静止点降权
if abs(dy_fw) <= 0 and abs(dx_fw) <= 0 and abs(dy_bw) <= 0 and abs(
dx_bw) <= 0:
dl_weights[cur_y, cur_x] = 0.05
is_track[cur_y, cur_x] = 1
track_cfd[cur_y, cur_x] = prev_cfd[row, col]
return track_cfd, is_track, dl_weights
def humanseg_track_fuse(track_cfd, dl_cfd, dl_weights, is_track):
"""光流追踪图和人像分割结构融合
输入参数:
track_cfd: 光流追踪图
dl_cfd: 当前帧分割结果
dl_weights: 融合权重图
is_track: 光流点匹配二值图
返回值:
cur_cfd: 光流跟踪图和人像分割结果融合图
"""
cur_cfd = dl_cfd.copy()
idxs = np.where(is_track > 0)
for i in range(len(idxs)):
x, y = idxs[0][i], idxs[1][i]
dl_score = dl_cfd[y, x]
track_score = track_cfd[y, x]
if dl_score > 0.9 or dl_score < 0.1:
if dl_weights[x, y] < 0.1:
cur_cfd[x, y] = 0.3 * dl_score + 0.7 * track_score
else:
cur_cfd[x, y] = 0.4 * dl_score + 0.6 * track_score
else:
cur_cfd[x, y] = dl_weights[x, y] * dl_score + (
1 - dl_weights[x, y]) * track_score
return cur_cfd
def threshold_mask(img, thresh_bg, thresh_fg):
"""设置背景和前景阈值mask
输入参数:
img : 原始图像, np.uint8 类型.
thresh_bg : 背景阈值百分比,低于该值置为0.
thresh_fg : 前景阈值百分比,超过该值置为1.
返回值:
dst : 原始图像设置完前景背景阈值mask结果, np.float32 类型.
"""
dst = (img / 255.0 - thresh_bg) / (thresh_fg - thresh_bg)
dst[np.where(dst > 1)] = 1
dst[np.where(dst < 0)] = 0
return dst.astype(np.float32)
def optflow_handle(cur_gray, scoremap, is_init):
"""光流优化
Args:
cur_gray : 当前帧灰度图
scoremap : 当前帧分割结果
is_init : 是否第一帧
Returns:
dst : 光流追踪图和预测结果融合图, 类型为 np.float32
"""
width, height = scoremap.shape[0], scoremap.shape[1]
disflow = cv2.DISOpticalFlow_create(cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST)
prev_gray = np.zeros((height, width), np.uint8)
prev_cfd = np.zeros((height, width), np.float32)
cur_cfd = scoremap.copy()
if is_init:
is_init = False
if height <= 64 or width <= 64:
disflow.setFinestScale(1)
elif height <= 160 or width <= 160:
disflow.setFinestScale(2)
else:
disflow.setFinestScale(3)
fusion_cfd = cur_cfd
else:
weights = np.ones((width, height), np.float32) * 0.3
track_cfd, is_track, weights = humanseg_tracking(
prev_gray, cur_gray, prev_cfd, weights, disflow)
fusion_cfd = humanseg_track_fuse(track_cfd, cur_cfd, weights, is_track)
fusion_cfd = cv2.GaussianBlur(fusion_cfd, (3, 3), 0)
return fusion_cfd
class HumanSeg:
"""人像分割类
封装了人像分割模型的加载,数据预处理,预测,后处理等
"""
def __init__(self, model_dir, mean, scale, eval_size, use_gpu=False):
self.mean = np.array(mean).reshape((3, 1, 1))
self.scale = np.array(scale).reshape((3, 1, 1))
self.eval_size = eval_size
self.load_model(model_dir, use_gpu)
def load_model(self, model_dir, use_gpu):
"""加载模型并创建predictor
Args:
model_dir: 预测模型路径, 包含 `__model__` 和 `__params__`
use_gpu: 是否使用GPU加速
"""
prog_file = os.path.join(model_dir, '__model__')
params_file = os.path.join(model_dir, '__params__')
config = fluid.core.AnalysisConfig(prog_file, params_file)
if use_gpu:
config.enable_use_gpu(100, 0)
config.switch_ir_optim(True)
else:
config.disable_gpu()
config.disable_glog_info()
config.switch_specify_input_names(True)
config.enable_memory_optim()
self.predictor = fluid.core.create_paddle_predictor(config)
def preprocess(self, image):
"""图像预处理
hwc_rgb 转换为 chw_bgr,并进行归一化
输入参数:
image: 原始图像
返回值:
经过预处理后的图片结果
"""
img_mat = cv2.resize(
image, self.eval_size, interpolation=cv2.INTER_LINEAR)
# HWC -> CHW
img_mat = img_mat.swapaxes(1, 2)
img_mat = img_mat.swapaxes(0, 1)
# Convert to float
img_mat = img_mat[:, :, :].astype('float32')
img_mat = (img_mat / 255. - self.mean) / self.scale
img_mat = img_mat[np.newaxis, :, :, :]
return img_mat
def postprocess(self, image, output_data):
"""对预测结果进行后处理
Args:
image: 原始图,opencv 图片对象
output_data: Paddle预测结果原始数据
Returns:
原图和预测结果融合并做了光流优化的结果图
"""
scoremap = output_data[0, 1, :, :]
scoremap = (scoremap * 255).astype(np.uint8)
ori_h, ori_w = image.shape[0], image.shape[1]
evl_h, evl_w = self.eval_size[0], self.eval_size[1]
# 光流处理
cur_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
cur_gray = cv2.resize(cur_gray, (evl_w, evl_h))
optflow_map = optflow_handle(cur_gray, scoremap, False)
optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
optflow_map = threshold_mask(optflow_map, thresh_bg=0.2, thresh_fg=0.8)
optflow_map = cv2.resize(optflow_map, (ori_w, ori_h))
optflow_map = np.repeat(optflow_map[:, :, np.newaxis], 3, axis=2)
bg_im = np.ones_like(optflow_map) * 255
comb = (optflow_map * image + (1 - optflow_map) * bg_im).astype(
np.uint8)
return comb
def run_predict(self, image):
"""运行预测并返回可视化结果图
输入参数:
image: 需要预测的原始图, opencv图片对象
返回值:
可视化的预测结果图
"""
im_mat = self.preprocess(image)
im_tensor = fluid.core.PaddleTensor(im_mat.copy().astype('float32'))
output_data = self.predictor.run([im_tensor])[1]
output_data = output_data.as_ndarray()
return self.postprocess(image, output_data)
def image_segment(self, path):
"""对图片文件进行分割
结果保存到`result.jpeg`文件中
"""
img_mat = cv2.imread(path)
img_mat = self.run_predict(img_mat)
cv2.imwrite('result.jpeg', img_mat)
def video_segment(self, path=None):
"""
对视屏流进行分割,
path为None时默认打开摄像头。
"""
if path is None:
cap = cv2.VideoCapture(0)
else:
cap = cv2.VideoCapture(path)
if not cap.isOpened():
raise IOError("Error opening video stream or file")
return
if path is not None:
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
# 用于保存预测结果视频
out = cv2.VideoWriter('result.avi',
cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'),
fps, (width, height))
# 开始获取视频帧
while cap.isOpened():
ret, frame = cap.read()
if ret:
img_mat = self.run_predict(frame)
out.write(img_mat)
else:
break
cap.release()
out.release()
else:
while cap.isOpened():
ret, frame = cap.read()
if ret:
img_mat = self.run_predict(frame)
cv2.imshow('HumanSegmentation', img_mat)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
else:
break
cap.release()
def main(args):
"""预测程序入口
完成模型加载, 对视频、摄像头、图片文件等预测过程
"""
model_dir = args.model_dir
use_gpu = args.use_gpu
# 加载模型
mean = [0.5, 0.5, 0.5]
scale = [0.5, 0.5, 0.5]
eval_size = (192, 192)
model = HumanSeg(model_dir, mean, scale, eval_size, use_gpu)
if args.use_camera:
# 开启摄像头
model.video_segment()
elif args.video_path:
# 使用视频文件作为输入
model.video_segment(args.video_path)
elif args.img_path:
# 使用图片文件作为输入
model.image_segment(args.img_path)
else:
raise ValueError(
'One of (--model_dir, --video_path, --use_camera) should be given.')
def parse_args():
"""解析命令行参数
"""
parser = argparse.ArgumentParser('Realtime Human Segmentation')
parser.add_argument(
'--model_dir',
type=str,
default='',
help='path of human segmentation model')
parser.add_argument(
'--img_path', type=str, default='', help='path of input image')
parser.add_argument(
'--video_path', type=str, default='', help='path of input video')
parser.add_argument(
'--use_camera',
type=bool,
default=False,
help='input video stream from camera')
parser.add_argument(
'--use_gpu', type=bool, default=False, help='enable gpu')
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
main(args)
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""实时人像分割Python预测部署"""
import os
import argparse
import numpy as np
import cv2
import paddle.fluid as fluid
def humanseg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow):
"""计算光流跟踪匹配点和光流图
输入参数:
pre_gray: 上一帧灰度图
cur_gray: 当前帧灰度图
prev_cfd: 上一帧光流图
dl_weights: 融合权重图
disflow: 光流数据结构
返回值:
is_track: 光流点跟踪二值图,即是否具有光流点匹配
track_cfd: 光流跟踪图
"""
check_thres = 8
hgt, wdh = pre_gray.shape[:2]
track_cfd = np.zeros_like(prev_cfd)
is_track = np.zeros_like(pre_gray)
# 计算前向光流
flow_fw = disflow.calc(pre_gray, cur_gray, None)
# 计算后向光流
flow_bw = disflow.calc(cur_gray, pre_gray, None)
get_round = lambda data: (int)(data + 0.5) if data >= 0 else (int)(data -
0.5)
for row in range(hgt):
for col in range(wdh):
# 计算光流处理后对应点坐标
# (row, col) -> (cur_x, cur_y)
fxy_fw = flow_fw[row, col]
dx_fw = get_round(fxy_fw[0])
cur_x = dx_fw + col
dy_fw = get_round(fxy_fw[1])
cur_y = dy_fw + row
if cur_x < 0 or cur_x >= wdh or cur_y < 0 or cur_y >= hgt:
continue
fxy_bw = flow_bw[cur_y, cur_x]
dx_bw = get_round(fxy_bw[0])
dy_bw = get_round(fxy_bw[1])
# 光流移动小于阈值
lmt = ((dy_fw + dy_bw) * (dy_fw + dy_bw) +
(dx_fw + dx_bw) * (dx_fw + dx_bw))
if lmt >= check_thres:
continue
# 静止点降权
if abs(dy_fw) <= 0 and abs(dx_fw) <= 0 and abs(dy_bw) <= 0 and abs(
dx_bw) <= 0:
dl_weights[cur_y, cur_x] = 0.05
is_track[cur_y, cur_x] = 1
track_cfd[cur_y, cur_x] = prev_cfd[row, col]
return track_cfd, is_track, dl_weights
def humanseg_track_fuse(track_cfd, dl_cfd, dl_weights, is_track):
"""光流追踪图和人像分割结构融合
输入参数:
track_cfd: 光流追踪图
dl_cfd: 当前帧分割结果
dl_weights: 融合权重图
is_track: 光流点匹配二值图
返回值:
cur_cfd: 光流跟踪图和人像分割结果融合图
"""
cur_cfd = dl_cfd.copy()
idxs = np.where(is_track > 0)
for i in range(len(idxs)):
x, y = idxs[0][i], idxs[1][i]
dl_score = dl_cfd[y, x]
track_score = track_cfd[y, x]
if dl_score > 0.9 or dl_score < 0.1:
if dl_weights[x, y] < 0.1:
cur_cfd[x, y] = 0.3 * dl_score + 0.7 * track_score
else:
cur_cfd[x, y] = 0.4 * dl_score + 0.6 * track_score
else:
cur_cfd[x, y] = dl_weights[x, y] * dl_score + (
1 - dl_weights[x, y]) * track_score
return cur_cfd
def threshold_mask(img, thresh_bg, thresh_fg):
"""设置背景和前景阈值mask
输入参数:
img : 原始图像, np.uint8 类型.
thresh_bg : 背景阈值百分比,低于该值置为0.
thresh_fg : 前景阈值百分比,超过该值置为1.
返回值:
dst : 原始图像设置完前景背景阈值mask结果, np.float32 类型.
"""
dst = (img / 255.0 - thresh_bg) / (thresh_fg - thresh_bg)
dst[np.where(dst > 1)] = 1
dst[np.where(dst < 0)] = 0
return dst.astype(np.float32)
def optflow_handle(cur_gray, scoremap, is_init):
"""光流优化
Args:
cur_gray : 当前帧灰度图
scoremap : 当前帧分割结果
is_init : 是否第一帧
Returns:
dst : 光流追踪图和预测结果融合图, 类型为 np.float32
"""
width, height = scoremap.shape[0], scoremap.shape[1]
disflow = cv2.DISOpticalFlow_create(cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST)
prev_gray = np.zeros((height, width), np.uint8)
prev_cfd = np.zeros((height, width), np.float32)
cur_cfd = scoremap.copy()
if is_init:
is_init = False
if height <= 64 or width <= 64:
disflow.setFinestScale(1)
elif height <= 160 or width <= 160:
disflow.setFinestScale(2)
else:
disflow.setFinestScale(3)
fusion_cfd = cur_cfd
else:
weights = np.ones((width, height), np.float32) * 0.3
track_cfd, is_track, weights = humanseg_tracking(
prev_gray, cur_gray, prev_cfd, weights, disflow)
fusion_cfd = humanseg_track_fuse(track_cfd, cur_cfd, weights, is_track)
fusion_cfd = cv2.GaussianBlur(fusion_cfd, (3, 3), 0)
return fusion_cfd
class HumanSeg:
"""人像分割类
封装了人像分割模型的加载,数据预处理,预测,后处理等
"""
def __init__(self, model_dir, mean, scale, long_size, use_gpu=False):
self.mean = np.array(mean).reshape((3, 1, 1))
self.scale = np.array(scale).reshape((3, 1, 1))
self.long_size = long_size
self.load_model(model_dir, use_gpu)
def load_model(self, model_dir, use_gpu):
"""加载模型并创建predictor
Args:
model_dir: 预测模型路径, 包含 `__model__` 和 `__params__`
use_gpu: 是否使用GPU加速
"""
prog_file = os.path.join(model_dir, '__model__')
params_file = os.path.join(model_dir, '__params__')
config = fluid.core.AnalysisConfig(prog_file, params_file)
if use_gpu:
config.enable_use_gpu(100, 0)
config.switch_ir_optim(True)
else:
config.disable_gpu()
config.disable_glog_info()
config.switch_specify_input_names(True)
config.enable_memory_optim()
self.predictor = fluid.core.create_paddle_predictor(config)
def preprocess(self, image):
"""图像预处理
hwc_rgb 转换为 chw_bgr,并进行归一化
输入参数:
image: 原始图像
返回值:
经过预处理后的图片结果
"""
origin_h, origin_w = image.shape[0], image.shape[1]
scale = float(self.long_size) / max(origin_w, origin_h)
resize_w = int(round(origin_w * scale))
resize_h = int(round(origin_h * scale))
img_mat = cv2.resize(
image, (resize_w, resize_h), interpolation=cv2.INTER_LINEAR)
pad_h = self.long_size - resize_h
pad_w = self.long_size - resize_w
img_mat = cv2.copyMakeBorder(
img_mat,
0,
pad_h,
0,
pad_w,
cv2.BORDER_CONSTANT,
value=[127.5, 127.5, 127.5])
# HWC -> CHW
img_mat = img_mat.swapaxes(1, 2)
img_mat = img_mat.swapaxes(0, 1)
# Convert to float
img_mat = img_mat[:, :, :].astype('float32')
img_mat = (img_mat / 255. - self.mean) / self.scale
img_mat = img_mat[np.newaxis, :, :, :]
return img_mat
def postprocess(self, image, output_data):
"""对预测结果进行后处理
Args:
image: 原始图,opencv 图片对象
output_data: Paddle预测结果原始数据
Returns:
原图和预测结果融合并做了光流优化的结果图
"""
scoremap = output_data[0, 1, :, :]
scoremap = (scoremap * 255).astype(np.uint8)
# 光流处理
cur_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
origin_h, origin_w = image.shape[0], image.shape[1]
scale = float(self.long_size) / max(origin_w, origin_h)
resize_w = int(round(origin_w * scale))
resize_h = int(round(origin_h * scale))
cur_gray = cv2.resize(
cur_gray, (resize_w, resize_h), interpolation=cv2.INTER_LINEAR)
pad_h = self.long_size - resize_h
pad_w = self.long_size - resize_w
cur_gray = cv2.copyMakeBorder(
cur_gray, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=127.5)
optflow_map = optflow_handle(cur_gray, scoremap, False)
optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
optflow_map = threshold_mask(optflow_map, thresh_bg=0.2, thresh_fg=0.8)
optflow_map = optflow_map[0:resize_h, 0:resize_w]
optflow_map = cv2.resize(optflow_map, (origin_w, origin_h))
optflow_map = np.repeat(optflow_map[:, :, np.newaxis], 3, axis=2)
bg_im = np.ones_like(optflow_map) * 255
comb = (optflow_map * image + (1 - optflow_map) * bg_im).astype(
np.uint8)
return comb
def run_predict(self, image):
"""运行预测并返回可视化结果图
输入参数:
image: 需要预测的原始图, opencv图片对象
返回值:
可视化的预测结果图
"""
im_mat = self.preprocess(image)
im_tensor = fluid.core.PaddleTensor(im_mat.copy().astype('float32'))
output_data = self.predictor.run([im_tensor])[1]
output_data = output_data.as_ndarray()
return self.postprocess(image, output_data)
def image_segment(self, path):
"""对图片文件进行分割
结果保存到`result.jpeg`文件中
"""
img_mat = cv2.imread(path)
img_mat = self.run_predict(img_mat)
cv2.imwrite('result.jpeg', img_mat)
def video_segment(self, path=None):
"""
对视屏流进行分割,
path为None时默认打开摄像头。
"""
if path is None:
cap = cv2.VideoCapture(0)
else:
cap = cv2.VideoCapture(path)
if not cap.isOpened():
raise IOError("Error opening video stream or file")
return
if path is not None:
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
# 用于保存预测结果视频
out = cv2.VideoWriter('result.avi',
cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'),
fps, (width, height))
# 开始获取视频帧
while cap.isOpened():
ret, frame = cap.read()
if ret:
img_mat = self.run_predict(frame)
out.write(img_mat)
else:
break
cap.release()
out.release()
else:
while cap.isOpened():
ret, frame = cap.read()
if ret:
img_mat = self.run_predict(frame)
cv2.imshow('HumanSegmentation', img_mat)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
else:
break
cap.release()
def main(args):
"""预测程序入口
完成模型加载, 对视频、摄像头、图片文件等预测过程
"""
model_dir = args.model_dir
use_gpu = args.use_gpu
# 加载模型
mean = [0.5, 0.5, 0.5]
scale = [0.5, 0.5, 0.5]
long_size = 192
model = HumanSeg(model_dir, mean, scale, long_size, use_gpu)
if args.use_camera:
# 开启摄像头
model.video_segment()
elif args.video_path:
# 使用视频文件作为输入
model.video_segment(args.video_path)
elif args.img_path:
# 使用图片文件作为输入
model.image_segment(args.img_path)
else:
raise ValueError(
'One of (--model_dir, --video_path, --use_camera) should be given.')
def parse_args():
"""解析命令行参数
"""
parser = argparse.ArgumentParser('Realtime Human Segmentation')
parser.add_argument(
'--model_dir',
type=str,
default='',
help='path of human segmentation model')
parser.add_argument(
'--img_path', type=str, default='', help='path of input image')
parser.add_argument(
'--video_path', type=str, default='', help='path of input video')
parser.add_argument(
'--use_camera',
type=bool,
default=False,
help='input video stream from camera')
parser.add_argument(
'--use_gpu', type=bool, default=False, help='enable gpu')
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
main(args)
opencv-python==4.1.2.30
opencv-contrib-python==4.2.0.32
import argparse
import os
import os.path as osp
import cv2
import numpy as np
import tqdm
import HumanSeg
def parse_args():
parser = argparse.ArgumentParser(
description='HumanSeg inference and visualization')
parser.add_argument(
'--test_model',
dest='test_model',
help='model path for inference',
type=str)
parser.add_argument(
'--data_dir',
dest='data_dir',
help='the root directory of dataset',
type=str)
parser.add_argument(
'--file_list', dest='file_list', help='file list for test', type=str)
parser.add_argument(
'--save_dir',
dest='save_dir',
help='the directory for saveing the inferenc results',
type=str,
default='./result')
return parser.parse_args()
def mkdir(path):
sub_dir = osp.dirname(path)
if not osp.exists(sub_dir):
os.makedirs(sub_dir)
def process(args):
model = HumanSeg.models.load_model(args.test_model)
added_saveed_path = osp.join(args.save_dir, 'added')
mat_saved_path = osp.join(args.save_dir, 'mat')
scoremap_saved_path = osp.join(args.save_dir, 'scoremap')
with open(args.file_list, 'r') as f:
files = f.readlines()
for file in tqdm.tqdm(files):
file = file.strip()
im_file = osp.join(args.data_dir, file)
im = cv2.imread(im_file)
result = model.predict(im)
# save added image
added_image = HumanSeg.utils.visualize(im_file, result, weight=0.6)
added_image_file = osp.join(added_saveed_path, file)
mkdir(added_image_file)
cv2.imwrite(added_image_file, added_image)
# save score map
score_map = result['score_map'][:, :, 1]
score_map = (score_map * 255).astype(np.uint8)
score_map_file = osp.join(scoremap_saved_path, file)
mkdir(score_map_file)
cv2.imwrite(score_map_file, score_map)
# save mat image
score_map = np.expand_dims(score_map, axis=-1)
mat_image = np.concatenate([im, score_map], axis=2)
mat_file = osp.join(mat_saved_path, file)
ext = osp.splitext(mat_file)[-1]
mat_file = mat_file.replace(ext, '.png')
mkdir(mat_file)
cv2.imwrite(mat_file, mat_image)
if __name__ == '__main__':
args = parse_args()
process(args)
from .humanseg import HumanSegMobile from .humanseg import HumanSegMobile
from .humanseg import HumanSegServer
from .humanseg import HumanSegLite
from .humanseg import HRNet
from .load_model import load_model
...@@ -24,6 +24,7 @@ import time ...@@ -24,6 +24,7 @@ import time
import tqdm import tqdm
import cv2 import cv2
import yaml import yaml
import paddleslim as slim
import HumanSeg import HumanSeg
import HumanSeg.utils.logging as logging import HumanSeg.utils.logging as logging
...@@ -42,16 +43,10 @@ def dict2str(dict_input): ...@@ -42,16 +43,10 @@ def dict2str(dict_input):
return out.strip(', ') return out.strip(', ')
class HumanSegMobile(object): class SegModel(object):
# DeepLab mobilenet # DeepLab mobilenet
def __init__(self, def __init__(self,
num_classes=2, num_classes=2,
backbone='MobileNetV2_x1.0',
output_stride=16,
aspp_with_sep_conv=True,
decoder_use_sep_conv=True,
encoder_with_aspp=False,
enable_decoder=False,
use_bce_loss=False, use_bce_loss=False,
use_dice_loss=False, use_dice_loss=False,
class_weight=None, class_weight=None,
...@@ -63,19 +58,6 @@ class HumanSegMobile(object): ...@@ -63,19 +58,6 @@ class HumanSegMobile(object):
"dice loss and bce loss is only applicable to binary classfication" "dice loss and bce loss is only applicable to binary classfication"
) )
self.output_stride = output_stride
if backbone not in [
'Xception65', 'Xception41', 'MobileNetV2_x0.25',
'MobileNetV2_x0.5', 'MobileNetV2_x1.0', 'MobileNetV2_x1.5',
'MobileNetV2_x2.0'
]:
raise ValueError(
"backbone: {} is set wrong. it should be one of "
"('Xception65', 'Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5',"
" 'MobileNetV2_x1.0', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0')".
format(backbone))
if class_weight is not None: if class_weight is not None:
if isinstance(class_weight, list): if isinstance(class_weight, list):
if len(class_weight) != num_classes: if len(class_weight) != num_classes:
...@@ -91,16 +73,11 @@ class HumanSegMobile(object): ...@@ -91,16 +73,11 @@ class HumanSegMobile(object):
'Expect class_weight is a list or string but receive {}'. 'Expect class_weight is a list or string but receive {}'.
format(type(class_weight))) format(type(class_weight)))
self.backbone = backbone
self.num_classes = num_classes self.num_classes = num_classes
self.use_bce_loss = use_bce_loss self.use_bce_loss = use_bce_loss
self.use_dice_loss = use_dice_loss self.use_dice_loss = use_dice_loss
self.class_weight = class_weight self.class_weight = class_weight
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.aspp_with_sep_conv = aspp_with_sep_conv
self.decoder_use_sep_conv = decoder_use_sep_conv
self.encoder_with_aspp = encoder_with_aspp
self.enable_decoder = enable_decoder
self.sync_bn = sync_bn self.sync_bn = sync_bn
self.labels = None self.labels = None
...@@ -118,6 +95,9 @@ class HumanSegMobile(object): ...@@ -118,6 +95,9 @@ class HumanSegMobile(object):
self.train_outputs = None self.train_outputs = None
self.test_outputs = None self.test_outputs = None
self.train_data_loader = None self.train_data_loader = None
self.eval_metrics = None
# 当前模型状态
self.status = 'Normal'
def _get_single_car_bs(self, batch_size): def _get_single_car_bs(self, batch_size):
if batch_size % len(self.places) == 0: if batch_size % len(self.places) == 0:
...@@ -130,29 +110,7 @@ class HumanSegMobile(object): ...@@ -130,29 +110,7 @@ class HumanSegMobile(object):
def build_net(self, mode='train'): def build_net(self, mode='train'):
"""应根据不同的情况进行构建""" """应根据不同的情况进行构建"""
model = HumanSeg.nets.DeepLabv3p( pass
self.num_classes,
mode=mode,
backbone=self.backbone,
output_stride=self.output_stride,
aspp_with_sep_conv=self.aspp_with_sep_conv,
decoder_use_sep_conv=self.decoder_use_sep_conv,
encoder_with_aspp=self.encoder_with_aspp,
enable_decoder=self.enable_decoder,
use_bce_loss=self.use_bce_loss,
use_dice_loss=self.use_dice_loss,
class_weight=self.class_weight,
ignore_index=self.ignore_index)
inputs = model.generate_inputs()
model_out = model.build_net(inputs)
outputs = OrderedDict()
if mode == 'train':
self.optimizer.minimize(model_out)
outputs['loss'] = model_out
else:
outputs['pred'] = model_out[0]
outputs['logit'] = model_out[1]
return inputs, outputs
def build_program(self): def build_program(self):
# build training network # build training network
...@@ -276,27 +234,134 @@ class HumanSegMobile(object): ...@@ -276,27 +234,134 @@ class HumanSegMobile(object):
if osp.exists(save_dir): if osp.exists(save_dir):
os.remove(save_dir) os.remove(save_dir)
os.makedirs(save_dir) os.makedirs(save_dir)
fluid.save(self.train_prog, osp.join(save_dir, 'model'))
model_info = self.get_model_info() model_info = self.get_model_info()
if self.status == 'Normal':
fluid.save(self.train_prog, osp.join(save_dir, 'model'))
elif self.status == 'Quant':
float_prog, _ = slim.quant.convert(
self.test_prog, self.exe.place, save_int8=True)
test_input_names = [
var.name for var in list(self.test_inputs.values())
]
test_outputs = list(self.test_outputs.values())
fluid.io.save_inference_model(
dirname=save_dir,
executor=self.exe,
params_filename='__params__',
feeded_var_names=test_input_names,
target_vars=test_outputs,
main_program=float_prog)
model_info['_ModelInputsOutputs'] = dict()
model_info['_ModelInputsOutputs']['test_inputs'] = [
[k, v.name] for k, v in self.test_inputs.items()
]
model_info['_ModelInputsOutputs']['test_outputs'] = [
[k, v.name] for k, v in self.test_outputs.items()
]
model_info['status'] = self.status
with open( with open(
osp.join(save_dir, 'model.yml'), encoding='utf-8', osp.join(save_dir, 'model.yml'), encoding='utf-8',
mode='w') as f: mode='w') as f:
yaml.dump(model_info, f) yaml.dump(model_info, f)
# The flag of model for saving successfully
open(osp.join(save_dir, '.success'), 'w').close() open(osp.join(save_dir, '.success'), 'w').close()
logging.info("Model saved in {}.".format(save_dir)) logging.info("Model saved in {}.".format(save_dir))
def export_inference_model(self, save_dir): def export_inference_model(self, save_dir):
pass test_input_names = [var.name for var in list(self.test_inputs.values())]
test_outputs = list(self.test_outputs.values())
fluid.io.save_inference_model(
dirname=save_dir,
executor=self.exe,
params_filename='__params__',
feeded_var_names=test_input_names,
target_vars=test_outputs,
main_program=self.test_prog)
model_info = self.get_model_info()
model_info['status'] = 'Infer'
def export_quant_model(self): # Save input and output descrition of model
pass model_info['_ModelInputsOutputs'] = dict()
model_info['_ModelInputsOutputs']['test_inputs'] = [
[k, v.name] for k, v in self.test_inputs.items()
]
model_info['_ModelInputsOutputs']['test_outputs'] = [
[k, v.name] for k, v in self.test_outputs.items()
]
with open(
osp.join(save_dir, 'model.yml'), encoding='utf-8',
mode='w') as f:
yaml.dump(model_info, f)
# The flag of model for saving successfully
open(osp.join(save_dir, '.success'), 'w').close()
logging.info("Model for inference deploy saved in {}.".format(save_dir))
def export_quant_model(self,
dataset,
save_dir,
batch_size=1,
batch_nums=10,
cache_dir="./temp"):
self.arrange_transform(transforms=dataset.transforms, mode='quant')
dataset.num_samples = batch_size * batch_nums
try:
from HumanSeg.utils import HumanSegPostTrainingQuantization
except:
raise Exception(
"Model Quantization is not available, try to upgrade your paddlepaddle>=1.7.0"
)
is_use_cache_file = True
if cache_dir is None:
is_use_cache_file = False
post_training_quantization = HumanSegPostTrainingQuantization(
executor=self.exe,
dataset=dataset,
program=self.test_prog,
inputs=self.test_inputs,
outputs=self.test_outputs,
batch_size=batch_size,
batch_nums=batch_nums,
scope=None,
algo='KL',
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False,
is_use_cache_file=is_use_cache_file,
cache_dir=cache_dir)
post_training_quantization.quantize()
post_training_quantization.save_quantized_model(save_dir)
model_info = self.get_model_info()
model_info['status'] = 'Quant'
# Save input and output descrition of model
model_info['_ModelInputsOutputs'] = dict()
model_info['_ModelInputsOutputs']['test_inputs'] = [
[k, v.name] for k, v in self.test_inputs.items()
]
model_info['_ModelInputsOutputs']['test_outputs'] = [
[k, v.name] for k, v in self.test_outputs.items()
]
with open(
osp.join(save_dir, 'model.yml'), encoding='utf-8',
mode='w') as f:
yaml.dump(model_info, f)
# The flag of model for saving successfully
open(osp.join(save_dir, '.success'), 'w').close()
logging.info("Model for quant saved in {}.".format(save_dir))
def default_optimizer(self, def default_optimizer(self,
learning_rate, learning_rate,
num_epochs, num_epochs,
num_steps_each_epoch, num_steps_each_epoch,
lr_decay_power=0.9): lr_decay_power=0.9,
regularization_coeff=4e-5):
decay_step = num_epochs * num_steps_each_epoch decay_step = num_epochs * num_steps_each_epoch
lr_decay = fluid.layers.polynomial_decay( lr_decay = fluid.layers.polynomial_decay(
learning_rate, learning_rate,
...@@ -307,7 +372,7 @@ class HumanSegMobile(object): ...@@ -307,7 +372,7 @@ class HumanSegMobile(object):
lr_decay, lr_decay,
momentum=0.9, momentum=0.9,
regularization=fluid.regularizer.L2Decay( regularization=fluid.regularizer.L2Decay(
regularization_coeff=4e-05)) regularization_coeff=regularization_coeff))
return optimizer return optimizer
def train(self, def train(self,
...@@ -323,7 +388,9 @@ class HumanSegMobile(object): ...@@ -323,7 +388,9 @@ class HumanSegMobile(object):
optimizer=None, optimizer=None,
learning_rate=0.01, learning_rate=0.01,
lr_decay_power=0.9, lr_decay_power=0.9,
use_vdl=False): regularization_coeff=4e-5,
use_vdl=False,
quant=False):
self.labels = train_dataset.labels self.labels = train_dataset.labels
self.train_transforms = train_dataset.transforms self.train_transforms = train_dataset.transforms
self.train_init = locals() self.train_init = locals()
...@@ -335,13 +402,27 @@ class HumanSegMobile(object): ...@@ -335,13 +402,27 @@ class HumanSegMobile(object):
learning_rate=learning_rate, learning_rate=learning_rate,
num_epochs=num_epochs, num_epochs=num_epochs,
num_steps_each_epoch=num_steps_each_epoch, num_steps_each_epoch=num_steps_each_epoch,
lr_decay_power=lr_decay_power) lr_decay_power=lr_decay_power,
regularization_coeff=regularization_coeff)
self.optimizer = optimizer self.optimizer = optimizer
self.build_program() self.build_program()
self.net_initialize( self.net_initialize(
startup_prog=fluid.default_startup_program(), startup_prog=fluid.default_startup_program(),
pretrain_weights=pretrain_weights, pretrain_weights=pretrain_weights,
resume_weights=resume_weights) resume_weights=resume_weights)
# 进行量化
if quant:
# 当 for_test=False ,返回类型为 fluid.CompiledProgram
# 当 for_test=True ,返回类型为 fluid.Program
self.train_prog = slim.quant.quant_aware(
self.train_prog, self.exe.place, for_test=False)
self.test_prog = slim.quant.quant_aware(
self.test_prog, self.exe.place, for_test=True)
# self.parallel_train_prog = self.train_prog.with_data_parallel(
# loss_name=self.train_outputs['loss'].name)
self.status = 'Quant'
if self.begin_epoch >= num_epochs: if self.begin_epoch >= num_epochs:
raise ValueError( raise ValueError(
("begin epoch[{}] is larger than num_epochs[{}]").format( ("begin epoch[{}] is larger than num_epochs[{}]").format(
...@@ -374,11 +455,19 @@ class HumanSegMobile(object): ...@@ -374,11 +455,19 @@ class HumanSegMobile(object):
build_strategy.sync_batch_norm = self.sync_bn build_strategy.sync_batch_norm = self.sync_bn
exec_strategy = fluid.ExecutionStrategy() exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_iteration_per_drop_scope = 1 exec_strategy.num_iteration_per_drop_scope = 1
self.parallel_train_prog = fluid.CompiledProgram( if quant:
self.train_prog).with_data_parallel( build_strategy.fuse_all_reduce_ops = False
build_strategy.sync_batch_norm = False
self.parallel_train_prog = self.train_prog.with_data_parallel(
loss_name=self.train_outputs['loss'].name, loss_name=self.train_outputs['loss'].name,
build_strategy=build_strategy, build_strategy=build_strategy,
exec_strategy=exec_strategy) exec_strategy=exec_strategy)
else:
self.parallel_train_prog = fluid.CompiledProgram(
self.train_prog).with_data_parallel(
loss_name=self.train_outputs['loss'].name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
total_num_steps = math.floor( total_num_steps = math.floor(
train_dataset.num_samples / train_batch_size) train_dataset.num_samples / train_batch_size)
...@@ -478,8 +567,6 @@ class HumanSegMobile(object): ...@@ -478,8 +567,6 @@ class HumanSegMobile(object):
eval_dataset=eval_dataset, eval_dataset=eval_dataset,
batch_size=eval_batch_size, batch_size=eval_batch_size,
epoch_id=i + 1) epoch_id=i + 1)
logging.info('[EVAL] Finished, Epoch={}, {} .'.format(
i + 1, dict2str(self.eval_metrics)))
# 保存最优模型 # 保存最优模型
current_miou = self.eval_metrics['miou'] current_miou = self.eval_metrics['miou']
if current_miou > best_miou: if current_miou > best_miou:
...@@ -567,18 +654,25 @@ class HumanSegMobile(object): ...@@ -567,18 +654,25 @@ class HumanSegMobile(object):
zip(['miou', 'category_iou', 'macc', 'category_acc', 'kappa'], zip(['miou', 'category_iou', 'macc', 'category_acc', 'kappa'],
[miou, category_iou, macc, category_acc, [miou, category_iou, macc, category_acc,
conf_mat.kappa()])) conf_mat.kappa()]))
logging.info('[EVAL] Finished, Epoch={}, {} .'.format(
epoch_id, dict2str(metrics)))
return metrics return metrics
def predict(self, im_file, transforms=None): def predict(self, im_file, transforms=None):
"""预测。 """预测。
Args: Args:
img_file(str): 预测图像路径 img_file(str|np.ndarray): 预测图像
transforms(paddlex.cv.transforms): 数据预处理操作。 transforms(paddlex.cv.transforms): 数据预处理操作。
Returns: Returns:
dict: 包含关键字'label_map'和'score_map', 'label_map'存储预测结果灰度图, dict: 包含关键字'label_map'和'score_map', 'label_map'存储预测结果灰度图,
像素值表示对应的类别,'score_map'存储各类别的概率,shape=(h, w, num_classes) 像素值表示对应的类别,'score_map'存储各类别的概率,shape=(h, w, num_classes)
""" """
if isinstance(im_file, str):
if not osp.exists(im_file):
raise ValueError(
'The Image file does not exist: {}'.format(im_file))
if transforms is None and not hasattr(self, 'test_transforms'): if transforms is None and not hasattr(self, 'test_transforms'):
raise Exception("transforms need to be defined, now is None.") raise Exception("transforms need to be defined, now is None.")
...@@ -612,18 +706,231 @@ class HumanSegMobile(object): ...@@ -612,18 +706,231 @@ class HumanSegMobile(object):
return {'label_map': pred, 'score_map': logit} return {'label_map': pred, 'score_map': logit}
class HumanSegLite(object): class HumanSegMobile(SegModel):
# DeepLab ShuffleNet # DeepLab mobilenet
def train(self): def __init__(self,
pass num_classes=2,
backbone='MobileNetV2_x1.0',
output_stride=16,
aspp_with_sep_conv=True,
decoder_use_sep_conv=True,
encoder_with_aspp=False,
enable_decoder=False,
use_bce_loss=False,
use_dice_loss=False,
class_weight=None,
ignore_index=255,
sync_bn=True):
super().__init__(
num_classes=num_classes,
use_bce_loss=use_bce_loss,
use_dice_loss=use_dice_loss,
class_weight=class_weight,
ignore_index=ignore_index,
sync_bn=sync_bn)
self.init_params = locals()
def evaluate(self): self.output_stride = output_stride
pass
def predict(self): if backbone not in [
pass 'MobileNetV2_x0.25', 'MobileNetV2_x0.5', 'MobileNetV2_x1.0',
'MobileNetV2_x1.5', 'MobileNetV2_x2.0'
]:
raise ValueError(
"backbone: {} is set wrong. it should be one of "
"('MobileNetV2_x0.25', 'MobileNetV2_x0.5',"
" 'MobileNetV2_x1.0', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0')".
format(backbone))
self.backbone = backbone
self.aspp_with_sep_conv = aspp_with_sep_conv
self.decoder_use_sep_conv = decoder_use_sep_conv
self.encoder_with_aspp = encoder_with_aspp
self.enable_decoder = enable_decoder
self.sync_bn = sync_bn
def build_net(self, mode='train'):
model = HumanSeg.nets.DeepLabv3p(
self.num_classes,
mode=mode,
backbone=self.backbone,
output_stride=self.output_stride,
aspp_with_sep_conv=self.aspp_with_sep_conv,
decoder_use_sep_conv=self.decoder_use_sep_conv,
encoder_with_aspp=self.encoder_with_aspp,
enable_decoder=self.enable_decoder,
use_bce_loss=self.use_bce_loss,
use_dice_loss=self.use_dice_loss,
class_weight=self.class_weight,
ignore_index=self.ignore_index)
inputs = model.generate_inputs()
model_out = model.build_net(inputs)
outputs = OrderedDict()
if mode == 'train':
self.optimizer.minimize(model_out)
outputs['loss'] = model_out
else:
outputs['pred'] = model_out[0]
outputs['logit'] = model_out[1]
return inputs, outputs
class HumanSegServer(object): class HumanSegLite(SegModel):
# DeepLab ShuffleNet
def build_net(self, mode='train'):
"""应根据不同的情况进行构建"""
model = HumanSeg.nets.ShuffleSeg(
self.num_classes,
mode=mode,
use_bce_loss=self.use_bce_loss,
use_dice_loss=self.use_dice_loss,
class_weight=self.class_weight,
ignore_index=self.ignore_index)
inputs = model.generate_inputs()
model_out = model.build_net(inputs)
outputs = OrderedDict()
if mode == 'train':
self.optimizer.minimize(model_out)
outputs['loss'] = model_out
else:
outputs['pred'] = model_out[0]
outputs['logit'] = model_out[1]
return inputs, outputs
class HumanSegServer(SegModel):
# DeepLab Xception # DeepLab Xception
pass def __init__(self,
num_classes=2,
backbone='Xception65',
output_stride=16,
aspp_with_sep_conv=True,
decoder_use_sep_conv=True,
encoder_with_aspp=True,
enable_decoder=True,
use_bce_loss=False,
use_dice_loss=False,
class_weight=None,
ignore_index=255,
sync_bn=True):
super().__init__(
num_classes=num_classes,
use_bce_loss=use_bce_loss,
use_dice_loss=use_dice_loss,
class_weight=class_weight,
ignore_index=ignore_index,
sync_bn=sync_bn)
self.init_params = locals()
self.output_stride = output_stride
if backbone not in ['Xception65', 'Xception41']:
raise ValueError("backbone: {} is set wrong. it should be one of "
"('Xception65', 'Xception41')".format(backbone))
self.backbone = backbone
self.aspp_with_sep_conv = aspp_with_sep_conv
self.decoder_use_sep_conv = decoder_use_sep_conv
self.encoder_with_aspp = encoder_with_aspp
self.enable_decoder = enable_decoder
self.sync_bn = sync_bn
def build_net(self, mode='train'):
model = HumanSeg.nets.DeepLabv3p(
self.num_classes,
mode=mode,
backbone=self.backbone,
output_stride=self.output_stride,
aspp_with_sep_conv=self.aspp_with_sep_conv,
decoder_use_sep_conv=self.decoder_use_sep_conv,
encoder_with_aspp=self.encoder_with_aspp,
enable_decoder=self.enable_decoder,
use_bce_loss=self.use_bce_loss,
use_dice_loss=self.use_dice_loss,
class_weight=self.class_weight,
ignore_index=self.ignore_index)
inputs = model.generate_inputs()
model_out = model.build_net(inputs)
outputs = OrderedDict()
if mode == 'train':
self.optimizer.minimize(model_out)
outputs['loss'] = model_out
else:
outputs['pred'] = model_out[0]
outputs['logit'] = model_out[1]
return inputs, outputs
class HRNet(SegModel):
def __init__(self,
num_classes=2,
stage1_num_modules=1,
stage1_num_blocks=[4],
stage1_num_channels=[64],
stage2_num_modules=1,
stage2_num_blocks=[4, 4],
stage2_num_channels=[18, 36],
stage3_num_modules=4,
stage3_num_blocks=[4, 4, 4],
stage3_num_channels=[18, 36, 72],
stage4_num_modules=3,
stage4_num_blocks=[4, 4, 4, 4],
stage4_num_channels=[18, 36, 72, 144],
use_bce_loss=False,
use_dice_loss=False,
class_weight=None,
ignore_index=255,
sync_bn=True):
super().__init__(
num_classes=num_classes,
use_bce_loss=use_bce_loss,
use_dice_loss=use_dice_loss,
class_weight=class_weight,
ignore_index=ignore_index,
sync_bn=sync_bn)
self.init_params = locals()
self.stage1_num_modules = stage1_num_modules
self.stage1_num_blocks = stage1_num_blocks
self.stage1_num_channels = stage1_num_channels
self.stage2_num_modules = stage2_num_modules
self.stage2_num_blocks = stage2_num_blocks
self.stage2_num_channels = stage2_num_channels
self.stage3_num_modules = stage3_num_modules
self.stage3_num_blocks = stage3_num_blocks
self.stage3_num_channels = stage3_num_channels
self.stage4_num_modules = stage4_num_modules
self.stage4_num_blocks = stage4_num_blocks
self.stage4_num_channels = stage4_num_channels
def build_net(self, mode='train'):
"""应根据不同的情况进行构建"""
model = HumanSeg.nets.HRNet(
self.num_classes,
mode=mode,
stage1_num_modules=self.stage1_num_modules,
stage1_num_blocks=self.stage1_num_blocks,
stage1_num_channels=self.stage1_num_channels,
stage2_num_modules=self.stage2_num_modules,
stage2_num_blocks=self.stage2_num_blocks,
stage2_num_channels=self.stage2_num_channels,
stage3_num_modules=self.stage3_num_modules,
stage3_num_blocks=self.stage3_num_blocks,
stage3_num_channels=self.stage3_num_channels,
stage4_num_modules=self.stage4_num_modules,
stage4_num_blocks=self.stage4_num_blocks,
stage4_num_channels=self.stage4_num_channels,
use_bce_loss=self.use_bce_loss,
use_dice_loss=self.use_dice_loss,
class_weight=self.class_weight,
ignore_index=self.ignore_index)
inputs = model.generate_inputs()
model_out = model.build_net(inputs)
outputs = OrderedDict()
if mode == 'train':
self.optimizer.minimize(model_out)
outputs['loss'] = model_out
else:
outputs['pred'] = model_out[0]
outputs['logit'] = model_out[1]
return inputs, outputs
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import yaml
import os.path as osp
import six
import copy
from collections import OrderedDict
import paddle.fluid as fluid
import HumanSeg
import HumanSeg.utils.logging as logging
def load_model(model_dir):
if not osp.exists(osp.join(model_dir, "model.yml")):
raise Exception("There's not model.yml in {}".format(model_dir))
with open(osp.join(model_dir, "model.yml")) as f:
info = yaml.load(f.read(), Loader=yaml.Loader)
status = info['status']
if not hasattr(HumanSeg.models, info['Model']):
raise Exception("There's no attribute {} in HumanSeg.models".format(
info['Model']))
model = getattr(HumanSeg.models, info['Model'])(**info['_init_params'])
if status == "Normal":
startup_prog = fluid.Program()
model.test_prog = fluid.Program()
with fluid.program_guard(model.test_prog, startup_prog):
with fluid.unique_name.guard():
model.test_inputs, model.test_outputs = model.build_net(
mode='test')
model.test_prog = model.test_prog.clone(for_test=True)
model.exe.run(startup_prog)
import pickle
with open(osp.join(model_dir, 'model.pdparams'), 'rb') as f:
load_dict = pickle.load(f)
fluid.io.set_program_state(model.test_prog, load_dict)
elif status in ['Infer', 'Quant']:
[prog, input_names, outputs] = fluid.io.load_inference_model(
model_dir, model.exe, params_filename='__params__')
model.test_prog = prog
test_outputs_info = info['_ModelInputsOutputs']['test_outputs']
model.test_inputs = OrderedDict()
model.test_outputs = OrderedDict()
for name in input_names:
model.test_inputs[name] = model.test_prog.global_block().var(name)
for i, out in enumerate(outputs):
var_desc = test_outputs_info[i]
model.test_outputs[var_desc[0]] = out
if 'test_transforms' in info:
model.test_transforms = build_transforms(info['test_transforms'])
model.eval_transforms = copy.deepcopy(model.test_transforms)
if '_Attributes' in info:
for k, v in info['_Attributes'].items():
if k in model.__dict__:
model.__dict__[k] = v
logging.info("Model[{}] loaded.".format(info['Model']))
return model
def build_transforms(transforms_info):
import HumanSeg.transforms as T
transforms = list()
for op_info in transforms_info:
op_name = list(op_info.keys())[0]
op_attr = op_info[op_name]
if not hasattr(T, op_name):
raise Exception(
"There's no operator named '{}' in transforms".format(op_name))
transforms.append(getattr(T, op_name)(**op_attr))
eval_transforms = T.Compose(transforms)
return eval_transforms
from .backbone import mobilenet_v2 from .backbone import mobilenet_v2
from .backbone import shufflenet_slim
from .backbone import xception from .backbone import xception
from .unet import UNet from .unet import UNet
from .deeplabv3p import DeepLabv3p from .deeplabv3p import DeepLabv3p
from .shufflenet_slim import ShuffleSeg
from .hrnet import HRNet
# coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
from .seg_modules import softmax_with_loss
from .seg_modules import dice_loss
from .seg_modules import bce_loss
from .libs import sigmoid_to_softmax
class HRNet(object):
def __init__(self,
num_classes,
mode='train',
stage1_num_modules=1,
stage1_num_blocks=[4],
stage1_num_channels=[64],
stage2_num_modules=1,
stage2_num_blocks=[4, 4],
stage2_num_channels=[18, 36],
stage3_num_modules=4,
stage3_num_blocks=[4, 4, 4],
stage3_num_channels=[18, 36, 72],
stage4_num_modules=3,
stage4_num_blocks=[4, 4, 4, 4],
stage4_num_channels=[18, 36, 72, 144],
use_bce_loss=False,
use_dice_loss=False,
class_weight=None,
ignore_index=255):
# dice_loss或bce_loss只适用两类分割中
if num_classes > 2 and (use_bce_loss or use_dice_loss):
raise ValueError(
"dice loss and bce loss is only applicable to binary classfication"
)
if class_weight is not None:
if isinstance(class_weight, list):
if len(class_weight) != num_classes:
raise ValueError(
"Length of class_weight should be equal to number of classes"
)
elif isinstance(class_weight, str):
if class_weight.lower() != 'dynamic':
raise ValueError(
"if class_weight is string, must be dynamic!")
else:
raise TypeError(
'Expect class_weight is a list or string but receive {}'.
format(type(class_weight)))
self.num_classes = num_classes
self.mode = mode
self.use_bce_loss = use_bce_loss
self.use_dice_loss = use_dice_loss
self.class_weight = class_weight
self.ignore_index = ignore_index
self.stage1_num_modules = stage1_num_modules
self.stage1_num_blocks = stage1_num_blocks
self.stage1_num_channels = stage1_num_channels
self.stage2_num_modules = stage2_num_modules
self.stage2_num_blocks = stage2_num_blocks
self.stage2_num_channels = stage2_num_channels
self.stage3_num_modules = stage3_num_modules
self.stage3_num_blocks = stage3_num_blocks
self.stage3_num_channels = stage3_num_channels
self.stage4_num_modules = stage4_num_modules
self.stage4_num_blocks = stage4_num_blocks
self.stage4_num_channels = stage4_num_channels
def build_net(self, inputs):
image = inputs['image']
logit = self._high_resolution_net(image, self.num_classes)
if self.num_classes == 1:
out = sigmoid_to_softmax(logit)
out = fluid.layers.transpose(out, [0, 2, 3, 1])
else:
out = fluid.layers.transpose(logit, [0, 2, 3, 1])
pred = fluid.layers.argmax(out, axis=3)
pred = fluid.layers.unsqueeze(pred, axes=[3])
if self.mode == 'train':
label = inputs['label']
mask = label != self.ignore_index
return self._get_loss(logit, label, mask)
else:
if self.num_classes == 1:
logit = sigmoid_to_softmax(logit)
else:
logit = fluid.layers.softmax(logit, axis=1)
return pred, logit
return logit
def generate_inputs(self):
inputs = OrderedDict()
inputs['image'] = fluid.data(
dtype='float32', shape=[None, 3, None, None], name='image')
if self.mode == 'train':
inputs['label'] = fluid.data(
dtype='int32', shape=[None, 1, None, None], name='label')
elif self.mode == 'eval':
inputs['label'] = fluid.data(
dtype='int32', shape=[None, 1, None, None], name='label')
return inputs
def _get_loss(self, logit, label, mask):
avg_loss = 0
if not (self.use_dice_loss or self.use_bce_loss):
avg_loss += softmax_with_loss(
logit,
label,
mask,
num_classes=self.num_classes,
weight=self.class_weight,
ignore_index=self.ignore_index)
else:
if self.use_dice_loss:
avg_loss += dice_loss(logit, label, mask)
if self.use_bce_loss:
avg_loss += bce_loss(
logit, label, mask, ignore_index=self.ignore_index)
return avg_loss
def _conv_bn_layer(self,
input,
filter_size,
num_filters,
stride=1,
padding=1,
num_groups=1,
if_act=True,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=num_groups,
act=None,
param_attr=ParamAttr(initializer=MSRA(), name=name + '_weights'),
bias_attr=False)
bn_name = name + '_bn'
bn = fluid.layers.batch_norm(
input=conv,
param_attr=ParamAttr(
name=bn_name + "_scale",
initializer=fluid.initializer.Constant(1.0)),
bias_attr=ParamAttr(
name=bn_name + "_offset",
initializer=fluid.initializer.Constant(0.0)),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
if if_act:
bn = fluid.layers.relu(bn)
return bn
def _basic_block(self,
input,
num_filters,
stride=1,
downsample=False,
name=None):
residual = input
conv = self._conv_bn_layer(
input=input,
filter_size=3,
num_filters=num_filters,
stride=stride,
name=name + '_conv1')
conv = self._conv_bn_layer(
input=conv,
filter_size=3,
num_filters=num_filters,
if_act=False,
name=name + '_conv2')
if downsample:
residual = self._conv_bn_layer(
input=input,
filter_size=1,
num_filters=num_filters,
if_act=False,
name=name + '_downsample')
return fluid.layers.elementwise_add(x=residual, y=conv, act='relu')
def _bottleneck_block(self,
input,
num_filters,
stride=1,
downsample=False,
name=None):
residual = input
conv = self._conv_bn_layer(
input=input,
filter_size=1,
num_filters=num_filters,
name=name + '_conv1')
conv = self._conv_bn_layer(
input=conv,
filter_size=3,
num_filters=num_filters,
stride=stride,
name=name + '_conv2')
conv = self._conv_bn_layer(
input=conv,
filter_size=1,
num_filters=num_filters * 4,
if_act=False,
name=name + '_conv3')
if downsample:
residual = self._conv_bn_layer(
input=input,
filter_size=1,
num_filters=num_filters * 4,
if_act=False,
name=name + '_downsample')
return fluid.layers.elementwise_add(x=residual, y=conv, act='relu')
def _fuse_layers(self, x, channels, multi_scale_output=True, name=None):
out = []
for i in range(len(channels) if multi_scale_output else 1):
residual = x[i]
shape = fluid.layers.shape(residual)[-2:]
for j in range(len(channels)):
if j > i:
y = self._conv_bn_layer(
x[j],
filter_size=1,
num_filters=channels[i],
if_act=False,
name=name + '_layer_' + str(i + 1) + '_' + str(j + 1))
y = fluid.layers.resize_bilinear(input=y, out_shape=shape)
residual = fluid.layers.elementwise_add(
x=residual, y=y, act=None)
elif j < i:
y = x[j]
for k in range(i - j):
if k == i - j - 1:
y = self._conv_bn_layer(
y,
filter_size=3,
num_filters=channels[i],
stride=2,
if_act=False,
name=name + '_layer_' + str(i + 1) + '_' +
str(j + 1) + '_' + str(k + 1))
else:
y = self._conv_bn_layer(
y,
filter_size=3,
num_filters=channels[j],
stride=2,
name=name + '_layer_' + str(i + 1) + '_' +
str(j + 1) + '_' + str(k + 1))
residual = fluid.layers.elementwise_add(
x=residual, y=y, act=None)
residual = fluid.layers.relu(residual)
out.append(residual)
return out
def _branches(self, x, block_num, channels, name=None):
out = []
for i in range(len(channels)):
residual = x[i]
for j in range(block_num[i]):
residual = self._basic_block(
residual,
channels[i],
name=name + '_branch_layer_' + str(i + 1) + '_' +
str(j + 1))
out.append(residual)
return out
def _high_resolution_module(self,
x,
blocks,
channels,
multi_scale_output=True,
name=None):
residual = self._branches(x, blocks, channels, name=name)
out = self._fuse_layers(
residual,
channels,
multi_scale_output=multi_scale_output,
name=name)
return out
def _transition_layer(self, x, in_channels, out_channels, name=None):
num_in = len(in_channels)
num_out = len(out_channels)
out = []
for i in range(num_out):
if i < num_in:
if in_channels[i] != out_channels[i]:
residual = self._conv_bn_layer(
x[i],
filter_size=3,
num_filters=out_channels[i],
name=name + '_layer_' + str(i + 1))
out.append(residual)
else:
out.append(x[i])
else:
residual = self._conv_bn_layer(
x[-1],
filter_size=3,
num_filters=out_channels[i],
stride=2,
name=name + '_layer_' + str(i + 1))
out.append(residual)
return out
def _stage(self,
x,
num_modules,
num_blocks,
num_channels,
multi_scale_output=True,
name=None):
out = x
for i in range(num_modules):
if i == num_modules - 1 and multi_scale_output == False:
out = self._high_resolution_module(
out,
num_blocks,
num_channels,
multi_scale_output=False,
name=name + '_' + str(i + 1))
else:
out = self._high_resolution_module(
out, num_blocks, num_channels, name=name + '_' + str(i + 1))
return out
def _layer1(self, input, num_modules, num_blocks, num_channels, name=None):
# num_modules 默认为1,是否增加处理,官网实现为[1],是否对齐。
conv = input
for i in range(num_blocks[0]):
conv = self._bottleneck_block(
conv,
num_filters=num_channels[0],
downsample=True if i == 0 else False,
name=name + '_' + str(i + 1))
return conv
def _high_resolution_net(self, input, num_classes):
x = self._conv_bn_layer(
input=input,
filter_size=3,
num_filters=self.stage1_num_channels[0],
stride=2,
if_act=True,
name='layer1_1')
x = self._conv_bn_layer(
input=x,
filter_size=3,
num_filters=self.stage1_num_channels[0],
stride=2,
if_act=True,
name='layer1_2')
la1 = self._layer1(
x,
self.stage1_num_modules,
self.stage1_num_blocks,
self.stage1_num_channels,
name='layer2')
tr1 = self._transition_layer([la1],
self.stage1_num_channels,
self.stage2_num_channels,
name='tr1')
st2 = self._stage(
tr1,
self.stage2_num_modules,
self.stage2_num_blocks,
self.stage2_num_channels,
name='st2')
tr2 = self._transition_layer(
st2, self.stage2_num_channels, self.stage3_num_channels, name='tr2')
st3 = self._stage(
tr2,
self.stage3_num_modules,
self.stage3_num_blocks,
self.stage3_num_channels,
name='st3')
tr3 = self._transition_layer(
st3, self.stage3_num_channels, self.stage4_num_channels, name='tr3')
st4 = self._stage(
tr3,
self.stage4_num_modules,
self.stage4_num_blocks,
self.stage4_num_channels,
name='st4')
# upsample
shape = fluid.layers.shape(st4[0])[-2:]
st4[1] = fluid.layers.resize_bilinear(st4[1], out_shape=shape)
st4[2] = fluid.layers.resize_bilinear(st4[2], out_shape=shape)
st4[3] = fluid.layers.resize_bilinear(st4[3], out_shape=shape)
out = fluid.layers.concat(st4, axis=1)
last_channels = sum(self.stage4_num_channels)
out = self._conv_bn_layer(
input=out,
filter_size=1,
num_filters=last_channels,
stride=1,
if_act=True,
name='conv-2')
out = fluid.layers.conv2d(
input=out,
num_filters=num_classes,
filter_size=1,
stride=1,
padding=0,
act=None,
param_attr=ParamAttr(initializer=MSRA(), name='conv-1_weights'),
bias_attr=False)
input_shape = fluid.layers.shape(input)[-2:]
out = fluid.layers.resize_bilinear(out, input_shape)
return out
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
from .libs import sigmoid_to_softmax
from .seg_modules import softmax_with_loss
from .seg_modules import dice_loss
from .seg_modules import bce_loss
class ShuffleSeg(object):
# def __init__(self):
# self.params = train_parameters
def __init__(self,
num_classes,
mode='train',
use_bce_loss=False,
use_dice_loss=False,
class_weight=None,
ignore_index=255):
# dice_loss或bce_loss只适用两类分割中
if num_classes > 2 and (use_bce_loss or use_dice_loss):
raise ValueError(
"dice loss and bce loss is only applicable to binary classfication"
)
if class_weight is not None:
if isinstance(class_weight, list):
if len(class_weight) != num_classes:
raise ValueError(
"Length of class_weight should be equal to number of classes"
)
elif isinstance(class_weight, str):
if class_weight.lower() != 'dynamic':
raise ValueError(
"if class_weight is string, must be dynamic!")
else:
raise TypeError(
'Expect class_weight is a list or string but receive {}'.
format(type(class_weight)))
self.num_classes = num_classes
self.mode = mode
self.use_bce_loss = use_bce_loss
self.use_dice_loss = use_dice_loss
self.class_weight = class_weight
self.ignore_index = ignore_index
def _get_loss(self, logit, label, mask):
avg_loss = 0
if not (self.use_dice_loss or self.use_bce_loss):
avg_loss += softmax_with_loss(
logit,
label,
mask,
num_classes=self.num_classes,
weight=self.class_weight,
ignore_index=self.ignore_index)
else:
if self.use_dice_loss:
avg_loss += dice_loss(logit, label, mask)
if self.use_bce_loss:
avg_loss += bce_loss(
logit, label, mask, ignore_index=self.ignore_index)
return avg_loss
def generate_inputs(self):
inputs = OrderedDict()
inputs['image'] = fluid.data(
dtype='float32', shape=[None, 3, None, None], name='image')
if self.mode == 'train':
inputs['label'] = fluid.data(
dtype='int32', shape=[None, 1, None, None], name='label')
elif self.mode == 'eval':
inputs['label'] = fluid.data(
dtype='int32', shape=[None, 1, None, None], name='label')
return inputs
def build_net(self, inputs, class_dim=2):
if self.use_dice_loss or self.use_bce_loss:
self.num_classes = 1
image = inputs['image']
## Encoder
conv1 = self.conv_bn(image, 3, 36, 2, 1)
print('encoder 1', conv1.shape)
shortcut = self.conv_bn(
input=conv1, filter_size=1, num_filters=18, stride=1, padding=0)
print('shortcut 1', shortcut.shape)
pool = fluid.layers.pool2d(
input=conv1,
pool_size=3,
pool_type='max',
pool_stride=2,
pool_padding=1)
print('encoder 2', pool.shape)
# Block 1
conv = self.sfnetv2module(pool, stride=2, num_filters=72)
conv = self.sfnetv2module(conv, stride=1)
conv = self.sfnetv2module(conv, stride=1)
conv = self.sfnetv2module(conv, stride=1)
print('encoder 3', conv.shape)
# Block 2
conv = self.sfnetv2module(conv, stride=2)
conv = self.sfnetv2module(conv, stride=1)
conv = self.sfnetv2module(conv, stride=1)
conv = self.sfnetv2module(conv, stride=1)
conv = self.sfnetv2module(conv, stride=1)
conv = self.sfnetv2module(conv, stride=1)
conv = self.sfnetv2module(conv, stride=1)
conv = self.sfnetv2module(conv, stride=1)
print('encoder 4', conv.shape)
### decoder
conv = self.depthwise_separable(conv, 3, 64, 1)
shortcut_shape = fluid.layers.shape(shortcut)[2:]
conv_b = fluid.layers.resize_bilinear(conv, shortcut_shape)
concat = fluid.layers.concat([shortcut, conv_b], axis=1)
decode_conv = self.depthwise_separable(concat, 3, 64, 1)
logit = self.output_layer(decode_conv, class_dim)
if self.num_classes == 1:
out = sigmoid_to_softmax(logit)
out = fluid.layers.transpose(out, [0, 2, 3, 1])
else:
out = fluid.layers.transpose(logit, [0, 2, 3, 1])
pred = fluid.layers.argmax(out, axis=3)
pred = fluid.layers.unsqueeze(pred, axes=[3])
if self.mode == 'train':
label = inputs['label']
mask = label != self.ignore_index
return self._get_loss(logit, label, mask)
else:
if self.num_classes == 1:
logit = sigmoid_to_softmax(logit)
else:
logit = fluid.layers.softmax(logit, axis=1)
return pred, logit
return logit
def conv_bn(self,
input,
filter_size,
num_filters,
stride,
padding,
channels=None,
num_groups=1,
act='relu',
use_cudnn=True):
parameter_attr = ParamAttr(learning_rate=1, initializer=MSRA())
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=parameter_attr,
bias_attr=False)
return fluid.layers.batch_norm(input=conv, act=act)
def depthwise_separable(self, input, filter_size, num_filters, stride):
num_filters1 = int(input.shape[1])
num_groups = num_filters1
depthwise_conv = self.conv_bn(
input=input,
filter_size=filter_size,
num_filters=int(num_filters1),
stride=stride,
padding=int(filter_size / 2),
num_groups=num_groups,
use_cudnn=False,
act=None)
pointwise_conv = self.conv_bn(
input=depthwise_conv,
filter_size=1,
num_filters=num_filters,
stride=1,
padding=0)
return pointwise_conv
def sfnetv2module(self, input, stride, num_filters=None):
if stride == 1:
shortcut, branch = fluid.layers.split(
input, num_or_sections=2, dim=1)
if num_filters is None:
in_channels = int(branch.shape[1])
else:
in_channels = int(num_filters / 2)
else:
branch = input
if num_filters is None:
in_channels = int(branch.shape[1])
else:
in_channels = int(num_filters / 2)
shortcut = self.depthwise_separable(input, 3, in_channels, stride)
branch_1x1 = self.conv_bn(
input=branch,
filter_size=1,
num_filters=int(in_channels),
stride=1,
padding=0)
branch_dw1x1 = self.depthwise_separable(branch_1x1, 3, in_channels,
stride)
output = fluid.layers.concat(input=[shortcut, branch_dw1x1], axis=1)
# channel shuffle
# b, c, h, w = output.shape
shape = fluid.layers.shape(output)
c = output.shape[1]
b, h, w = shape[0], shape[2], shape[3]
output = fluid.layers.reshape(x=output, shape=[b, 2, in_channels, h, w])
output = fluid.layers.transpose(x=output, perm=[0, 2, 1, 3, 4])
output = fluid.layers.reshape(x=output, shape=[b, c, h, w])
return output
def output_layer(self, input, out_dim):
param_attr = fluid.param_attr.ParamAttr(
learning_rate=1.,
regularizer=fluid.regularizer.L2Decay(0.),
initializer=fluid.initializer.Xavier())
# deconv
output = fluid.layers.conv2d_transpose(
input=input,
num_filters=out_dim,
filter_size=2,
padding=0,
stride=2,
bias_attr=True,
param_attr=param_attr,
act=None)
return output
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import setuptools
long_descrition = 'HumanSeg'
setuptools.setup(
name='HumanSeg',
version='1.0.0',
author='paddleseg',
description=long_descrition,
long_descrition=long_descrition,
packages='./',
setup_requires=['cython', 'numpy'],
install_requires=['pyyaml', 'tqdm', 'visualdl==1.3.0', 'paddleslim==1.0.1'],
license='Apache 2.0')
...@@ -12,5 +12,5 @@ ...@@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from . import transforms from .transforms import *
from . import functional from . import functional
...@@ -58,13 +58,14 @@ class Compose: ...@@ -58,13 +58,14 @@ class Compose:
if im_info is None: if im_info is None:
im_info = dict() im_info = dict()
im = cv2.imread(im).astype('float32') if isinstance(im, str):
im = cv2.imread(im).astype('float32')
if isinstance(label, str):
label = np.asarray(Image.open(label))
if im is None: if im is None:
raise ValueError('Can\'t read The image file {}!'.format(im)) raise ValueError('Can\'t read The image file {}!'.format(im))
if self.to_rgb: if self.to_rgb:
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
if label is not None:
label = np.asarray(Image.open(label))
for op in self.transforms: for op in self.transforms:
outputs = op(im, im_info, label) outputs = op(im, im_info, label)
......
import HumanSeg
model = HumanSeg.models.load_model('output/best_model')
model.export_inference_model('output/export')
import HumanSeg
from HumanSeg.utils import visualize
im_file = '/ssd1/chenguowei01/dataset/humanseg/supervise.ly/pexel/img/person_detection__ds6/img/pexels-photo-704264.jpg'
model = HumanSeg.models.load_model('output/best_model')
result = model.predict(im_file)
visualize(im_file, result, save_dir='output/')
import os
import numpy as np
from HumanSeg.datasets.dataset import Dataset
from HumanSeg.models import HumanSegMobile
from HumanSeg.transforms import transforms
train_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.Resize((192, 192)),
transforms.Normalize()
])
eval_transforms = transforms.Compose(
[transforms.Resize((192, 192)),
transforms.Normalize()])
data_dir = '/ssd1/chenguowei01/dataset/humanseg/supervise.ly'
train_list = '/ssd1/chenguowei01/dataset/humanseg/supervise.ly/train.txt'
val_list = '/ssd1/chenguowei01/dataset/humanseg/supervise.ly/val.txt'
train_dataset = Dataset(
data_dir=data_dir,
file_list=train_list,
transforms=train_transforms,
num_workers='auto',
buffer_size=100,
parallel_method='thread',
shuffle=True)
eval_dataset = Dataset(
data_dir=data_dir,
file_list=val_list,
transforms=eval_transforms,
num_workers='auto',
buffer_size=100,
parallel_method='thread',
shuffle=False)
model = HumanSegMobile(num_classes=2)
model.train(
num_epochs=100,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
save_interval_epochs=5,
train_batch_size=256,
# resume_weights='/Users/chenguowei01/PycharmProjects/github/PaddleSeg/contrib/HumanSeg/output/epoch_20',
log_interval_steps=2,
save_dir='output',
use_vdl=True,
)
import HumanSeg
from HumanSeg.datasets.dataset import Dataset
from HumanSeg.transforms import transforms
eval_transforms = transforms.Compose(
[transforms.Resize((192, 192)),
transforms.Normalize()])
data_dir = '/ssd1/chenguowei01/dataset/humanseg/supervise.ly'
val_list = '/ssd1/chenguowei01/dataset/humanseg/supervise.ly/val.txt'
eval_dataset = Dataset(
data_dir=data_dir,
file_list=val_list,
transforms=eval_transforms,
num_workers='auto',
buffer_size=100,
parallel_method='thread',
shuffle=False)
model = HumanSeg.models.load_model('output/best_model')
model.evaluate(eval_dataset, 2)
...@@ -16,3 +16,4 @@ from . import logging ...@@ -16,3 +16,4 @@ from . import logging
from . import utils from . import utils
from .metrics import ConfusionMatrix from .metrics import ConfusionMatrix
from .utils import * from .utils import *
from .post_quantization import HumanSegPostTrainingQuantization
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid.contrib.slim.quantization.quantization_pass import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization.quantization_pass import AddQuantDequantPass
from paddle.fluid.contrib.slim.quantization.quantization_pass import _op_real_in_out_name
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
import paddle.fluid as fluid
import os
import HumanSeg.utils.logging as logging
class HumanSegPostTrainingQuantization(PostTrainingQuantization):
def __init__(self,
executor,
dataset,
program,
inputs,
outputs,
batch_size=10,
batch_nums=None,
scope=None,
algo="KL",
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False,
is_use_cache_file=False,
cache_dir="./temp_post_training"):
'''
The class utilizes post training quantization methon to quantize the
fp32 model. It uses calibrate data to calculate the scale factor of
quantized variables, and inserts fake quant/dequant op to obtain the
quantized model.
Args:
executor(fluid.Executor): The executor to load, run and save the
quantized model.
dataset(Python Iterator): The data Reader.
program(fluid.Program): The paddle program, save the parameters for model.
inputs(dict): The input of prigram.
outputs(dict): The output of program.
batch_size(int, optional): The batch size of DataLoader. Default is 10.
batch_nums(int, optional): If batch_nums is not None, the number of
calibrate data is batch_size*batch_nums. If batch_nums is None, use
all data provided by sample_generator as calibrate data.
scope(fluid.Scope, optional): The scope of the program, use it to load
and save variables. If scope=None, get scope by global_scope().
algo(str, optional): If algo=KL, use KL-divergenc method to
get the more precise scale factor. If algo='direct', use
abs_max methon to get the scale factor. Default is KL.
quantizable_op_type(list[str], optional): List the type of ops
that will be quantized. Default is ["conv2d", "depthwise_conv2d",
"mul"].
is_full_quantized(bool, optional): If set is_full_quantized as True,
apply quantization to all supported quantizable op type. If set
is_full_quantized as False, only apply quantization to the op type
according to the input quantizable_op_type.
is_use_cache_file(bool, optional): If set is_use_cache_file as False,
all temp data will be saved in memory. If set is_use_cache_file as True,
it will save temp data to disk. When the fp32 model is complex or
the number of calibrate data is large, we should set is_use_cache_file
as True. Defalut is False.
cache_dir(str, optional): When is_use_cache_file is True, set cache_dir as
the directory for saving temp data. Default is ./temp_post_training.
Returns:
None
'''
self._executor = executor
self._dataset = dataset
self._batch_size = batch_size
self._batch_nums = batch_nums
self._scope = fluid.global_scope() if scope == None else scope
self._algo = algo
self._is_use_cache_file = is_use_cache_file
self._cache_dir = cache_dir
if self._is_use_cache_file and not os.path.exists(self._cache_dir):
os.mkdir(self._cache_dir)
supported_quantizable_op_type = \
QuantizationTransformPass._supported_quantizable_op_type + \
AddQuantDequantPass._supported_quantizable_op_type
if is_full_quantize:
self._quantizable_op_type = supported_quantizable_op_type
else:
self._quantizable_op_type = quantizable_op_type
for op_type in self._quantizable_op_type:
assert op_type in supported_quantizable_op_type + \
AddQuantDequantPass._activation_type, \
op_type + " is not supported for quantization."
self._place = self._executor.place
self._program = program
self._feed_list = list(inputs.values())
self._fetch_list = list(outputs.values())
self._data_loader = None
self._op_real_in_out_name = _op_real_in_out_name
self._bit_length = 8
self._quantized_weight_var_name = set()
self._quantized_act_var_name = set()
self._sampling_data = {}
self._quantized_var_scale_factor = {}
def quantize(self):
'''
Quantize the fp32 model. Use calibrate data to calculate the scale factor of
quantized variables, and inserts fake quant/dequant op to obtain the
quantized model.
Args:
None
Returns:
the program of quantized model.
'''
self._preprocess()
batch_id = 0
for data in self._data_loader():
self._executor.run(
program=self._program,
feed=data,
fetch_list=self._fetch_list,
return_numpy=False)
self._sample_data(batch_id)
if batch_id % 5 == 0:
logging.info("run batch: {}".format(batch_id))
batch_id += 1
if self._batch_nums and batch_id >= self._batch_nums:
break
logging.info("all run batch: ".format(batch_id))
logging.info("calculate scale factor ...")
self._calculate_scale_factor()
logging.info("update the program ...")
self._update_program()
self._save_output_scale()
return self._program
def save_quantized_model(self, save_model_path):
'''
Save the quantized model to the disk.
Args:
save_model_path(str): The path to save the quantized model
Returns:
None
'''
feed_vars_names = [var.name for var in self._feed_list]
fluid.io.save_inference_model(
dirname=save_model_path,
feeded_var_names=feed_vars_names,
target_vars=self._fetch_list,
executor=self._executor,
params_filename='__params__',
main_program=self._program)
def _preprocess(self):
'''
Load model and set data loader, collect the variable names for sampling,
and set activation variables to be persistable.
'''
feed_vars = [fluid.framework._get_var(var.name, self._program) \
for var in self._feed_list]
self._data_loader = fluid.io.DataLoader.from_generator(
feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True)
self._data_loader.set_sample_list_generator(
self._dataset.generator(self._batch_size, drop_last=True),
places=self._place)
# collect the variable names for sampling
persistable_var_names = []
for var in self._program.list_vars():
if var.persistable:
persistable_var_names.append(var.name)
for op in self._program.global_block().ops:
op_type = op.type
if op_type in self._quantizable_op_type:
if op_type in ("conv2d", "depthwise_conv2d"):
self._quantized_act_var_name.add(op.input("Input")[0])
self._quantized_weight_var_name.add(op.input("Filter")[0])
self._quantized_act_var_name.add(op.output("Output")[0])
elif op_type == "mul":
if self._is_input_all_not_persistable(
op, persistable_var_names):
op._set_attr("skip_quant", True)
logging.warning(
"Skip quant a mul op for two input variables are not persistable"
)
else:
self._quantized_act_var_name.add(op.input("X")[0])
self._quantized_weight_var_name.add(op.input("Y")[0])
self._quantized_act_var_name.add(op.output("Out")[0])
else:
# process other quantizable op type, the input must all not persistable
if self._is_input_all_not_persistable(
op, persistable_var_names):
input_output_name_list = self._op_real_in_out_name[
op_type]
for input_name in input_output_name_list[0]:
for var_name in op.input(input_name):
self._quantized_act_var_name.add(var_name)
for output_name in input_output_name_list[1]:
for var_name in op.output(output_name):
self._quantized_act_var_name.add(var_name)
# set activation variables to be persistable, so can obtain
# the tensor data in sample_data
for var in self._program.list_vars():
if var.name in self._quantized_act_var_name:
var.persistable = True
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
import six import six
import yaml import yaml
import math import math
import cv2
from . import logging from . import logging
...@@ -218,3 +219,58 @@ def load_pretrain_weights(exe, main_prog, weights_dir, fuse_bn=False): ...@@ -218,3 +219,58 @@ def load_pretrain_weights(exe, main_prog, weights_dir, fuse_bn=False):
len(vars_to_load), weights_dir)) len(vars_to_load), weights_dir))
if fuse_bn: if fuse_bn:
fuse_bn_weights(exe, main_prog, weights_dir) fuse_bn_weights(exe, main_prog, weights_dir)
def visualize(image, result, save_dir=None, weight=0.6):
"""
Convert segment result to color image, and save added image.
Args:
image: the path of origin image
result: the predict result of image
save_dir: the directory for saving visual image
weight: the image weight of visual image, and the result weight is (1 - weight)
"""
label_map = result['label_map']
color_map = get_color_map_list(256)
color_map = np.array(color_map).astype("uint8")
# Use OpenCV LUT for color mapping
c1 = cv2.LUT(label_map, color_map[:, 0])
c2 = cv2.LUT(label_map, color_map[:, 1])
c3 = cv2.LUT(label_map, color_map[:, 2])
pseudo_img = np.dstack((c1, c2, c3))
im = cv2.imread(image)
vis_result = cv2.addWeighted(im, weight, pseudo_img, 1 - weight, 0)
if save_dir is not None:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
image_name = os.path.split(image)[-1]
out_path = os.path.join(save_dir, image_name)
cv2.imwrite(out_path, vis_result)
else:
return vis_result
def get_color_map_list(num_classes):
""" Returns the color map for visualizing the segmentation mask,
which can support arbitrary number of classes.
Args:
num_classes: Number of classes
Returns:
The color map
"""
num_classes += 1
color_map = num_classes * [0, 0, 0]
for i in range(0, num_classes):
j = 0
lab = i
while lab:
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
j += 1
lab >>= 3
color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
color_map = color_map[1:]
return color_map
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册