# Copyright (c) 2019-present, Baidu, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ############################################################################## from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np def oks_iou(g, d, a_g, a_d, sigmas=None, in_vis_thre=None): if not isinstance(sigmas, np.ndarray): sigmas = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89]) / 10.0 vars = (sigmas * 2) ** 2 xg = g[0::3] yg = g[1::3] vg = g[2::3] ious = np.zeros((d.shape[0])) for n_d in range(0, d.shape[0]): xd = d[n_d, 0::3] yd = d[n_d, 1::3] vd = d[n_d, 2::3] dx = xd - xg dy = yd - yg e = (dx ** 2 + dy ** 2) / vars / ((a_g + a_d[n_d]) / 2 + np.spacing(1)) / 2 if in_vis_thre is not None: ind = list(vg > in_vis_thre) and list(vd > in_vis_thre) e = e[ind] ious[n_d] = np.sum(np.exp(-e)) / e.shape[0] if e.shape[0] != 0 else 0.0 return ious def oks_nms(kpts_db, thresh, sigmas=None, in_vis_thre=None): """ greedily select boxes with high confidence and overlap with current maximum <= thresh rule out overlap >= thresh, overlap = oks :param kpts_db :param thresh: retain overlap < thresh :return: indexes to keep """ if len(kpts_db) == 0: return [] scores = np.array([kpts_db[i]['score'] for i in range(len(kpts_db))]) kpts = np.array([kpts_db[i]['keypoints'].flatten() for i in range(len(kpts_db))]) areas = np.array([kpts_db[i]['area'] for i in range(len(kpts_db))]) order = scores.argsort()[::-1] keep = [] while order.size > 0: i = order[0] keep.append(i) oks_ovr = oks_iou(kpts[i], kpts[order[1:]], areas[i], areas[order[1:]], sigmas, in_vis_thre) inds = np.where(oks_ovr <= thresh)[0] order = order[inds + 1] return keep