提交 febaba67 编写于 作者: C chenguowei01

change models name

上级 4968e6c2
......@@ -18,9 +18,9 @@ $ pip install -r requirements.txt
HumanSeg开放了在大规模人像数据上训练的三个预训练模型,满足多种使用场景的需求
| 模型类型 | Checkpoint | Inference Model | Quant Inference Model | 备注 |
| --- | --- | --- | --- | --- |
| HumanSeg-server | [humanseg_server_ckpt](https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_server.zip) | [humanseg_server_export](https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_server_export.zip) | -- | 高精度模型,适用于服务端GPU且背景复杂的人像场景 |
| HumanSeg-mobile | [humanseg_mobile_ckpt](https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_mobile.zip) | [humanseg_mobile_export](https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_mobile_export.zip) | [humanseg_mobile_quant](https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_mobile_quant.zip) | 轻量级模型, 适用于移动端或服务端CPU的前置摄像头场景 |
| HumanSeg-lite | [humanseg_lite](https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_lite.zip) | [humanseg_lite_export](https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_lite_export.zip) | [humanseg_lite_quant](https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_lite_quant.zip) | 超轻量级模型, 适用于手机自拍人像,且有移动端实时分割场景 |
| HumanSeg-server | [humanseg_server_ckpt](https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_server_ckpt.zip) | [humanseg_server_inference](https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_server_inference.zip) | -- | 高精度模型,适用于服务端GPU且背景复杂的人像场景 |
| HumanSeg-mobile | [humanseg_mobile_ckpt](https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_mobile_ckpt.zip) | [humanseg_mobile_inference](https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_mobile_inference.zip) | [humanseg_mobile_quant](https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_mobile_quant.zip) | 轻量级模型, 适用于移动端或服务端CPU的前置摄像头场景 |
| HumanSeg-lite | [humanseg_lite_ckpt](https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_lite_ckpt.zip) | [humanseg_lite_inference](https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_lite_inference.zip) | [humanseg_lite_quant](https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_lite_quant.zip) | 超轻量级模型, 适用于手机自拍人像,且有移动端实时分割场景 |
**NOTE:**
其中Checkpoint为模型权重,用于Fine-tuning场景。
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# 人像分割预测部署
本模型基于飞浆开源的人像分割模型,并做了大量的针对视频的光流追踪优化,提供了完整的支持视频流的人像分割解决方案,并提供了高性能的`Python`集成部署方案。
## 模型下载
支持的模型文件如下,请根据应用场景选择合适的模型:
|模型文件 | 说明 |
| --- | --- |
|[humanseg_lite_quant]() | 小模型, 适合轻量级计算环境 |
|[humanseg_lite]()| 小模型,适合轻量级计算环境 |
|[humanseg_mobile_quant]() | 小模型, 适合轻量级计算环境 |
|[humanseg_mobile]()| 小模型,适合轻量级计算环境 |
|[humanseg_server_quant]() | 服务端GPU环境 |
|[humanseg_server]() | 服务端GPU环境 |
**注意:下载后解压到合适的路径,后续该路径将做为预测参数用于加载模型。**
## 预测部署
- [Python预测部署](./python)
## 效果预览
<figure class="half">
<img src="https://paddleseg.bj.bcebos.com/deploy/data/input.gif">
<img src="https://paddleseg.bj.bcebos.com/deploy/data/output.gif">
</figure>
# 人像分割Python预测部署方案
本方案基于Python实现,最小化依赖并把所有模型加载、数据预处理、预测、光流处理等后处理都封装在文件`infer.py`中,用户可以直接使用或集成到自己项目中。
## 前置依赖
- Windows(7,8,10) / Linux (Ubuntu 16.04) or MacOS 10.1+
- Paddle 1.7+
- Python 3.6+
注意:
1. 仅支持Paddle 1.7以上版本
2. MacOS上不支持GPU预测
其它未涉及情形,能正常安装`Paddle``OpenCV`通常都能正常使用。
## 安装依赖
执行如下命令
```shell
pip install -r requirements.txt
```
## 运行
1. 输入图片进行分割
```
python infer.py --model_dir /PATH/TO/INFERENCE/MODEL --img_path /PATH/TO/INPUT/IMAGE
```
预测结果会保存为`result.jpeg`
2. 输入视频进行分割
```shell
python infer.py --model_dir /PATH/TO/INFERENCE/MODEL --video_path /PATH/TO/INPUT/VIDEO
```
预测结果会保存在`result.avi`
3. 使用摄像头视频流
```shell
python infer.py --model_dir /PATH/TO/INFERENCE/MODEL --use_camera True
```
预测结果会通过可视化窗口实时显示。
**注意:**
`GPU`默认关闭, 如果要使用`GPU`进行加速,则先运行
```
export CUDA_VISIBLE_DEVICES=0
```
然后在前面的预测命令中增加参数`--use_gpu True`即可。
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""实时人像分割Python预测部署"""
import os
import argparse
import numpy as np
import cv2
import paddle.fluid as fluid
def humanseg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow):
"""计算光流跟踪匹配点和光流图
输入参数:
pre_gray: 上一帧灰度图
cur_gray: 当前帧灰度图
prev_cfd: 上一帧光流图
dl_weights: 融合权重图
disflow: 光流数据结构
返回值:
is_track: 光流点跟踪二值图,即是否具有光流点匹配
track_cfd: 光流跟踪图
"""
check_thres = 8
hgt, wdh = pre_gray.shape[:2]
track_cfd = np.zeros_like(prev_cfd)
is_track = np.zeros_like(pre_gray)
# 计算前向光流
flow_fw = disflow.calc(pre_gray, cur_gray, None)
# 计算后向光流
flow_bw = disflow.calc(cur_gray, pre_gray, None)
get_round = lambda data: (int)(data + 0.5) if data >= 0 else (int)(data -
0.5)
for row in range(hgt):
for col in range(wdh):
# 计算光流处理后对应点坐标
# (row, col) -> (cur_x, cur_y)
fxy_fw = flow_fw[row, col]
dx_fw = get_round(fxy_fw[0])
cur_x = dx_fw + col
dy_fw = get_round(fxy_fw[1])
cur_y = dy_fw + row
if cur_x < 0 or cur_x >= wdh or cur_y < 0 or cur_y >= hgt:
continue
fxy_bw = flow_bw[cur_y, cur_x]
dx_bw = get_round(fxy_bw[0])
dy_bw = get_round(fxy_bw[1])
# 光流移动小于阈值
lmt = ((dy_fw + dy_bw) * (dy_fw + dy_bw) +
(dx_fw + dx_bw) * (dx_fw + dx_bw))
if lmt >= check_thres:
continue
# 静止点降权
if abs(dy_fw) <= 0 and abs(dx_fw) <= 0 and abs(dy_bw) <= 0 and abs(
dx_bw) <= 0:
dl_weights[cur_y, cur_x] = 0.05
is_track[cur_y, cur_x] = 1
track_cfd[cur_y, cur_x] = prev_cfd[row, col]
return track_cfd, is_track, dl_weights
def humanseg_track_fuse(track_cfd, dl_cfd, dl_weights, is_track):
"""光流追踪图和人像分割结构融合
输入参数:
track_cfd: 光流追踪图
dl_cfd: 当前帧分割结果
dl_weights: 融合权重图
is_track: 光流点匹配二值图
返回值:
cur_cfd: 光流跟踪图和人像分割结果融合图
"""
cur_cfd = dl_cfd.copy()
idxs = np.where(is_track > 0)
for i in range(len(idxs)):
x, y = idxs[0][i], idxs[1][i]
dl_score = dl_cfd[y, x]
track_score = track_cfd[y, x]
if dl_score > 0.9 or dl_score < 0.1:
if dl_weights[x, y] < 0.1:
cur_cfd[x, y] = 0.3 * dl_score + 0.7 * track_score
else:
cur_cfd[x, y] = 0.4 * dl_score + 0.6 * track_score
else:
cur_cfd[x, y] = dl_weights[x, y] * dl_score + (
1 - dl_weights[x, y]) * track_score
return cur_cfd
def threshold_mask(img, thresh_bg, thresh_fg):
"""设置背景和前景阈值mask
输入参数:
img : 原始图像, np.uint8 类型.
thresh_bg : 背景阈值百分比,低于该值置为0.
thresh_fg : 前景阈值百分比,超过该值置为1.
返回值:
dst : 原始图像设置完前景背景阈值mask结果, np.float32 类型.
"""
dst = (img / 255.0 - thresh_bg) / (thresh_fg - thresh_bg)
dst[np.where(dst > 1)] = 1
dst[np.where(dst < 0)] = 0
return dst.astype(np.float32)
def optflow_handle(cur_gray, scoremap, is_init):
"""光流优化
Args:
cur_gray : 当前帧灰度图
scoremap : 当前帧分割结果
is_init : 是否第一帧
Returns:
dst : 光流追踪图和预测结果融合图, 类型为 np.float32
"""
width, height = scoremap.shape[0], scoremap.shape[1]
disflow = cv2.DISOpticalFlow_create(cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST)
prev_gray = np.zeros((height, width), np.uint8)
prev_cfd = np.zeros((height, width), np.float32)
cur_cfd = scoremap.copy()
if is_init:
is_init = False
if height <= 64 or width <= 64:
disflow.setFinestScale(1)
elif height <= 160 or width <= 160:
disflow.setFinestScale(2)
else:
disflow.setFinestScale(3)
fusion_cfd = cur_cfd
else:
weights = np.ones((width, height), np.float32) * 0.3
track_cfd, is_track, weights = humanseg_tracking(
prev_gray, cur_gray, prev_cfd, weights, disflow)
fusion_cfd = humanseg_track_fuse(track_cfd, cur_cfd, weights, is_track)
fusion_cfd = cv2.GaussianBlur(fusion_cfd, (3, 3), 0)
return fusion_cfd
class HumanSeg:
"""人像分割类
封装了人像分割模型的加载,数据预处理,预测,后处理等
"""
def __init__(self, model_dir, mean, scale, eval_size, use_gpu=False):
self.mean = np.array(mean).reshape((3, 1, 1))
self.scale = np.array(scale).reshape((3, 1, 1))
self.eval_size = eval_size
self.load_model(model_dir, use_gpu)
def load_model(self, model_dir, use_gpu):
"""加载模型并创建predictor
Args:
model_dir: 预测模型路径, 包含 `__model__` 和 `__params__`
use_gpu: 是否使用GPU加速
"""
prog_file = os.path.join(model_dir, '__model__')
params_file = os.path.join(model_dir, '__params__')
config = fluid.core.AnalysisConfig(prog_file, params_file)
if use_gpu:
config.enable_use_gpu(100, 0)
config.switch_ir_optim(True)
else:
config.disable_gpu()
config.disable_glog_info()
config.switch_specify_input_names(True)
config.enable_memory_optim()
self.predictor = fluid.core.create_paddle_predictor(config)
def preprocess(self, image):
"""图像预处理
hwc_rgb 转换为 chw_bgr,并进行归一化
输入参数:
image: 原始图像
返回值:
经过预处理后的图片结果
"""
img_mat = cv2.resize(
image, self.eval_size, interpolation=cv2.INTER_LINEAR)
# HWC -> CHW
img_mat = img_mat.swapaxes(1, 2)
img_mat = img_mat.swapaxes(0, 1)
# Convert to float
img_mat = img_mat[:, :, :].astype('float32')
img_mat = (img_mat / 255. - self.mean) / self.scale
img_mat = img_mat[np.newaxis, :, :, :]
return img_mat
def postprocess(self, image, output_data):
"""对预测结果进行后处理
Args:
image: 原始图,opencv 图片对象
output_data: Paddle预测结果原始数据
Returns:
原图和预测结果融合并做了光流优化的结果图
"""
scoremap = output_data[0, 1, :, :]
scoremap = (scoremap * 255).astype(np.uint8)
ori_h, ori_w = image.shape[0], image.shape[1]
evl_h, evl_w = self.eval_size[0], self.eval_size[1]
# 光流处理
cur_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
cur_gray = cv2.resize(cur_gray, (evl_w, evl_h))
optflow_map = optflow_handle(cur_gray, scoremap, False)
optflow_map = cv2.GaussianBlur(optflow_map, (3, 3), 0)
optflow_map = threshold_mask(optflow_map, thresh_bg=0.2, thresh_fg=0.8)
optflow_map = cv2.resize(optflow_map, (ori_w, ori_h))
optflow_map = np.repeat(optflow_map[:, :, np.newaxis], 3, axis=2)
bg_im = np.ones_like(optflow_map) * 255
comb = (optflow_map * image + (1 - optflow_map) * bg_im).astype(
np.uint8)
return comb
def run_predict(self, image):
"""运行预测并返回可视化结果图
输入参数:
image: 需要预测的原始图, opencv图片对象
返回值:
可视化的预测结果图
"""
im_mat = self.preprocess(image)
im_tensor = fluid.core.PaddleTensor(im_mat.copy().astype('float32'))
output_data = self.predictor.run([im_tensor])[1]
output_data = output_data.as_ndarray()
return self.postprocess(image, output_data)
def image_segment(self, path):
"""对图片文件进行分割
结果保存到`result.jpeg`文件中
"""
img_mat = cv2.imread(path)
img_mat = self.run_predict(img_mat)
cv2.imwrite('result.jpeg', img_mat)
def video_segment(self, path=None):
"""
对视屏流进行分割,
path为None时默认打开摄像头。
"""
if path is None:
cap = cv2.VideoCapture(0)
else:
cap = cv2.VideoCapture(path)
if not cap.isOpened():
raise IOError("Error opening video stream or file")
return
if path is not None:
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
# 用于保存预测结果视频
out = cv2.VideoWriter('result.avi',
cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'),
fps, (width, height))
# 开始获取视频帧
while cap.isOpened():
ret, frame = cap.read()
if ret:
img_mat = self.run_predict(frame)
out.write(img_mat)
else:
break
cap.release()
out.release()
else:
while cap.isOpened():
ret, frame = cap.read()
if ret:
img_mat = self.run_predict(frame)
cv2.imshow('HumanSegmentation', img_mat)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
else:
break
cap.release()
def main(args):
"""预测程序入口
完成模型加载, 对视频、摄像头、图片文件等预测过程
"""
model_dir = args.model_dir
use_gpu = args.use_gpu
# 加载模型
mean = [0.5, 0.5, 0.5]
scale = [0.5, 0.5, 0.5]
eval_size = (192, 192)
model = HumanSeg(model_dir, mean, scale, eval_size, use_gpu)
if args.use_camera:
# 开启摄像头
model.video_segment()
elif args.video_path:
# 使用视频文件作为输入
model.video_segment(args.video_path)
elif args.img_path:
# 使用图片文件作为输入
model.image_segment(args.img_path)
else:
raise ValueError(
'One of (--model_dir, --video_path, --use_camera) should be given.')
def parse_args():
"""解析命令行参数
"""
parser = argparse.ArgumentParser('Realtime Human Segmentation')
parser.add_argument(
'--model_dir',
type=str,
default='',
help='path of human segmentation model')
parser.add_argument(
'--img_path', type=str, default='', help='path of input image')
parser.add_argument(
'--video_path', type=str, default='', help='path of input video')
parser.add_argument(
'--use_camera',
type=bool,
default=False,
help='input video stream from camera')
parser.add_argument(
'--use_gpu', type=bool, default=False, help='enable gpu')
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
main(args)
opencv-python==4.1.2.30
opencv-contrib-python==4.2.0.32
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
import paddle.fluid as fluid
from .libs import scope
from .libs import bn_relu, conv, max_pool, deconv
from .libs import sigmoid_to_softmax
from .seg_modules import softmax_with_loss
from .seg_modules import dice_loss, bce_loss
class UNet(object):
"""实现Unet模型
`"U-Net: Convolutional Networks for Biomedical Image Segmentation"
<https://arxiv.org/abs/1505.04597>`
Args:
num_classes (int): 类别数
mode (str): 网络运行模式,根据mode构建网络的输入和返回。
当mode为'train'时,输入为image(-1, 3, -1, -1)和label (-1, 1, -1, -1) 返回loss。
当mode为'train'时,输入为image (-1, 3, -1, -1)和label (-1, 1, -1, -1),返回loss,
pred (与网络输入label 相同大小的预测结果,值代表相应的类别),label,mask(非忽略值的mask,
与label相同大小,bool类型)。
当mode为'test'时,输入为image(-1, 3, -1, -1)返回pred (-1, 1, -1, -1)和
logit (-1, num_classes, -1, -1) 通道维上代表每一类的概率值。
upsample_mode (str): UNet decode时采用的上采样方式,取值为'bilinear'时利用双线行差值进行上菜样,
当输入其他选项时则利用反卷积进行上菜样,默认为'bilinear'。
use_bce_loss (bool): 是否使用bce loss作为网络的损失函数,只能用于两类分割。可与dice loss同时使用。
use_dice_loss (bool): 是否使用dice loss作为网络的损失函数,只能用于两类分割,可与bce loss同时使用。
当use_bce_loss和use_dice_loss都为False时,使用交叉熵损失函数。
class_weight (list/str): 交叉熵损失函数各类损失的权重。当class_weight为list的时候,长度应为
num_classes。当class_weight为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重
自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,
即平时使用的交叉熵损失函数。
ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。
Raises:
ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
ValueError: class_weight为list, 但长度不等于num_class。
class_weight为str, 但class_weight.low()不等于dynamic。
TypeError: class_weight不为None时,其类型不是list或str。
"""
def __init__(self,
num_classes,
mode='train',
upsample_mode='bilinear',
use_bce_loss=False,
use_dice_loss=False,
class_weight=None,
ignore_index=255):
# dice_loss或bce_loss只适用两类分割中
if num_classes > 2 and (use_bce_loss or use_dice_loss):
raise Exception(
"dice loss and bce loss is only applicable to binary classfication"
)
if class_weight is not None:
if isinstance(class_weight, list):
if len(class_weight) != num_classes:
raise ValueError(
"Length of class_weight should be equal to number of classes"
)
elif isinstance(class_weight, str):
if class_weight.lower() != 'dynamic':
raise ValueError(
"if class_weight is string, must be dynamic!")
else:
raise TypeError(
'Expect class_weight is a list or string but receive {}'.
format(type(class_weight)))
self.num_classes = num_classes
self.mode = mode
self.upsample_mode = upsample_mode
self.use_bce_loss = use_bce_loss
self.use_dice_loss = use_dice_loss
self.class_weight = class_weight
self.ignore_index = ignore_index
def _double_conv(self, data, out_ch):
param_attr = fluid.ParamAttr(
name='weights',
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0),
initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=0.33))
with scope("conv0"):
data = bn_relu(
conv(
data, out_ch, 3, stride=1, padding=1,
param_attr=param_attr))
with scope("conv1"):
data = bn_relu(
conv(
data, out_ch, 3, stride=1, padding=1,
param_attr=param_attr))
return data
def _down(self, data, out_ch):
# 下采样:max_pool + 2个卷积
with scope("down"):
data = max_pool(data, 2, 2, 0)
data = self._double_conv(data, out_ch)
return data
def _up(self, data, short_cut, out_ch):
# 上采样:data上采样(resize或deconv), 并与short_cut concat
param_attr = fluid.ParamAttr(
name='weights',
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0),
initializer=fluid.initializer.XavierInitializer(),
)
with scope("up"):
if self.upsample_mode == 'bilinear':
short_cut_shape = fluid.layers.shape(short_cut)
data = fluid.layers.resize_bilinear(data, short_cut_shape[2:])
else:
data = deconv(
data,
out_ch // 2,
filter_size=2,
stride=2,
padding=0,
param_attr=param_attr)
data = fluid.layers.concat([data, short_cut], axis=1)
data = self._double_conv(data, out_ch)
return data
def _encode(self, data):
# 编码器设置
short_cuts = []
with scope("encode"):
with scope("block1"):
data = self._double_conv(data, 64)
short_cuts.append(data)
with scope("block2"):
data = self._down(data, 128)
short_cuts.append(data)
with scope("block3"):
data = self._down(data, 256)
short_cuts.append(data)
with scope("block4"):
data = self._down(data, 512)
short_cuts.append(data)
with scope("block5"):
data = self._down(data, 512)
return data, short_cuts
def _decode(self, data, short_cuts):
# 解码器设置,与编码器对称
with scope("decode"):
with scope("decode1"):
data = self._up(data, short_cuts[3], 256)
with scope("decode2"):
data = self._up(data, short_cuts[2], 128)
with scope("decode3"):
data = self._up(data, short_cuts[1], 64)
with scope("decode4"):
data = self._up(data, short_cuts[0], 64)
return data
def _get_logit(self, data, num_classes):
# 根据类别数设置最后一个卷积层输出
param_attr = fluid.ParamAttr(
name='weights',
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0),
initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=0.01))
with scope("logit"):
data = conv(
data,
num_classes,
3,
stride=1,
padding=1,
param_attr=param_attr)
return data
def _get_loss(self, logit, label, mask):
avg_loss = 0
if not (self.use_dice_loss or self.use_bce_loss):
avg_loss += softmax_with_loss(
logit,
label,
mask,
num_classes=self.num_classes,
weight=self.class_weight,
ignore_index=self.ignore_index)
else:
if self.use_dice_loss:
avg_loss += dice_loss(logit, label, mask)
if self.use_bce_loss:
avg_loss += bce_loss(
logit, label, mask, ignore_index=self.ignore_index)
return avg_loss
def generate_inputs(self):
inputs = OrderedDict()
inputs['image'] = fluid.data(
dtype='float32', shape=[None, 3, None, None], name='image')
if self.mode == 'train':
inputs['label'] = fluid.data(
dtype='int32', shape=[None, 1, None, None], name='label')
elif self.mode == 'eval':
inputs['label'] = fluid.data(
dtype='int32', shape=[None, 1, None, None], name='label')
return inputs
def build_net(self, inputs):
# 在两类分割情况下,当loss函数选择dice_loss或bce_loss的时候,最后logit输出通道数设置为1
if self.use_dice_loss or self.use_bce_loss:
self.num_classes = 1
image = inputs['image']
encode_data, short_cuts = self._encode(image)
decode_data = self._decode(encode_data, short_cuts)
logit = self._get_logit(decode_data, self.num_classes)
if self.num_classes == 1:
out = sigmoid_to_softmax(logit)
out = fluid.layers.transpose(out, [0, 2, 3, 1])
else:
out = fluid.layers.transpose(logit, [0, 2, 3, 1])
pred = fluid.layers.argmax(out, axis=3)
pred = fluid.layers.unsqueeze(pred, axes=[3])
if self.mode == 'train':
label = inputs['label']
mask = label != self.ignore_index
return self._get_loss(logit, label, mask)
elif self.mode == 'eval':
label = inputs['label']
mask = label != self.ignore_index
loss = self._get_loss(logit, label, mask)
return loss, pred, label, mask
else:
if self.num_classes == 1:
logit = sigmoid_to_softmax(logit)
else:
logit = fluid.layers.softmax(logit, axis=1)
return pred, logit
......@@ -22,20 +22,20 @@ sys.path.append(TEST_PATH)
from test_utils import download_file_and_uncompress
model_urls = {
"humanseg_server":
"https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_server.zip",
"humanseg_server_export":
"https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_server_export.zip",
"humanseg_mobile":
"https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_mobile.zip",
"humanseg_mobile_export":
"https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_mobile_export.zip",
"humanseg_server_ckpt":
"https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_server_ckpt.zip",
"humanseg_server_inference":
"https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_server_inference.zip",
"humanseg_mobile_ckpt":
"https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_mobile_ckpt.zip",
"humanseg_mobile_inference":
"https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_mobile_inference.zip",
"humanseg_mobile_quant":
"https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_mobile_quant.zip",
"humanseg_lite":
"https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_lite.zip",
"humanseg_lite_export":
"https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_lite_export.zip",
"humanseg_lite_ckpt":
"https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_lite_ckpt.zip",
"humanseg_lite_inference":
"https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_lite_inference.zip",
"humanseg_lite_quant":
"https://paddleseg.bj.bcebos.com/humanseg/models/humanseg_lite_quant.zip",
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册