From e76236514896a5a8a90ab90a9bb4d46b87d2725b Mon Sep 17 00:00:00 2001 From: zhiboniu <31800336+zhiboniu@users.noreply.github.com> Date: Mon, 26 Apr 2021 16:26:54 +0800 Subject: [PATCH] pose bottomup higherhrnet: model (#2638) --- ppdet/modeling/architectures/__init__.py | 2 + .../architectures/keypoint_hrhrnet.py | 274 ++++++++++++++++++ ppdet/modeling/backbones/hrnet.py | 1 + ppdet/modeling/heads/__init__.py | 2 + ppdet/modeling/heads/keypoint_hrhrnet_head.py | 109 +++++++ ppdet/modeling/keypoint_utils.py | 160 ++++++++++ ppdet/modeling/layers.py | 98 +++++++ ppdet/modeling/losses/__init__.py | 2 + ppdet/modeling/losses/keypoint_loss.py | 185 ++++++++++++ requirements.txt | 1 + 10 files changed, 834 insertions(+) create mode 100644 ppdet/modeling/architectures/keypoint_hrhrnet.py create mode 100644 ppdet/modeling/heads/keypoint_hrhrnet_head.py create mode 100644 ppdet/modeling/keypoint_utils.py create mode 100644 ppdet/modeling/losses/keypoint_loss.py diff --git a/ppdet/modeling/architectures/__init__.py b/ppdet/modeling/architectures/__init__.py index ae881607c..33dae8593 100644 --- a/ppdet/modeling/architectures/__init__.py +++ b/ppdet/modeling/architectures/__init__.py @@ -15,6 +15,7 @@ from . import fcos from . import solov2 from . import ttfnet from . import s2anet +from . import keypoint_hrhrnet from .meta_arch import * from .faster_rcnn import * @@ -26,3 +27,4 @@ from .fcos import * from .solov2 import * from .ttfnet import * from .s2anet import * +from .keypoint_hrhrnet import * diff --git a/ppdet/modeling/architectures/keypoint_hrhrnet.py b/ppdet/modeling/architectures/keypoint_hrhrnet.py new file mode 100644 index 000000000..79abbe489 --- /dev/null +++ b/ppdet/modeling/architectures/keypoint_hrhrnet.py @@ -0,0 +1,274 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from scipy.optimize import linear_sum_assignment +from collections import abc, defaultdict +import numpy as np +import paddle + +from ppdet.core.workspace import register, create, serializable +from .meta_arch import BaseArch +from .. import layers as L +from ..keypoint_utils import transpred + +__all__ = ['HigherHrnet'] + + +@register +class HigherHrnet(BaseArch): + __category__ = 'architecture' + + def __init__(self, + backbone='Hrnet', + hrhrnet_head='HigherHrnetHead', + post_process='HrHrnetPostProcess', + eval_flip=True, + flip_perm=None): + """ + HigherHrnet network, see https://arxiv.org/abs/ + + Args: + backbone (nn.Layer): backbone instance + hrhrnet_head (nn.Layer): keypoint_head instance + bbox_post_process (object): `BBoxPostProcess` instance + """ + super(HigherHrnet, self).__init__() + self.backbone = backbone + self.hrhrnet_head = hrhrnet_head + self.post_process = HrHrnetPostProcess() + self.flip = eval_flip + self.flip_perm = paddle.to_tensor(flip_perm) + self.deploy = False + + @classmethod + def from_config(cls, cfg, *args, **kwargs): + # backbone + backbone = create(cfg['backbone']) + # head + kwargs = {'input_shape': backbone.out_shape} + hrhrnet_head = create(cfg['hrhrnet_head'], **kwargs) + post_process = create(cfg['post_process']) + + return { + 'backbone': backbone, + "hrhrnet_head": hrhrnet_head, + "post_process": post_process, + } + + def _forward(self): + batchsize = self.inputs['image'].shape[0] + if self.flip and not self.training and not self.deploy: + self.inputs['image'] = paddle.concat( + (self.inputs['image'], paddle.flip(self.inputs['image'], [3]))) + body_feats = self.backbone(self.inputs) + + if self.training: + return self.hrhrnet_head(body_feats, self.inputs) + else: + outputs = self.hrhrnet_head(body_feats) + if self.deploy: + return outputs, [1] + if self.flip: + outputs = [paddle.split(o, 2) for o in outputs] + output_rflip = [ + paddle.flip(paddle.gather(o[1], self.flip_perm, 1), [3]) + for o in outputs + ] + output1 = [o[0] for o in outputs] + heatmap = (output1[0] + output_rflip[0]) / 2. + tagmaps = [output1[1], output_rflip[1]] + outputs = [heatmap] + tagmaps + + res_lst = [] + bboxnums = [] + for idx in range(batchsize): + item = [o[idx:(idx + 1)] for o in outputs] + + h = self.inputs['im_shape'][idx, 0].numpy().item() + w = self.inputs['im_shape'][idx, 1].numpy().item() + kpts, scores = self.post_process(item, h, w) + res_lst.append([kpts, scores]) + bboxnums.append(1) + + return res_lst, bboxnums + + def get_loss(self): + return self._forward() + + def get_pred(self): + outputs = {} + res_lst, bboxnums = self._forward() + outputs['keypoint'] = res_lst + outputs['bbox_num'] = bboxnums + return outputs + + +@register +@serializable +class HrHrnetPostProcess(object): + def __init__(self, max_num_people=30, heat_thresh=0.2, tag_thresh=1.): + self.interpolate = L.Upsample(2, mode='bilinear') + self.pool = L.MaxPool(5, 1, 2) + self.max_num_people = max_num_people + self.heat_thresh = heat_thresh + self.tag_thresh = tag_thresh + + def lerp(self, j, y, x, heatmap): + H, W = heatmap.shape[-2:] + left = np.clip(x - 1, 0, W - 1) + right = np.clip(x + 1, 0, W - 1) + up = np.clip(y - 1, 0, H - 1) + down = np.clip(y + 1, 0, H - 1) + offset_y = np.where(heatmap[j, down, x] > heatmap[j, up, x], 0.25, + -0.25) + offset_x = np.where(heatmap[j, y, right] > heatmap[j, y, left], 0.25, + -0.25) + return offset_y + 0.5, offset_x + 0.5 + + def __call__(self, inputs, original_height, original_width): + + # resize to image size + inputs = [self.interpolate(x) for x in inputs] + # aggregate + heatmap = inputs[0] + if len(inputs) == 3: + tagmap = paddle.concat( + (inputs[1].unsqueeze(4), inputs[2].unsqueeze(4)), axis=4) + else: + tagmap = inputs[1].unsqueeze(4) + + N, J, H, W = heatmap.shape + assert N == 1, "only support batch size 1" + # topk + maximum = self.pool(heatmap) + maxmap = heatmap * (heatmap == maximum) + maxmap = maxmap.reshape([N, J, -1]) + heat_k, inds_k = maxmap.topk(self.max_num_people, axis=2) + heatmap = heatmap[0].cpu().detach().numpy() + tagmap = tagmap[0].cpu().detach().numpy() + heats = heat_k[0].cpu().detach().numpy() + inds_np = inds_k[0].cpu().detach().numpy() + y = inds_np // W + x = inds_np % W + tags = tagmap[np.arange(J)[None, :].repeat(self.max_num_people), + y.flatten(), x.flatten()].reshape(J, -1, tagmap.shape[-1]) + coords = np.stack((y, x), axis=2) + # threshold + mask = heats > self.heat_thresh + # cluster + cluster = defaultdict(lambda: { + 'coords': np.zeros((J, 2), dtype=np.float32), + 'scores': np.zeros(J, dtype=np.float32), + 'tags': [] + }) + for jid, m in enumerate(mask): + num_valid = m.sum() + if num_valid == 0: + continue + valid_inds = np.where(m)[0] + valid_tags = tags[jid, m, :] + if len(cluster) == 0: # initialize + for i in valid_inds: + tag = tags[jid, i] + key = tag[0] + cluster[key]['tags'].append(tag) + cluster[key]['scores'][jid] = heats[jid, i] + cluster[key]['coords'][jid] = coords[jid, i] + continue + candidates = list(cluster.keys())[:self.max_num_people] + centroids = [ + np.mean( + cluster[k]['tags'], axis=0) for k in candidates + ] + num_clusters = len(centroids) + # shape is (num_valid, num_clusters, tag_dim) + dist = valid_tags[:, None, :] - np.array(centroids)[None, ...] + l2_dist = np.linalg.norm(dist, ord=2, axis=2) + # modulate dist with heat value, see `use_detection_val` + cost = np.round(l2_dist) * 100 - heats[jid, m, None] + # pad the cost matrix, otherwise new pose are ignored + if num_valid > num_clusters: + cost = np.pad(cost, ((0, 0), (0, num_valid - num_clusters)), + constant_values=((0, 0), (0, 1e-10))) + rows, cols = linear_sum_assignment(cost) + for y, x in zip(rows, cols): + tag = tags[jid, y] + if y < num_valid and x < num_clusters and \ + l2_dist[y, x] < self.tag_thresh: + key = candidates[x] # merge to cluster + else: + key = tag[0] # initialize new cluster + cluster[key]['tags'].append(tag) + cluster[key]['scores'][jid] = heats[jid, y] + cluster[key]['coords'][jid] = coords[jid, y] + + # shape is [k, J, 2] and [k, J] + pose_tags = np.array([cluster[k]['tags'] for k in cluster]) + pose_coords = np.array([cluster[k]['coords'] for k in cluster]) + pose_scores = np.array([cluster[k]['scores'] for k in cluster]) + valid = pose_scores > 0 + + pose_kpts = np.zeros((pose_scores.shape[0], J, 3), dtype=np.float32) + if valid.sum() == 0: + return pose_kpts, pose_kpts + + # refine coords + valid_coords = pose_coords[valid].astype(np.int32) + y = valid_coords[..., 0].flatten() + x = valid_coords[..., 1].flatten() + _, j = np.nonzero(valid) + offsets = self.lerp(j, y, x, heatmap) + pose_coords[valid, 0] += offsets[0] + pose_coords[valid, 1] += offsets[1] + + # mean score before salvage + mean_score = pose_scores.mean(axis=1) + pose_kpts[valid, 2] = pose_scores[valid] + + # TODO can we remove the outermost loop altogether + # salvage missing joints + + if True: + for pid, coords in enumerate(pose_coords): + # vj = np.nonzero(valid[pid])[0] + # vyx = coords[valid[pid]].astype(np.int32) + # tag_mean = tagmap[vj, vyx[:, 0], vyx[:, 1]].mean(axis=0) + + tag_mean = np.array(pose_tags[pid]).mean( + axis=0) #TODO: replace tagmap sample by history record + + norm = np.sum((tagmap - tag_mean)**2, axis=3)**0.5 + score = heatmap - np.round(norm) # (J, H, W) + flat_score = score.reshape(J, -1) + max_inds = np.argmax(flat_score, axis=1) + max_scores = np.max(flat_score, axis=1) + salvage_joints = (pose_scores[pid] == 0) & (max_scores > 0) + if salvage_joints.sum() == 0: + continue + y = max_inds[salvage_joints] // W + x = max_inds[salvage_joints] % W + offsets = self.lerp(salvage_joints.nonzero()[0], y, x, heatmap) + y = y.astype(np.float32) + offsets[0] + x = x.astype(np.float32) + offsets[1] + pose_coords[pid][salvage_joints, 0] = y + pose_coords[pid][salvage_joints, 1] = x + pose_kpts[pid][salvage_joints, 2] = max_scores[salvage_joints] + pose_kpts[..., :2] = transpred(pose_coords[..., :2][..., ::-1], + original_height, original_width, + min(H, W)) + return pose_kpts, mean_score diff --git a/ppdet/modeling/backbones/hrnet.py b/ppdet/modeling/backbones/hrnet.py index f93f5fd9c..0955f8b22 100644 --- a/ppdet/modeling/backbones/hrnet.py +++ b/ppdet/modeling/backbones/hrnet.py @@ -688,6 +688,7 @@ class HRNet(nn.Layer): has_se=self.has_se, norm_decay=norm_decay, freeze_norm=freeze_norm, + multi_scale_output=len(return_idx) > 1, name="st4") def forward(self, inputs): diff --git a/ppdet/modeling/heads/__init__.py b/ppdet/modeling/heads/__init__.py index 9263aa812..a6dfcabda 100644 --- a/ppdet/modeling/heads/__init__.py +++ b/ppdet/modeling/heads/__init__.py @@ -23,6 +23,7 @@ from . import ttf_head from . import cascade_head from . import face_head from . import s2anet_head +from . import keypoint_hrhrnet_head from .bbox_head import * from .mask_head import * @@ -35,3 +36,4 @@ from .ttf_head import * from .cascade_head import * from .face_head import * from .s2anet_head import * +from .keypoint_hrhrnet_head import * diff --git a/ppdet/modeling/heads/keypoint_hrhrnet_head.py b/ppdet/modeling/heads/keypoint_hrhrnet_head.py new file mode 100644 index 000000000..08187f852 --- /dev/null +++ b/ppdet/modeling/heads/keypoint_hrhrnet_head.py @@ -0,0 +1,109 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from ppdet.core.workspace import register +from .. import layers as L +from ..backbones.hrnet import BasicBlock + + +@register +class HrHrnetHead(nn.Layer): + __inject__ = ['loss'] + + def __init__(self, num_joints, loss='HrHrnetLoss', swahr=False, width=32): + """ + Head for HigherHrnet network + + Args: + num_joints (int): number of keypoints + hrloss (object): HrHrnetLoss instance + swahr (bool): whether to use swahr + width (int): hrnet channel width + """ + super(HrHrnetHead, self).__init__() + self.loss = loss + + self.num_joints = num_joints + num_featout1 = num_joints * 2 + num_featout2 = num_joints + self.swahr = swahr + self.conv1 = L.Conv2d(width, num_featout1, 1, 1, 0, bias=True) + self.conv2 = L.Conv2d(width, num_featout2, 1, 1, 0, bias=True) + self.deconv = nn.Sequential( + L.ConvTranspose2d( + num_featout1 + width, width, 4, 2, 1, 0, bias=False), + L.BatchNorm2d(width), + L.ReLU()) + self.blocks = nn.Sequential(*(BasicBlock( + num_channels=width, + num_filters=width, + has_se=False, + freeze_norm=False, + name='HrHrnetHead_{}'.format(i)) for i in range(4))) + + self.interpolate = L.Upsample(2, mode='bilinear') + self.concat = L.Concat(dim=1) + if swahr: + self.scalelayer0 = nn.Sequential( + L.Conv2d( + width, num_joints, 1, 1, 0, bias=True), + L.BatchNorm2d(num_joints), + L.ReLU(), + L.Conv2d( + num_joints, + num_joints, + 9, + 1, + 4, + groups=num_joints, + bias=True)) + self.scalelayer1 = nn.Sequential( + L.Conv2d( + width, num_joints, 1, 1, 0, bias=True), + L.BatchNorm2d(num_joints), + L.ReLU(), + L.Conv2d( + num_joints, + num_joints, + 9, + 1, + 4, + groups=num_joints, + bias=True)) + + def forward(self, feats, targets=None): + x1 = feats[0] + xo1 = self.conv1(x1) + x2 = self.blocks(self.deconv(self.concat((x1, xo1)))) + xo2 = self.conv2(x2) + num_joints = self.num_joints + if self.training: + if self.swahr: + so1 = self.scalelayer0(x1) + so2 = self.scalelayer1(x2) + hrhrnet_outputs = ([xo1[:, :num_joints], so1], [xo2, so2], + xo1[:, num_joints:]) + return self.loss(hrhrnet_outputs, targets) + else: + hrhrnet_outputs = (xo1[:, :num_joints], xo2, + xo1[:, num_joints:]) + return self.loss(hrhrnet_outputs, targets) + + # averaged heatmap, upsampled tagmap + upsampled = self.interpolate(xo1) + avg = (upsampled[:, :num_joints] + xo2[:, :num_joints]) / 2 + return avg, upsampled[:, num_joints:] diff --git a/ppdet/modeling/keypoint_utils.py b/ppdet/modeling/keypoint_utils.py new file mode 100644 index 000000000..b65ea9cba --- /dev/null +++ b/ppdet/modeling/keypoint_utils.py @@ -0,0 +1,160 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np + + +def get_affine_mat_kernel(h, w, s, inv=False): + if w < h: + w_ = s + h_ = int(np.ceil((s / w * h) / 64.) * 64) + scale_w = w + scale_h = h_ / w_ * w + + else: + h_ = s + w_ = int(np.ceil((s / h * w) / 64.) * 64) + scale_h = h + scale_w = w_ / h_ * h + + center = np.array([np.round(w / 2.), np.round(h / 2.)]) + + size_resized = (w_, h_) + trans = get_affine_transform( + center, np.array([scale_w, scale_h]), 0, size_resized, inv=inv) + + return trans, size_resized + + +def get_affine_transform(center, + input_size, + rot, + output_size, + shift=(0., 0.), + inv=False): + """Get the affine transform matrix, given the center/scale/rot/output_size. + + Args: + center (np.ndarray[2, ]): Center of the bounding box (x, y). + scale (np.ndarray[2, ]): Scale of the bounding box + wrt [width, height]. + rot (float): Rotation angle (degree). + output_size (np.ndarray[2, ]): Size of the destination heatmaps. + shift (0-100%): Shift translation ratio wrt the width/height. + Default (0., 0.). + inv (bool): Option to inverse the affine transform direction. + (inv=False: src->dst or inv=True: dst->src) + + Returns: + np.ndarray: The transform matrix. + """ + assert len(center) == 2 + assert len(input_size) == 2 + assert len(output_size) == 2 + assert len(shift) == 2 + + scale_tmp = input_size + + shift = np.array(shift) + src_w = scale_tmp[0] + dst_w = output_size[0] + dst_h = output_size[1] + + rot_rad = np.pi * rot / 180 + src_dir = rotate_point([0., src_w * -0.5], rot_rad) + dst_dir = np.array([0., dst_w * -0.5]) + + src = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + scale_tmp * shift + src[1, :] = center + src_dir + scale_tmp * shift + src[2, :] = _get_3rd_point(src[0, :], src[1, :]) + + dst = np.zeros((3, 2), dtype=np.float32) + dst[0, :] = [dst_w * 0.5, dst_h * 0.5] + dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir + dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :]) + + if inv: + trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + return trans + + +def _get_3rd_point(a, b): + """To calculate the affine matrix, three pairs of points are required. This + function is used to get the 3rd point, given 2D points a & b. + + The 3rd point is defined by rotating vector `a - b` by 90 degrees + anticlockwise, using b as the rotation center. + + Args: + a (np.ndarray): point(x,y) + b (np.ndarray): point(x,y) + + Returns: + np.ndarray: The 3rd point. + """ + assert len(a) == 2 + assert len(b) == 2 + direction = a - b + third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32) + + return third_pt + + +def rotate_point(pt, angle_rad): + """Rotate a point by an angle. + + Args: + pt (list[float]): 2 dimensional point to be rotated + angle_rad (float): rotation angle by radian + + Returns: + list[float]: Rotated point. + """ + assert len(pt) == 2 + sn, cs = np.sin(angle_rad), np.cos(angle_rad) + new_x = pt[0] * cs - pt[1] * sn + new_y = pt[0] * sn + pt[1] * cs + rotated_pt = [new_x, new_y] + + return rotated_pt + + +def transpred(kpts, h, w, s): + trans, _ = get_affine_mat_kernel(h, w, s, inv=True) + + return warp_affine_joints(kpts[..., :2].copy(), trans) + + +def warp_affine_joints(joints, mat): + """Apply affine transformation defined by the transform matrix on the + joints. + + Args: + joints (np.ndarray[..., 2]): Origin coordinate of joints. + mat (np.ndarray[3, 2]): The affine matrix. + + Returns: + matrix (np.ndarray[..., 2]): Result coordinate of joints. + """ + joints = np.array(joints) + shape = joints.shape + joints = joints.reshape(-1, 2) + return np.dot(np.concatenate( + (joints, joints[:, 0:1] * 0 + 1), axis=1), + mat.T).reshape(shape) diff --git a/ppdet/modeling/layers.py b/ppdet/modeling/layers.py index c7b63eba4..57fb0a99f 100644 --- a/ppdet/modeling/layers.py +++ b/ppdet/modeling/layers.py @@ -950,3 +950,101 @@ class MaskMatrixNMS(object): cate_scores = paddle.gather(cate_scores, index=sort_inds) cate_labels = paddle.gather(cate_labels, index=sort_inds) return seg_preds, cate_scores, cate_labels + + +def Conv2d(in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + weight_init=Normal(std=0.001), + bias_init=Constant(0.)): + weight_attr = paddle.framework.ParamAttr(initializer=weight_init) + if bias: + bias_attr = paddle.framework.ParamAttr(initializer=bias_init) + else: + bias_attr = False + conv = nn.Conv2D( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + weight_attr=weight_attr, + bias_attr=bias_attr) + return conv + + +def ConvTranspose2d(in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + dilation=1, + weight_init=Normal(std=0.001), + bias_init=Constant(0.)): + weight_attr = paddle.framework.ParamAttr(initializer=weight_init) + if bias: + bias_attr = paddle.framework.ParamAttr(initializer=bias_init) + else: + bias_attr = False + conv = nn.Conv2DTranspose( + in_channels, + out_channels, + kernel_size, + stride, + padding, + output_padding, + dilation, + groups, + weight_attr=weight_attr, + bias_attr=bias_attr) + return conv + + +def BatchNorm2d(num_features, eps=1e-05, momentum=0.9, affine=True): + if not affine: + weight_attr = False + bias_attr = False + else: + weight_attr = None + bias_attr = None + batchnorm = nn.BatchNorm2D( + num_features, + momentum, + eps, + weight_attr=weight_attr, + bias_attr=bias_attr) + return batchnorm + + +def ReLU(): + return nn.ReLU() + + +def Upsample(scale_factor=None, mode='nearest', align_corners=False): + return nn.Upsample(None, scale_factor, mode, align_corners) + + +def MaxPool(kernel_size, stride, padding, ceil_mode=False): + return nn.MaxPool2D(kernel_size, stride, padding, ceil_mode=ceil_mode) + + +class Concat(nn.Layer): + def __init__(self, dim=0): + super(Concat, self).__init__() + self.dim = dim + + def forward(self, inputs): + return paddle.concat(inputs, axis=self.dim) + + def extra_repr(self): + return 'dim={}'.format(self.dim) diff --git a/ppdet/modeling/losses/__init__.py b/ppdet/modeling/losses/__init__.py index 7a3816811..f4c914516 100644 --- a/ppdet/modeling/losses/__init__.py +++ b/ppdet/modeling/losses/__init__.py @@ -19,6 +19,7 @@ from . import ssd_loss from . import fcos_loss from . import solov2_loss from . import ctfocal_loss +from . import keypoint_loss from .yolo_loss import * from .iou_aware_loss import * @@ -27,3 +28,4 @@ from .ssd_loss import * from .fcos_loss import * from .solov2_loss import * from .ctfocal_loss import * +from .keypoint_loss import * diff --git a/ppdet/modeling/losses/keypoint_loss.py b/ppdet/modeling/losses/keypoint_loss.py new file mode 100644 index 000000000..21b5556d3 --- /dev/null +++ b/ppdet/modeling/losses/keypoint_loss.py @@ -0,0 +1,185 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from itertools import cycle, islice +from collections import abc +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from ppdet.core.workspace import register, serializable + +__all__ = ['HrHrnetLoss'] + + +@register +@serializable +class HrHrnetLoss(nn.Layer): + def __init__(self, num_joints, swahr): + """ + HrHrnetLoss layer + + Args: + num_joints (int): number of keypoints + """ + super(HrHrnetLoss, self).__init__() + if swahr: + self.heatmaploss = HeatMapSWAHRLoss(num_joints) + else: + self.heatmaploss = HeatMapLoss() + self.aeloss = AELoss() + self.ziploss = ZipLoss( + [self.heatmaploss, self.heatmaploss, self.aeloss]) + + def forward(self, inputs, records): + targets = [] + targets.append([records['heatmap_gt1x'], records['mask_1x']]) + targets.append([records['heatmap_gt2x'], records['mask_2x']]) + targets.append(records['tagmap']) + keypoint_losses = dict() + loss = self.ziploss(inputs, targets) + keypoint_losses['heatmap_loss'] = loss[0] + loss[1] + keypoint_losses['pull_loss'] = loss[2][0] + keypoint_losses['push_loss'] = loss[2][1] + keypoint_losses['loss'] = recursive_sum(loss) + return keypoint_losses + + +class HeatMapLoss(object): + def __init__(self, loss_factor=1.0): + super(HeatMapLoss, self).__init__() + self.loss_factor = loss_factor + + def __call__(self, preds, targets): + heatmap, mask = targets + loss = ((preds - heatmap)**2 * mask.cast('float').unsqueeze(1)) + loss = paddle.clip(loss, min=0, max=2).mean() + loss *= self.loss_factor + return loss + + +class HeatMapSWAHRLoss(object): + def __init__(self, num_joints, loss_factor=1.0): + super(HeatMapSWAHRLoss, self).__init__() + self.loss_factor = loss_factor + self.num_joints = num_joints + + def __call__(self, preds, targets): + heatmaps_gt, mask = targets + heatmaps_pred = preds[0] + scalemaps_pred = preds[1] + + heatmaps_scaled_gt = paddle.where(heatmaps_gt > 0, 0.5 * heatmaps_gt * ( + 1 + (1 + + (scalemaps_pred - 1.) * paddle.log(heatmaps_gt + 1e-10))**2), + heatmaps_gt) + + regularizer_loss = paddle.mean( + paddle.pow((scalemaps_pred - 1.) * (heatmaps_gt > 0).astype(float), + 2)) + omiga = 0.01 + # thres = 2**(-1/omiga), threshold for positive weight + hm_weight = heatmaps_scaled_gt**( + omiga + ) * paddle.abs(1 - heatmaps_pred) + paddle.abs(heatmaps_pred) * ( + 1 - heatmaps_scaled_gt**(omiga)) + + loss = (((heatmaps_pred - heatmaps_scaled_gt)**2) * + mask.cast('float').unsqueeze(1)) * hm_weight + loss = loss.mean() + loss = self.loss_factor * (loss + 1.0 * regularizer_loss) + return loss + + +class AELoss(object): + def __init__(self, pull_factor=0.001, push_factor=0.001): + super(AELoss, self).__init__() + self.pull_factor = pull_factor + self.push_factor = push_factor + + def apply_single(self, pred, tagmap): + if tagmap.numpy()[:, :, 3].sum() == 0: + return (paddle.zeros([1]), paddle.zeros([1])) + nonzero = paddle.nonzero(tagmap[:, :, 3] > 0) + if nonzero.shape[0] == 0: + return (paddle.zeros([1]), paddle.zeros([1])) + p_inds = paddle.unique(nonzero[:, 0]) + num_person = p_inds.shape[0] + if num_person == 0: + return (paddle.zeros([1]), paddle.zeros([1])) + + pull = 0 + tagpull_num = 0 + embs_all = [] + person_unvalid = 0 + for person_idx in p_inds.numpy(): + valid_single = tagmap[person_idx.item()] + validkpts = paddle.nonzero(valid_single[:, 3] > 0) + valid_single = paddle.index_select(valid_single, validkpts) + emb = paddle.gather_nd(pred, valid_single[:, :3]) + if emb.shape[0] == 1: + person_unvalid += 1 + mean = paddle.mean(emb, axis=0) + embs_all.append(mean) + pull += paddle.mean(paddle.pow(emb - mean, 2), axis=0) + tagpull_num += emb.shape[0] + pull /= max(num_person - person_unvalid, 1) + if num_person < 2: + return pull, paddle.zeros([1]) + + embs_all = paddle.stack(embs_all) + A = embs_all.expand([num_person, num_person]) + B = A.transpose([1, 0]) + diff = A - B + + diff = paddle.pow(diff, 2) + push = paddle.exp(-diff) + push = paddle.sum(push) - num_person + + push /= 2 * num_person * (num_person - 1) + return pull, push + + def __call__(self, preds, tagmaps): + bs = preds.shape[0] + losses = [self.apply_single(preds[i], tagmaps[i]) for i in range(bs)] + pull = self.pull_factor * sum(loss[0] for loss in losses) / len(losses) + push = self.push_factor * sum(loss[1] for loss in losses) / len(losses) + return pull, push + + +class ZipLoss(object): + def __init__(self, loss_funcs): + super(ZipLoss, self).__init__() + self.loss_funcs = loss_funcs + + def __call__(self, inputs, targets): + assert len(self.loss_funcs) == len(targets) >= len(inputs) + + def zip_repeat(*args): + longest = max(map(len, args)) + filled = [islice(cycle(x), longest) for x in args] + return zip(*filled) + + return tuple( + fn(x, y) + for x, y, fn in zip_repeat(inputs, targets, self.loss_funcs)) + + +def recursive_sum(inputs): + if isinstance(inputs, abc.Sequence): + return sum([recursive_sum(x) for x in inputs]) + return inputs diff --git a/requirements.txt b/requirements.txt index 8ce34b5f0..ac2135f27 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,5 @@ shapely scipy terminaltables pycocotools +xtcocotools==1.6 setuptools>=42.0.0 -- GitLab