humanseg_postprocess.py 4.7 KB
Newer Older
W
wuyefeilin 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16 17
import numpy as np
import cv2
W
wuyefeilin 已提交
18 19
import os

20

W
wuyefeilin 已提交
21 22 23
def get_round(data):
    round = 0.5 if data >= 0 else -0.5
    return (int)(data + round)
24

W
wuyefeilin 已提交
25 26

def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow):
27 28 29 30 31 32 33 34 35 36 37 38
    """计算光流跟踪匹配点和光流图
    输入参数:
        pre_gray: 上一帧灰度图
        cur_gray: 当前帧灰度图
        prev_cfd: 上一帧光流图
        dl_weights: 融合权重图
        disflow: 光流数据结构
    返回值:
        is_track: 光流点跟踪二值图,即是否具有光流点匹配
        track_cfd: 光流跟踪图
    """
    check_thres = 8
W
wuyefeilin 已提交
39
    h, w = pre_gray.shape[:2]
40 41 42 43
    track_cfd = np.zeros_like(prev_cfd)
    is_track = np.zeros_like(pre_gray)
    flow_fw = disflow.calc(pre_gray, cur_gray, None)
    flow_bw = disflow.calc(cur_gray, pre_gray, None)
W
wuyefeilin 已提交
44 45 46
    for r in range(h):
        for c in range(w):
            fxy_fw = flow_fw[r, c]
47
            dx_fw = get_round(fxy_fw[0])
W
wuyefeilin 已提交
48
            cur_x = dx_fw + c
49
            dy_fw = get_round(fxy_fw[1])
W
wuyefeilin 已提交
50 51
            cur_y = dy_fw + r
            if cur_x < 0 or cur_x >= w or cur_y < 0 or cur_y >= h:
52 53 54 55
                continue
            fxy_bw = flow_bw[cur_y, cur_x]
            dx_bw = get_round(fxy_bw[0])
            dy_bw = get_round(fxy_bw[1])
W
wuyefeilin 已提交
56 57
            if ((dy_fw + dy_bw) * (dy_fw + dy_bw) +
                (dx_fw + dx_bw) * (dx_fw + dx_bw)) >= check_thres:
58 59 60 61 62
                continue
            if abs(dy_fw) <= 0 and abs(dx_fw) <= 0 and abs(dy_bw) <= 0 and abs(
                    dx_bw) <= 0:
                dl_weights[cur_y, cur_x] = 0.05
            is_track[cur_y, cur_x] = 1
W
wuyefeilin 已提交
63
            track_cfd[cur_y, cur_x] = prev_cfd[r, c]
64 65 66
    return track_cfd, is_track, dl_weights


W
wuyefeilin 已提交
67
def human_seg_track_fuse(track_cfd, dl_cfd, dl_weights, is_track):
68 69 70 71 72 73
    """光流追踪图和人像分割结构融合
    输入参数:
        track_cfd: 光流追踪图
        dl_cfd: 当前帧分割结果
        dl_weights: 融合权重图
        is_track: 光流点匹配二值图
W
wuyefeilin 已提交
74
    返回
75 76
        cur_cfd: 光流跟踪图和人像分割结果融合图
    """
W
wuyefeilin 已提交
77
    fusion_cfd = dl_cfd.copy()
78
    idxs = np.where(is_track > 0)
W
wuyefeilin 已提交
79
    for i in range(len(idxs[0])):
80 81 82
        x, y = idxs[0][i], idxs[1][i]
        dl_score = dl_cfd[x, y]
        track_score = track_cfd[x, y]
W
wuyefeilin 已提交
83 84
        fusion_cfd[x, y] = dl_weights[x, y] * dl_score + (
            1 - dl_weights[x, y]) * track_score
85 86
        if dl_score > 0.9 or dl_score < 0.1:
            if dl_weights[x, y] < 0.1:
W
wuyefeilin 已提交
87
                fusion_cfd[x, y] = 0.3 * dl_score + 0.7 * track_score
88
            else:
W
wuyefeilin 已提交
89
                fusion_cfd[x, y] = 0.4 * dl_score + 0.6 * track_score
90
        else:
W
wuyefeilin 已提交
91
            fusion_cfd[x, y] = dl_weights[x, y] * dl_score + (
92
                1 - dl_weights[x, y]) * track_score
W
wuyefeilin 已提交
93
    return fusion_cfd
94 95


W
wuyefeilin 已提交
96
def postprocess(cur_gray, scoremap, prev_gray, pre_cfd, disflow, is_init):
97 98 99
    """光流优化
    Args:
        cur_gray : 当前帧灰度图
W
wuyefeilin 已提交
100 101
        pre_gray : 前一帧灰度图
        pre_cfd  :前一帧融合结果
102
        scoremap : 当前帧分割结果
W
wuyefeilin 已提交
103
        difflow  : 光流
104 105
        is_init : 是否第一帧
    Returns:
W
wuyefeilin 已提交
106
        fusion_cfd : 光流追踪图和预测结果融合图
107 108 109
    """
    height, width = scoremap.shape[0], scoremap.shape[1]
    disflow = cv2.DISOpticalFlow_create(cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST)
W
wuyefeilin 已提交
110
    h, w = scoremap.shape
111
    cur_cfd = scoremap.copy()
W
wuyefeilin 已提交
112

113 114
    if is_init:
        is_init = False
W
wuyefeilin 已提交
115
        if h <= 64 or w <= 64:
116
            disflow.setFinestScale(1)
W
wuyefeilin 已提交
117
        elif h <= 160 or w <= 160:
118 119 120 121 122
            disflow.setFinestScale(2)
        else:
            disflow.setFinestScale(3)
        fusion_cfd = cur_cfd
    else:
W
wuyefeilin 已提交
123 124 125 126 127
        weights = np.ones((w, h), np.float32) * 0.3
        track_cfd, is_track, weights = human_seg_tracking(
            prev_gray, cur_gray, pre_cfd, weights, disflow)
        fusion_cfd = human_seg_track_fuse(track_cfd, cur_cfd, weights, is_track)

128
    fusion_cfd = cv2.GaussianBlur(fusion_cfd, (3, 3), 0)
W
wuyefeilin 已提交
129

130 131 132
    return fusion_cfd


W
wuyefeilin 已提交
133 134 135 136 137
def threshold_mask(img, thresh_bg, thresh_fg):
    dst = (img / 255.0 - thresh_bg) / (thresh_fg - thresh_bg)
    dst[np.where(dst > 1)] = 1
    dst[np.where(dst < 0)] = 0
    return dst.astype(np.float32)