未验证 提交 8ebccc9f 编写于 作者: Z zhiboniu 提交者: GitHub

update keypoint code citation (#4456)

上级 8a685c4b
...@@ -23,7 +23,7 @@ from keypoint_preprocess import get_affine_mat_kernel, get_affine_transform ...@@ -23,7 +23,7 @@ from keypoint_preprocess import get_affine_mat_kernel, get_affine_transform
class HrHRNetPostProcess(object): class HrHRNetPostProcess(object):
''' """
HrHRNet postprocess contain: HrHRNet postprocess contain:
1) get topk keypoints in the output heatmap 1) get topk keypoints in the output heatmap
2) sample the tagmap's value corresponding to each of the topk coordinate 2) sample the tagmap's value corresponding to each of the topk coordinate
...@@ -37,7 +37,7 @@ class HrHRNetPostProcess(object): ...@@ -37,7 +37,7 @@ class HrHRNetPostProcess(object):
inputs(list[heatmap]): the output list of modle, [heatmap, heatmap_maxpool, tagmap], heatmap_maxpool used to get topk inputs(list[heatmap]): the output list of modle, [heatmap, heatmap_maxpool, tagmap], heatmap_maxpool used to get topk
original_height, original_width (float): the original image size original_height, original_width (float): the original image size
''' """
def __init__(self, max_num_people=30, heat_thresh=0.2, tag_thresh=1.): def __init__(self, max_num_people=30, heat_thresh=0.2, tag_thresh=1.):
self.max_num_people = max_num_people self.max_num_people = max_num_people
...@@ -212,7 +212,7 @@ class HRNetPostProcess(object): ...@@ -212,7 +212,7 @@ class HRNetPostProcess(object):
return output_flipped return output_flipped
def get_max_preds(self, heatmaps): def get_max_preds(self, heatmaps):
'''get predictions from score maps """get predictions from score maps
Args: Args:
heatmaps: numpy.ndarray([batch_size, num_joints, height, width]) heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
...@@ -220,7 +220,7 @@ class HRNetPostProcess(object): ...@@ -220,7 +220,7 @@ class HRNetPostProcess(object):
Returns: Returns:
preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords
maxvals: numpy.ndarray([batch_size, num_joints, 2]), the maximum confidence of the keypoints maxvals: numpy.ndarray([batch_size, num_joints, 2]), the maximum confidence of the keypoints
''' """
assert isinstance(heatmaps, assert isinstance(heatmaps,
np.ndarray), 'heatmaps should be numpy.ndarray' np.ndarray), 'heatmaps should be numpy.ndarray'
assert heatmaps.ndim == 4, 'batch_images should be 4-ndim' assert heatmaps.ndim == 4, 'batch_images should be 4-ndim'
...@@ -286,6 +286,10 @@ class HRNetPostProcess(object): ...@@ -286,6 +286,10 @@ class HRNetPostProcess(object):
return coord return coord
def dark_postprocess(self, hm, coords, kernelsize): def dark_postprocess(self, hm, coords, kernelsize):
"""
refer to https://github.com/ilovepose/DarkPose/lib/core/inference.py
"""
hm = self.gaussian_blur(hm, kernelsize) hm = self.gaussian_blur(hm, kernelsize)
hm = np.maximum(hm, 1e-10) hm = np.maximum(hm, 1e-10)
hm = np.log(hm) hm = np.log(hm)
......
...@@ -11,7 +11,9 @@ ...@@ -11,7 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
this code is based on https://github.com/open-mmlab/mmpose/mmpose/core/post_processing/post_transforms.py
"""
import cv2 import cv2
import numpy as np import numpy as np
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import cv2 import cv2
import numpy as np import numpy as np
from keypoint_preprocess import get_affine_transform
def decode_image(im_file, im_info): def decode_image(im_file, im_info):
...@@ -263,90 +264,6 @@ class WarpAffine(object): ...@@ -263,90 +264,6 @@ class WarpAffine(object):
self.scale = scale self.scale = scale
self.shift = shift self.shift = shift
def _get_3rd_point(self, a, b):
assert len(
a) == 2, 'input of _get_3rd_point should be point with length of 2'
assert len(
b) == 2, 'input of _get_3rd_point should be point with length of 2'
direction = a - b
third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32)
return third_pt
def rotate_point(self, pt, angle_rad):
"""Rotate a point by an angle.
Args:
pt (list[float]): 2 dimensional point to be rotated
angle_rad (float): rotation angle by radian
Returns:
list[float]: Rotated point.
"""
assert len(pt) == 2
sn, cs = np.sin(angle_rad), np.cos(angle_rad)
new_x = pt[0] * cs - pt[1] * sn
new_y = pt[0] * sn + pt[1] * cs
rotated_pt = [new_x, new_y]
return rotated_pt
def get_affine_transform(self,
center,
input_size,
rot,
output_size,
shift=(0., 0.),
inv=False):
"""Get the affine transform matrix, given the center/scale/rot/output_size.
Args:
center (np.ndarray[2, ]): Center of the bounding box (x, y).
input_size (np.ndarray[2, ]): Size of input feature (width, height).
rot (float): Rotation angle (degree).
output_size (np.ndarray[2, ]): Size of the destination heatmaps.
shift (0-100%): Shift translation ratio wrt the width/height.
Default (0., 0.).
inv (bool): Option to inverse the affine transform direction.
(inv=False: src->dst or inv=True: dst->src)
Returns:
np.ndarray: The transform matrix.
"""
assert len(center) == 2
assert len(output_size) == 2
assert len(shift) == 2
if not isinstance(input_size, (np.ndarray, list)):
input_size = np.array([input_size, input_size], dtype=np.float32)
scale_tmp = input_size
shift = np.array(shift)
src_w = scale_tmp[0]
dst_w = output_size[0]
dst_h = output_size[1]
rot_rad = np.pi * rot / 180
src_dir = self.rotate_point([0., src_w * -0.5], rot_rad)
dst_dir = np.array([0., dst_w * -0.5])
src = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center + scale_tmp * shift
src[1, :] = center + src_dir + scale_tmp * shift
src[2, :] = self._get_3rd_point(src[0, :], src[1, :])
dst = np.zeros((3, 2), dtype=np.float32)
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
dst[2, :] = self._get_3rd_point(dst[0, :], dst[1, :])
if inv:
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
else:
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
return trans
def __call__(self, im, im_info): def __call__(self, im, im_info):
""" """
Args: Args:
...@@ -371,7 +288,7 @@ class WarpAffine(object): ...@@ -371,7 +288,7 @@ class WarpAffine(object):
input_h, input_w = self.input_h, self.input_w input_h, input_w = self.input_h, self.input_w
c = np.array([w / 2., h / 2.], dtype=np.float32) c = np.array([w / 2., h / 2.], dtype=np.float32)
trans_input = self.get_affine_transform(c, s, 0, [input_w, input_h]) trans_input = get_affine_transform(c, s, 0, [input_w, input_h])
img = cv2.resize(img, (w, h)) img = cv2.resize(img, (w, h))
inp = cv2.warpAffine( inp = cv2.warpAffine(
img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR) img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR)
......
...@@ -11,7 +11,9 @@ ...@@ -11,7 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
this code is base on https://github.com/open-mmlab/mmpose
"""
import os import os
import cv2 import cv2
import numpy as np import numpy as np
...@@ -25,8 +27,7 @@ from ppdet.core.workspace import register, serializable ...@@ -25,8 +27,7 @@ from ppdet.core.workspace import register, serializable
@serializable @serializable
class KeypointBottomUpBaseDataset(DetDataset): class KeypointBottomUpBaseDataset(DetDataset):
"""Base class for bottom-up datasets. Adapted from """Base class for bottom-up datasets.
https://github.com/open-mmlab/mmpose
All datasets should subclass it. All datasets should subclass it.
All subclasses should overwrite: All subclasses should overwrite:
...@@ -90,8 +91,7 @@ class KeypointBottomUpBaseDataset(DetDataset): ...@@ -90,8 +91,7 @@ class KeypointBottomUpBaseDataset(DetDataset):
@register @register
@serializable @serializable
class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset): class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset):
"""COCO dataset for bottom-up pose estimation. Adapted from """COCO dataset for bottom-up pose estimation.
https://github.com/open-mmlab/mmpose
The dataset loads raw features and apply specified transforms The dataset loads raw features and apply specified transforms
to return a dict containing the image tensors and other information. to return a dict containing the image tensors and other information.
...@@ -262,8 +262,7 @@ class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset): ...@@ -262,8 +262,7 @@ class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset):
@register @register
@serializable @serializable
class KeypointBottomUpCrowdPoseDataset(KeypointBottomUpCocoDataset): class KeypointBottomUpCrowdPoseDataset(KeypointBottomUpCocoDataset):
"""CrowdPose dataset for bottom-up pose estimation. Adapted from """CrowdPose dataset for bottom-up pose estimation.
https://github.com/open-mmlab/mmpose
The dataset loads raw features and apply specified transforms The dataset loads raw features and apply specified transforms
to return a dict containing the image tensors and other information. to return a dict containing the image tensors and other information.
...@@ -387,9 +386,7 @@ class KeypointTopDownBaseDataset(DetDataset): ...@@ -387,9 +386,7 @@ class KeypointTopDownBaseDataset(DetDataset):
@register @register
@serializable @serializable
class KeypointTopDownCocoDataset(KeypointTopDownBaseDataset): class KeypointTopDownCocoDataset(KeypointTopDownBaseDataset):
"""COCO dataset for top-down pose estimation. Adapted from """COCO dataset for top-down pose estimation.
https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
Copyright (c) Microsoft, under the MIT License.
The dataset loads raw features and apply specified transforms The dataset loads raw features and apply specified transforms
to return a dict containing the image tensors and other information. to return a dict containing the image tensors and other information.
...@@ -582,9 +579,7 @@ class KeypointTopDownCocoDataset(KeypointTopDownBaseDataset): ...@@ -582,9 +579,7 @@ class KeypointTopDownCocoDataset(KeypointTopDownBaseDataset):
@register @register
@serializable @serializable
class KeypointTopDownMPIIDataset(KeypointTopDownBaseDataset): class KeypointTopDownMPIIDataset(KeypointTopDownBaseDataset):
"""MPII dataset for topdown pose estimation. Adapted from """MPII dataset for topdown pose estimation.
https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
Copyright (c) Microsoft, under the MIT License.
The dataset loads raw features and apply specified transforms The dataset loads raw features and apply specified transforms
to return a dict containing the image tensors and other information. to return a dict containing the image tensors and other information.
......
...@@ -682,6 +682,10 @@ class ToHeatmapsTopDown(object): ...@@ -682,6 +682,10 @@ class ToHeatmapsTopDown(object):
self.sigma = sigma self.sigma = sigma
def __call__(self, records): def __call__(self, records):
"""refer to
https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
Copyright (c) Microsoft, under the MIT License.
"""
joints = records['joints'] joints = records['joints']
joints_vis = records['joints_vis'] joints_vis = records['joints_vis']
num_joints = joints.shape[0] num_joints = joints.shape[0]
......
...@@ -27,11 +27,10 @@ __all__ = ['KeyPointTopDownCOCOEval', 'KeyPointTopDownMPIIEval'] ...@@ -27,11 +27,10 @@ __all__ = ['KeyPointTopDownCOCOEval', 'KeyPointTopDownMPIIEval']
class KeyPointTopDownCOCOEval(object): class KeyPointTopDownCOCOEval(object):
''' """refer to
Adapted from
https://github.com/leoxiaobin/deep-high-resolution-net.pytorch https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
Copyright (c) Microsoft, under the MIT License. Copyright (c) Microsoft, under the MIT License.
''' """
def __init__(self, def __init__(self,
anno_file, anno_file,
...@@ -286,7 +285,7 @@ class KeyPointTopDownMPIIEval(object): ...@@ -286,7 +285,7 @@ class KeyPointTopDownMPIIEval(object):
return self.eval_results return self.eval_results
def evaluate(self, outputs, savepath=None): def evaluate(self, outputs, savepath=None):
"""Evaluate PCKh for MPII dataset. Adapted from """Evaluate PCKh for MPII dataset. refer to
https://github.com/leoxiaobin/deep-high-resolution-net.pytorch https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
Copyright (c) Microsoft, under the MIT License. Copyright (c) Microsoft, under the MIT License.
......
...@@ -11,6 +11,9 @@ ...@@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
this code is based on https://github.com/open-mmlab/mmpose
"""
import cv2 import cv2
import numpy as np import numpy as np
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册