提交 96764865 编写于 作者: F FlyingQianMM

Applications -> examples

上级 829431c5
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import humanseg_postprocess
......@@ -7,8 +7,6 @@
**前置依赖**
* paddlepaddle >= 1.8.0
* python >= 3.5
* cython
* pycocotools
```
pip install paddlex -i https://mirror.baidu.com/pypi/simple
......@@ -95,6 +93,10 @@ python bg_replace.py --model_dir pretrain_weights/humanseg_mobile_inference --im
## 训练
使用下述命令基于与训练模型进行Fine-tuning,请确保选用的模型结构`model_type`与模型参数`pretrain_weights`匹配。
```bash
# 指定GPU卡号(以0号卡为例)
export CUDA_VISIBLE_DEVICES=0
# 若不使用GPU,则将CUDA_VISIBLE_DEVICES指定为空
# export CUDA_VISIBLE_DEVICES=
python train.py --model_type HumanSegMobile \
--save_dir output/ \
--data_dir data/mini_supervisely \
......@@ -177,7 +179,3 @@ python quant_offline.py --model_dir output/best_model \
* `--quant_list`: 量化数据集列表路径,一般直接选择训练集或验证集
* `--save_dir`: 量化模型保存路径
* `--image_shape`: 网络输入图像大小(w, h)
## AIStudio在线教程
我们在AI Studio平台上提供了人像分割在线体验的教程,[点击体验](https://aistudio.baidu.com/aistudio/projectdetail/475345)
......@@ -16,7 +16,8 @@
import numpy as np
def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow):
def cal_optical_flow_tracking(pre_gray, cur_gray, prev_cfd, dl_weights,
disflow):
"""计算光流跟踪匹配点和光流图
输入参数:
pre_gray: 上一帧灰度图
......@@ -59,7 +60,7 @@ def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow):
return track_cfd, is_track, dl_weights
def human_seg_track_fuse(track_cfd, dl_cfd, dl_weights, is_track):
def fuse_optical_flow_tracking(track_cfd, dl_cfd, dl_weights, is_track):
"""光流追踪图和人像分割结构融合
输入参数:
track_cfd: 光流追踪图
......@@ -116,9 +117,9 @@ def postprocess(cur_gray, scoremap, prev_gray, pre_cfd, disflow, is_init):
fusion_cfd = cur_cfd
else:
weights = np.ones((h, w), np.float32) * 0.3
track_cfd, is_track, weights = human_seg_tracking(
track_cfd, is_track, weights = cal_optical_flow_tracking(
prev_gray, cur_gray, pre_cfd, weights, disflow)
fusion_cfd = human_seg_track_fuse(track_cfd, cur_cfd, weights,
is_track)
fusion_cfd = fuse_optical_flow_tracking(track_cfd, cur_cfd, weights,
is_track)
return fusion_cfd
......@@ -13,11 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
# 选择使用0号卡
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# 使用CPU
#os.environ['CUDA_VISIBLE_DEVICES'] = ''
import argparse
import paddlex as pdx
......
......@@ -19,7 +19,7 @@ import os.path as osp
import cv2
import numpy as np
from utils.humanseg_postprocess import postprocess, threshold_mask
from postprocess import postprocess, threshold_mask
import paddlex as pdx
import paddlex.utils.logging as logging
from paddlex.seg import transforms
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册