未验证 提交 8ebccc9f 编写于 作者: Z zhiboniu 提交者: GitHub

update keypoint code citation (#4456)

上级 8a685c4b
......@@ -23,7 +23,7 @@ from keypoint_preprocess import get_affine_mat_kernel, get_affine_transform
class HrHRNetPostProcess(object):
'''
"""
HrHRNet postprocess contain:
1) get topk keypoints in the output heatmap
2) sample the tagmap's value corresponding to each of the topk coordinate
......@@ -37,7 +37,7 @@ class HrHRNetPostProcess(object):
inputs(list[heatmap]): the output list of modle, [heatmap, heatmap_maxpool, tagmap], heatmap_maxpool used to get topk
original_height, original_width (float): the original image size
'''
"""
def __init__(self, max_num_people=30, heat_thresh=0.2, tag_thresh=1.):
self.max_num_people = max_num_people
......@@ -212,7 +212,7 @@ class HRNetPostProcess(object):
return output_flipped
def get_max_preds(self, heatmaps):
'''get predictions from score maps
"""get predictions from score maps
Args:
heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
......@@ -220,7 +220,7 @@ class HRNetPostProcess(object):
Returns:
preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords
maxvals: numpy.ndarray([batch_size, num_joints, 2]), the maximum confidence of the keypoints
'''
"""
assert isinstance(heatmaps,
np.ndarray), 'heatmaps should be numpy.ndarray'
assert heatmaps.ndim == 4, 'batch_images should be 4-ndim'
......@@ -286,6 +286,10 @@ class HRNetPostProcess(object):
return coord
def dark_postprocess(self, hm, coords, kernelsize):
"""
refer to https://github.com/ilovepose/DarkPose/lib/core/inference.py
"""
hm = self.gaussian_blur(hm, kernelsize)
hm = np.maximum(hm, 1e-10)
hm = np.log(hm)
......
......@@ -11,7 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
this code is based on https://github.com/open-mmlab/mmpose/mmpose/core/post_processing/post_transforms.py
"""
import cv2
import numpy as np
......
......@@ -14,6 +14,7 @@
import cv2
import numpy as np
from keypoint_preprocess import get_affine_transform
def decode_image(im_file, im_info):
......@@ -263,90 +264,6 @@ class WarpAffine(object):
self.scale = scale
self.shift = shift
def _get_3rd_point(self, a, b):
assert len(
a) == 2, 'input of _get_3rd_point should be point with length of 2'
assert len(
b) == 2, 'input of _get_3rd_point should be point with length of 2'
direction = a - b
third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32)
return third_pt
def rotate_point(self, pt, angle_rad):
"""Rotate a point by an angle.
Args:
pt (list[float]): 2 dimensional point to be rotated
angle_rad (float): rotation angle by radian
Returns:
list[float]: Rotated point.
"""
assert len(pt) == 2
sn, cs = np.sin(angle_rad), np.cos(angle_rad)
new_x = pt[0] * cs - pt[1] * sn
new_y = pt[0] * sn + pt[1] * cs
rotated_pt = [new_x, new_y]
return rotated_pt
def get_affine_transform(self,
center,
input_size,
rot,
output_size,
shift=(0., 0.),
inv=False):
"""Get the affine transform matrix, given the center/scale/rot/output_size.
Args:
center (np.ndarray[2, ]): Center of the bounding box (x, y).
input_size (np.ndarray[2, ]): Size of input feature (width, height).
rot (float): Rotation angle (degree).
output_size (np.ndarray[2, ]): Size of the destination heatmaps.
shift (0-100%): Shift translation ratio wrt the width/height.
Default (0., 0.).
inv (bool): Option to inverse the affine transform direction.
(inv=False: src->dst or inv=True: dst->src)
Returns:
np.ndarray: The transform matrix.
"""
assert len(center) == 2
assert len(output_size) == 2
assert len(shift) == 2
if not isinstance(input_size, (np.ndarray, list)):
input_size = np.array([input_size, input_size], dtype=np.float32)
scale_tmp = input_size
shift = np.array(shift)
src_w = scale_tmp[0]
dst_w = output_size[0]
dst_h = output_size[1]
rot_rad = np.pi * rot / 180
src_dir = self.rotate_point([0., src_w * -0.5], rot_rad)
dst_dir = np.array([0., dst_w * -0.5])
src = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center + scale_tmp * shift
src[1, :] = center + src_dir + scale_tmp * shift
src[2, :] = self._get_3rd_point(src[0, :], src[1, :])
dst = np.zeros((3, 2), dtype=np.float32)
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
dst[2, :] = self._get_3rd_point(dst[0, :], dst[1, :])
if inv:
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
else:
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
return trans
def __call__(self, im, im_info):
"""
Args:
......@@ -371,7 +288,7 @@ class WarpAffine(object):
input_h, input_w = self.input_h, self.input_w
c = np.array([w / 2., h / 2.], dtype=np.float32)
trans_input = self.get_affine_transform(c, s, 0, [input_w, input_h])
trans_input = get_affine_transform(c, s, 0, [input_w, input_h])
img = cv2.resize(img, (w, h))
inp = cv2.warpAffine(
img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR)
......
......@@ -11,7 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
this code is base on https://github.com/open-mmlab/mmpose
"""
import os
import cv2
import numpy as np
......@@ -25,8 +27,7 @@ from ppdet.core.workspace import register, serializable
@serializable
class KeypointBottomUpBaseDataset(DetDataset):
"""Base class for bottom-up datasets. Adapted from
https://github.com/open-mmlab/mmpose
"""Base class for bottom-up datasets.
All datasets should subclass it.
All subclasses should overwrite:
......@@ -90,8 +91,7 @@ class KeypointBottomUpBaseDataset(DetDataset):
@register
@serializable
class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset):
"""COCO dataset for bottom-up pose estimation. Adapted from
https://github.com/open-mmlab/mmpose
"""COCO dataset for bottom-up pose estimation.
The dataset loads raw features and apply specified transforms
to return a dict containing the image tensors and other information.
......@@ -262,8 +262,7 @@ class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset):
@register
@serializable
class KeypointBottomUpCrowdPoseDataset(KeypointBottomUpCocoDataset):
"""CrowdPose dataset for bottom-up pose estimation. Adapted from
https://github.com/open-mmlab/mmpose
"""CrowdPose dataset for bottom-up pose estimation.
The dataset loads raw features and apply specified transforms
to return a dict containing the image tensors and other information.
......@@ -387,9 +386,7 @@ class KeypointTopDownBaseDataset(DetDataset):
@register
@serializable
class KeypointTopDownCocoDataset(KeypointTopDownBaseDataset):
"""COCO dataset for top-down pose estimation. Adapted from
https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
Copyright (c) Microsoft, under the MIT License.
"""COCO dataset for top-down pose estimation.
The dataset loads raw features and apply specified transforms
to return a dict containing the image tensors and other information.
......@@ -582,9 +579,7 @@ class KeypointTopDownCocoDataset(KeypointTopDownBaseDataset):
@register
@serializable
class KeypointTopDownMPIIDataset(KeypointTopDownBaseDataset):
"""MPII dataset for topdown pose estimation. Adapted from
https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
Copyright (c) Microsoft, under the MIT License.
"""MPII dataset for topdown pose estimation.
The dataset loads raw features and apply specified transforms
to return a dict containing the image tensors and other information.
......
......@@ -682,6 +682,10 @@ class ToHeatmapsTopDown(object):
self.sigma = sigma
def __call__(self, records):
"""refer to
https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
Copyright (c) Microsoft, under the MIT License.
"""
joints = records['joints']
joints_vis = records['joints_vis']
num_joints = joints.shape[0]
......
......@@ -27,11 +27,10 @@ __all__ = ['KeyPointTopDownCOCOEval', 'KeyPointTopDownMPIIEval']
class KeyPointTopDownCOCOEval(object):
'''
Adapted from
"""refer to
https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
Copyright (c) Microsoft, under the MIT License.
'''
"""
def __init__(self,
anno_file,
......@@ -286,7 +285,7 @@ class KeyPointTopDownMPIIEval(object):
return self.eval_results
def evaluate(self, outputs, savepath=None):
"""Evaluate PCKh for MPII dataset. Adapted from
"""Evaluate PCKh for MPII dataset. refer to
https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
Copyright (c) Microsoft, under the MIT License.
......
......@@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
this code is based on https://github.com/open-mmlab/mmpose
"""
import cv2
import numpy as np
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册